?

提升聯邦學習通信效率的梯度壓縮算法①

2022-11-07 09:07田金簫
計算機系統應用 2022年10期
關鍵詞:梯度投影客戶端

田金簫

(西南交通大學 計算機與人工智能學院,成都 611756)

1 引言

近年來,隨著人工智能技術的快速發展和廣泛應用,數據隱私保護也得到了密切關注.歐盟出臺了首個關于數據隱私保護的法案《通用數據保護條例》(General Data Protection Regulation,GDPR)[1],明確了對數據隱私保護的若干規定.中國自2017年起實施的《中華人民共和國網絡安全法》和《中華人民共和國民法總則》中也對用戶隱私數據的使用做出了明確的規定.在機器學習中,模型的好壞很大程度上依托于建模的數據.但由于相關法律法規的限制,數據孤島問題變得十分普遍,導致企業很難獲取訓練數據.為此,谷歌在2016年提出了聯邦學習的概念.聯邦學習是一種基于分布式機器學習的框架,在這種框架中,多個客戶端在中央服務器的協調下共同訓練模型,并保證訓練數據可以保留在本地,不需要像傳統的機器學習方法一樣將數據上傳至中央服務器[2],從而保護了用戶隱私.

構建一個高性能的聯邦模型通常需要多輪通信,同時規模龐大的神經網絡模型,往往包含數百萬個參數[3],這導致了巨大的通信開銷.此外,相較于傳統的分布式機器學習,聯邦學習還面臨如下問題:

1)客戶端數據非獨立同分布: 在傳統分布式機器學習中的訓練數據隨機均勻地分布在客戶端上[4],即遵循獨立同分布(independent and identically distributed,IID).這在聯邦學習中通常是不成立的,由于用戶的喜好不同,客戶端的數據通常是非獨立同分布(non-IID)的.即客戶端擁有的局部數據集不能代表整體數據的分布,不同客戶端之間的數據分布也不同.

2)數據不平衡: 不同的客戶端可能擁有不同的數據量.

3)客戶端數量龐大且不可靠: 參與訓練的客戶端為大量的移動設備,通常大部分客戶端經常離線或者處于不可靠的連接上,因此無法確??蛻舳藚⑴c每一輪的訓練.

本文主要研究聯邦學習中的通信效率問題,利用梯度稀疏化的思想減少客戶端與服務器之間通信的參數量,并在服務器聚合時使用投影的方式緩解非獨立同分布數據帶來的影響.經過在MNIST 和CIFAR10數據集上的實驗證明,本文提出的算法能夠在聯邦學習的約束條件下高效訓練模型.

2 相關工作

一般來說,減少聯邦學習中的通信開銷有兩種策略,一種是減少訓練過程中的通信輪次,另一種是減少每輪傳遞的通信量.減少通信輪次的經典方案是聯邦學習中最常用的FedAvg 算法[2],即令客戶端在本地執行多輪本地更新,服務器再進行全局聚合,來減少通信輪數.FedAvg 在每次通信中,客戶端需要上傳或下載整個模型,由于聯邦客戶端通常運行在緩慢且不可靠的網絡連接上,這一要求使得使用FedAvg 訓練大型模型變得困難.在實際應用中,FedAvg 算法可以較好地處理非凸問題,但該算法不能很好處理聯邦學習中數據non-IID 的情況,在此應用場景很可能導致模型不收斂[5].因此針對non-IID 場景,Briggs 等[6]在FedAvg的基礎上引入層次聚類技術,根據局部更新與全局模型的相似度對客戶端進行聚類和分離,以減少總通信輪數.此外Karimireddy 等[7]通過估計服務器與客戶端更新方向的差異來修正客戶端本地更新的方向,有效地克服了non-IID 問題,能在較少的通信輪次達到收斂.

另一類方法的核心思想在于減少傳輸的數據量,主要通過量化、稀疏化等一系列方法對模型參數或者梯度進行壓縮.量化通過將元素低精度表示或者映射到預定義的一組碼字來減少梯度張量中每個元素的位數,例如Dettmers[8]將梯度的32 位浮點數量化至8 位,SignSGD[9-11]則只保留梯度的符號來更新模型,將負梯度量化為-1,其余量化為1,實現了32 倍的壓縮.稀疏化方法通過只上傳部分重要的梯度來進行全局模型的更新,如何選擇這些梯度成為該方法的關鍵.Strom[12]提出使用梯度的大小來衡量其重要性,通過預先設立閾值,當梯度大于該閾值時對其進行上傳.然而在實際情況中,由于不同的網絡結構參數分布差異較大,導致我們無法選擇合適的閾值.因此目前稀疏化方法通常使用Aji 等[13]提出的固定稀疏率,每次傳遞一定比例的最大梯度或每次傳遞前k個最大梯度的Topk 方法[14].上述工作有效地解決了分布式機器學習中的通信開銷問題,針對聯邦學習的訓練環境,Rothchild 等[15]使用了一種特殊的數據結構計數草圖(count sketch)對客戶端梯度進行壓縮.Chen 等[16]將神經網絡的不同層分為淺層和深層,并認為深層參數更新頻率低于淺層參數,因此提出了異步更新策略,有效減少了每輪傳遞的參數量.Haddadpour 等[17]在FedAvg 的基礎上對每輪傳遞的參數進行壓縮,并針對non-IID 場景采用梯度跟蹤技術對客戶端梯度方向進行修正,在收斂速度和準確率上都取得了較好的效果.

Sattler 等[18]也針對聯邦學習的訓練環境提出了稀疏三元壓縮(sparse ternary compression,STC),該方法在Topk 梯度稀疏化的基礎上進行了量化進一步減少了通信量,并利用錯誤反饋機制實現了客戶端與服務器之間的雙向壓縮,在聯邦學習場景中表現出了良好的效果.該方法考慮了聯邦學習中客戶端non-IID數據的場景,通過利用稀疏的特性以及減少本地訓練次數與服務器端頻繁通信去減輕non-IID 數據帶來的問題,但該方法對non-IID 數據的優化能力有限.因此本文將在稀疏三元壓縮算法的基礎上,關注non-IID下的聯邦場景,提升聯邦學習的通信效率.

3 算法設計

3.1 稀疏三元壓縮

常規的Topk 稀疏方法以全精度傳遞稀疏元素,Sattler 等[19]證明了當稀疏化與非零元素的量化相結合時,可以獲得更高的壓縮增益.如算法1 所示,當獲得Topk 稀疏元素Tmasked后,會將其量化為稀疏元素的平均值,因此最后只需要傳遞一個包含值{-μ,0,μ}的三元張量.如果將每一層的梯度看做一個矩陣,那么使用Topk 和稀疏三元壓縮后得到的結果如圖1 所示,原始梯度是一個稠密矩陣,顏色深淺代表值的大小,通過Topk 方法會得到一個保留較大值的稀疏矩陣,值較小的則置為0,而稀疏三元壓縮則在Topk 的基礎上做了量化,進一步提升了壓縮率.

圖1 梯度壓縮效果

算法1.STC[18]: 稀疏三元壓縮算法T∈Rn輸入: 張量,稀疏率p 1.v←topk(|T|)k←max(np,1)2.mask←(|T|≥v)∈{0,1}n 3.Tmasked←mask⊙T 4.μ←1∑ni=1|Tmaskedi|5.T*←μ×sign(Tmasked)6.輸出k

Sattler 等[18]在聯邦學習中使用稀疏三元壓縮對客戶端和服務器之間通信的梯度進行雙向壓縮,并結合錯誤反饋機制[20]在客戶端和服務器保留壓縮前后的誤差累加至下一輪訓練過程.

其中,gti為第i個客戶端第t輪訓練得到的原始梯度,為壓縮后的梯度,errort為壓縮前后的誤差.該方法取得了與非壓縮算法相似的收斂速度并大大減少了每一輪的通信量,因此本文也將使用稀疏三元壓縮方法進行梯度壓縮.

3.2 Non-IID 數據的處理

目前在聯邦學習中,我們通常采用平均各個客戶端梯度的方法計算全局模型.當不同客戶端數據滿足IID 條件時,各客戶端梯度更新方向相近,且聚合后梯度與基于傳統的集中式學習獲得的梯度相似性較高.故此方法能獲得全局目標函數的最優解.若客戶端數據non-IID 且數據量差異較大,各客戶端梯度差異性較大,存在相互干擾的情況,導致全局模型收斂速率降低.同時,簡單平均各方梯度易使數據量多的客戶端占主導作用,使得全局模型無法較好地處理數據量較少的客戶端,最終導致全局模型整體性能低下.

Wang 等[21]提出使用梯度投影處理non-IID 數據的問題,服務器端在進行梯度平均之前,通過修改梯度方向減輕non-IID 數據帶來的影響.該方法首先對客戶端之間的梯度沖突做出定義,當客戶端i的梯度gi和客戶端j的梯度gj滿足gi·gj<0時,則稱為客戶端i和客戶端j之間存在梯度沖突.當客戶端之間存在梯度沖突時,梯度方向差異性較大,這時可以通過將一個客戶端的梯度投影到另一個有沖突的客戶端梯度平面上,使用原梯度減去投影來縮小客戶端之間的梯度差異,如式(3)所示:

此外,該方法定義了內部沖突和外部沖突,分別對其進行投影處理.將參與訓練的客戶端之間的梯度沖突定義為內部沖突,將客戶端梯度按照訓練損失從小到大排序得到并引入參數 α來控制每輪參與投影的客戶端數目.從POt中選擇損失較小的客戶端Sαt迭代的判斷與其他客戶端之間的梯度沖突,并進行投影修改梯度方向以緩解內部沖突.對于未選擇的損失較大的客戶端則保持原有的梯度,此后進行梯度平均得到聚合后的梯度gt,如算法2 所示.

在實際聯邦場景中,客戶端non-IID 程度較大,在每輪聚合中,若對所有客戶端統一采用投影方案,則導致訓練損失大的客戶端的梯度方向不斷靠近損失小的客戶端.這將導致聚合模型無法學習到所有客戶端的信息.但通過調整參數 α,自適應地讓部分訓練損失較大的客戶端直接參與最終的聚合階段,有效地緩解了上述問題.

算法2.MitigateInternalConflict[21]: 緩解內部沖突算法輸入: 客戶端梯度投影順序,參數POtα POtS1-αtα 1.服務器從選擇損失較小的客戶集合參與投影,保留 比例損失較大的客戶端梯度k∈S1-α t 2.for each client in parallel do gpc k ←gtk 3.gti∈POti=1,···,m 4.for each ,do k ·gti<0k≠i 5.if and then gPC||gti||2 gti 6.投影修正客戶端梯度:gPC k ←gPCk -(gti)·gPC k 7.end if 8.end for 9.end for ∑mk=1 gPCk 10.計算聚合梯度:gt←1 m 11.返回聚合梯度gt

由于聯邦學習中客戶端的部分參與和不可靠連接,在第t輪未被選中參與訓練的客戶端可能會遭受被全局模型遺忘的風險, 因此可以在服務端保留其最近一次參與訓練的梯度根據它們的近鄰歷史梯度來估計真實梯度以避免客戶端被遺忘, 如算法3 第6 步所示.第t輪未被選中客戶端的估計梯度gcon與參與更新的客戶端平均后的梯度gt之間的沖突稱為外部沖突, 通過將gt迭代的投影到不同輪次的估計梯度gcon的法平面以緩解外部沖突, 通過參數τ控制投影的輪次. 具體步驟如算法3 所示.

算法3.MitigateExternalConflict[21]: 緩解外部沖突算法gtGHτ輸入:聚合梯度 ,所有客戶端近鄰歷史梯度,參數1.for round do gcon←0 t-i,i=τ,τ-1,···,1 2.初始化估計梯度:k=1,2,···,K 3.for each client do tk=t-i 4.if then gt·gtkk <0 5.if then gcon←gcon+gtkk 6.計算未被選中客戶端的估計梯度:7.end if 8.end if 9.end for gt·gcon<0 10.if then 11.對聚合梯度投影修正:12.end if 13.end for gt 14.返回聚合梯度gt←gt- gt·gcon||gcon||2 gcon

3.3 基于投影聚合的稀疏三元壓縮算法

鑒于投影能夠有效地處理聯邦學習中的non-IID數據問題,因此本文將在稀疏三元壓縮的基礎上,在服務器端使用投影聚合的方式,進一步提高模型的正確率與收斂速度,具體步驟如算法4 所示.

服務器端接收到客戶端梯度與訓練損失后,首先在算法第14 行更新每個客戶端最近一次參與訓練的梯度以便在緩解外部沖突時使用,其中K是所有客戶端個數,tK是客戶端最近一次參與訓練的輪次.之后在第15 行根據訓練損失的大小對本輪參與訓練的客戶端梯度進行排序得到其中m是本輪參與訓練的客戶端個數.然后依次根據算法2 中的緩解內部沖突算法和算法3 中的緩解外部沖突算法得到聚合梯度gt.算法2 和算法3 的主要作用是對聚合梯度gt的方向進行修正以緩解non-IID 問題,因此在第20 行中,保留修正后的聚合梯度gt的方向與原始聚合梯度的大小得到最終的聚合梯度.最后使用與客戶端相同的STC 壓縮算法壓縮聚合梯度并發送至客戶端.

算法4.基于投影聚合的稀疏三元壓縮算法輸入: 初始化模型w 1.for do 2.服務器從K 個客戶端隨機選取m 個客戶端參與訓練i=1,···,m t=1,···,T 3.for in parallel do Ci 4.客戶端 :5.從服務器端下載聚合梯度wti←wt-1i -gˉg 6.)-wti 7.gti←S TC(gti+errort-1,p)8.errort=gti-?gti 9.?gtilti 10.上傳客戶端梯度 和訓練損失至服務器11.end for 12.服務器器端:?gtilti 13.接收參與訓練的客戶端梯度 和訓練損失gti←SGD(wti,Datai

GH={?gt11 ,?gt22 ,···,?gtKK 14.更新所有客戶端近鄰歷史梯度信息:POt={?gt1,?gt2,···,?gtm}15.根據客戶端訓練損失對梯度排序:gt←MitigateInternalCon flict(POt,α)16.緩解內部沖突:t≥τ}17.if then gt←MitigateExternalCon flict(gt,GH,τ)18.緩解外部沖突:19.end if gt=gt/||gt||*|| 1∑mi ?gti||20.m g=S TC(gt+error,p)21.22.error=gt-g 23.發送聚合梯度 至客戶端24.end for g

算法4 中的步驟可簡化為圖2,在客戶端,首先接收聚合梯度,然后根據模型和客戶端數據進行本地訓練得到客戶端梯度,本地訓練完成后使用STC 算法壓縮梯度上傳至服務器,并計算壓縮誤差存儲在本地,在下一輪被選中訓練時進行梯度修正.

圖2 基于投影聚合的稀疏三元壓縮算法流程

服務端接收到所有參與訓練的客戶端發送的梯度后判斷客戶端梯度之間是否存在梯度沖突,并依次通過緩解內部沖突和外部沖突的算法對梯度方向進行修正.最終聚合投影后的梯度生成全局梯度gt,采用STC 算法壓縮全局梯度gt得到發送至客戶端.該算法實現了客戶端與服務器之間的雙向壓縮,并且在服務器端進行投影緩解數據異構的問題.

4 實驗分析

4.1 實驗設置

本文的實驗使用了MNIST 和CIFAR10 數據集.MNIST 數據集包含60 000 張訓練圖片,10 000 張測試圖片,每張圖片是2 828 的灰度手寫數字圖像,實驗使用帶有3 個卷積層的CNN 模型對MNIST 進行訓練.CIFAR10 數據集包含50 000 張訓練圖片,10 000 張測試圖片,每張圖片是3 232 的RGB 圖像,使用文獻[18]中簡化的VGG11 網絡進行訓練.客戶端數據集劃分參照文獻[2],首先按照數據集的類別進行排序,然后將數據集劃分為200 個分片,每個客戶端隨機選擇兩個不會替換的分片來模擬客戶端數據非獨立同分布的場景.實驗中部分參數設置如表1 所示.

表1 參數設置

4.2 實驗結果

我們將本文提出的算法與FedAvg 以及稀疏三元壓縮算法進行了對比,圖3 和圖4 是在MNIST 數據集上的結果,圖3 是全局模型在所有客戶端上的平均測試準確率,圖4 為測試準確率的方差,其中稀疏三元壓縮以及本文提出的算法在實驗中設置了0.1 的稀疏率,也就是每輪傳遞10%的參數進行訓練,根據圖1 的實驗結果可以看到本文提出的算法相較于其他算法收斂速度和收斂精度都略有提升,特別是相較于STC 算法,在相同壓縮率的條件下本文提出的算法大約在第75 輪收斂,而STC 算法在訓練過程非常震蕩,并且在大約100 輪才收斂.

圖3 MNISTS 數據集測試正確率

圖4 MNISTS 數據集測試方差

圖5 和圖6 是在CIFAR10 數據集上的測試準確率和測試方差,稀疏率同樣為0.1,與MNIST 數據集相比,在CIFAR10 數據集上的訓練過程更加震蕩,但是本文提出的算法相較其他算法收斂速度和收斂精度都有大幅度提升,并且訓練過程中的震蕩幅度遠遠小于FedAvg 和STC 算法,這說明本文的算法是非常有效的.

圖5 CIFAR10 數據集平均測試正確率

圖6 CIFAR10 數據集測試方差

表2 中記錄了客戶端與服務器之間每輪通信的參數大小,通信輪次是達到固定正確率(MNIST 95%CIFAR10 50% )大約所用的通信輪數,以FedAvg 作為基線算法,本文提出的算法在上傳和下載時都進行了壓縮,在MNIST 數據集上相較于FedAvg 每輪的通信量減少了45 倍,并且本文的算法在第100 輪時就達到了指定的正確率,相較于FedAvg 和STC 分別減少了97 和57 個通信輪次,在CIFAR10 數據集上每輪的通信量更是減少了47 倍,通信輪次相較于FedAvg 和STC 減少了295 輪和300 輪.

表2 通信開銷計算

5 結論

本文提出了基于投影聚合的稀疏三元壓縮算法,提升聯邦學習的通信效率.該算法在客戶端和服務端采用稀疏三元壓縮減少客戶端在每一輪訓練過程中上傳和下載的通信量,同時在服務器端利用梯度投影的方式緩解了由于客戶端數據異構以及部分參與導致的梯度沖突問題.通過在MNIST 和CIFAR10 數據集上的實驗驗證,本文提出的算法在通信量、收斂速度和正確率3 個方面都要由于傳統的FedAvg 算法和稀疏三元壓縮算法.由于梯度壓縮會略微改變原始梯度的方向,在未來我們將針對不同的壓縮方法對投影聚合的方式做進一步的研究,進一步提高算法的有效性.

猜你喜歡
梯度投影客戶端
論詞樂“均拍”對詞體格律之投影
“人民網+客戶端”推出數據新聞
——穩就業、惠民生,“數”讀十年成績單
投影向量問題
一個具梯度項的p-Laplace 方程弱解的存在性
找投影
內容、形式與表達——有梯度的語言教學策略研究
航磁梯度數據實測與計算對比研究
虛擬專用網絡訪問保護機制研究
新聞客戶端差異化發展策略
《投影與視圖》單元測試題
91香蕉高清国产线观看免费-97夜夜澡人人爽人人喊a-99久久久无码国产精品9-国产亚洲日韩欧美综合