曹炅宣 常 明 張 蕊③** 支 天** 張曦珊**
(*中國科學(xué)技術(shù)大學(xué) 合肥 230026)
(**中國科學(xué)院計算技術(shù)研究所 北京 100190)
(***中科寒武紀(jì)科技股份有限公司 北京 100191)
近年來,得益于越來越深的網(wǎng)絡(luò)層數(shù)和越來越大的參數(shù)量,深度神經(jīng)網(wǎng)絡(luò)(deep neural networks,DNN)在各類任務(wù)中取得了顯著的成功。然而,DNN有著較高的計算復(fù)雜度和較大的參數(shù)儲存要求,因此將其部署到運(yùn)算資源有限的設(shè)備或者對即時性要求較高的應(yīng)用場景變得比較困難,例如智能手機(jī)、嵌入式設(shè)備和邊緣計算。因此壓縮大的模型并提高其運(yùn)行速度變得非常重要。知識蒸餾(distilling the knowledge,KD)就是一種十分有效的模型壓縮方法,其通過從大型教師網(wǎng)絡(luò)中提取有用的知識轉(zhuǎn)移給小型學(xué)生網(wǎng)絡(luò)從而提高小型網(wǎng)絡(luò)的性能。知識蒸餾的損失函數(shù)包含2 個部分,一個是來自于真實(shí)標(biāo)簽的任務(wù)損失,另一個部分則是來自于教師網(wǎng)絡(luò)的蒸餾損失。
因此,如何有效地找到2 個損失函數(shù)的權(quán)重成為了一個待解決的問題。換言之,如何在訓(xùn)練過程中更合理地混合這2 個損失的梯度?,F(xiàn)在大多數(shù)已有的知識蒸餾方法都是手動調(diào)整損失權(quán)重,這種方法既繁瑣又十分浪費(fèi)計算資源,并且往往無法達(dá)到最佳性能。手動搜索權(quán)重的問題主要在于權(quán)重的搜索空間范圍特別大,而且往往是連續(xù)的。例如,根據(jù)框架RepDistiller[1],在相同的數(shù)據(jù)集和師生網(wǎng)絡(luò)組合下,蒸餾損失權(quán)重在0.02(基于概率的知識轉(zhuǎn)移方法(probabilistic knowledge transfer,PKT)[2]) 到30 000(計算相關(guān)性的知識蒸餾方法(correlation congruence for knowledge,CC)[3])之間變化。
針對這個問題,可以采用超參數(shù)優(yōu)化(hyperparameter optimization,HPO[4])和多任務(wù)學(xué)習(xí)(multitask learning,MTL)來確定2 個損失的權(quán)重,但將這2 種方法應(yīng)用到知識蒸餾的訓(xùn)練時存在著一些缺陷。在知識蒸餾訓(xùn)練中,存在2 個優(yōu)化目標(biāo):用于任務(wù)損失的真實(shí)標(biāo)簽以及用于蒸餾損失的教師網(wǎng)絡(luò)。而知識蒸餾設(shè)計的初衷就是使用蒸餾損失作為輔助,幫助作為主要目標(biāo)的任務(wù)損失降低到最小。但超參數(shù)優(yōu)化和多任務(wù)學(xué)習(xí)方法會認(rèn)為這2 個損失處于一個平等情況,因此會產(chǎn)生大量冗余的搜索空間導(dǎo)致參數(shù)調(diào)節(jié)過程的效率十分低下,并且過分平衡用于輔助的蒸餾損失也有一定可能損害到主優(yōu)化目標(biāo)(即任務(wù)損失),后續(xù)多任務(wù)學(xué)習(xí)的實(shí)驗(yàn)也證明了這一點(diǎn)。
為了解決上述問題,本文提出了一種新穎的自動梯度混合方法,該方法可以自動地為知識蒸餾訓(xùn)練找到合適的損失函數(shù)權(quán)重。本文將尋找合適的損失函數(shù)權(quán)重的問題轉(zhuǎn)換為尋找2 個損失通過反向傳播得到的最佳混合梯度的問題??紤]到在知識蒸餾中,蒸餾損失是任務(wù)損失的輔助這一重要的先驗(yàn)知識,自動梯度混合方法可以顯著減少混合梯度的搜索空間。通過找到混合梯度的模長和方向從而確定用于更新模型參數(shù)的混合梯度。在具體訓(xùn)練過程中,混合梯度的模長用來控制模型參數(shù)更新速度,而方向則是決定著模型最終的訓(xùn)練結(jié)果。因此自動梯度混合方法通過固定混合梯度模長與任務(wù)損失產(chǎn)生的梯度模長相同,用來保證模型迭代的穩(wěn)定性。在只需要搜索方向的情況下,可以有效地減少混合梯度的搜索空間并提高搜索效率。在確定了混合梯度的模長和方向后,就可以計算出2 個損失函數(shù)的權(quán)重,從而避免了復(fù)雜的手動調(diào)節(jié)過程。
與現(xiàn)有的手動調(diào)節(jié)方法相比,本文提出的自動梯度混合方法有效利用了知識蒸餾的先驗(yàn)知識,具有以下幾個優(yōu)點(diǎn):首先,自動梯度混合方法將混合梯度的模長約束到與任務(wù)梯度模長相同,這樣能夠保證模型訓(xùn)練的收斂穩(wěn)定性,解耦了梯度向量模長和方向,只需要在方向上進(jìn)行搜索,顯著減少了搜索空間;此外,在進(jìn)行了該梯度模長的約束后,早期訓(xùn)練輪次的結(jié)果與最終訓(xùn)練輪次的結(jié)果具備一個較好的保序性,從而通過一個極短時間的預(yù)訓(xùn)練即可找到較優(yōu)的混合方向,從而實(shí)現(xiàn)了比手動設(shè)置權(quán)重更好的性能;最后,自動梯度混合方法是一種簡單易用的方法,能夠適用于絕大部分的知識蒸餾方法,可以對某種蒸餾方法在某類應(yīng)用場景下是否有效進(jìn)行一個快速驗(yàn)證。
為了證明自動梯度混合方法的效果,本文在CIFAR-100[5]和ImageNet-1k[6]數(shù)據(jù)集上使用Rep-Disitiller[1]框架進(jìn)行實(shí)驗(yàn),自動梯度混合方法在130個組別中表現(xiàn)超過70%的手動調(diào)節(jié)結(jié)果。在時間上,與超參數(shù)優(yōu)化方法相比,自動梯度混合方法只需要1/10 或者更少的時間就能達(dá)到與超參數(shù)優(yōu)化方法相當(dāng)?shù)木取?/p>
知識蒸餾將大的、笨重的教師網(wǎng)絡(luò)的知識轉(zhuǎn)移給更小、更敏捷的學(xué)生網(wǎng)絡(luò)中,從而能夠有效提高學(xué)生網(wǎng)絡(luò)的性能。Hinton 等人[7]提出了這種方法,該方法使用溫度來修正教師網(wǎng)絡(luò)輸出的softmax,使其作為軟標(biāo)簽來指導(dǎo)小型的學(xué)生網(wǎng)絡(luò)。目前有3 種不同類型的知識蒸餾,分別是基于響應(yīng)、基于特征和基于關(guān)系的知識蒸餾方法[8]?;陧憫?yīng)的方法[7]旨在通過使用教師網(wǎng)絡(luò)的logits 作為知識來直接模擬教師網(wǎng)絡(luò)的最終預(yù)測。基于特征的方法[9-12]則是專注于匹配教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中間層的特征?;陉P(guān)系的方法[1,3,13-14]認(rèn)為不同層或數(shù)據(jù)樣本間的關(guān)系能有助于蒸餾。然而,現(xiàn)有絕大部分方法都使用手動調(diào)整來找到合適的任務(wù)損失權(quán)重和蒸餾損失權(quán)重,這既繁瑣又十分耗時,而且往往無法達(dá)到最佳性能。
超參數(shù)優(yōu)化(HPO)方法是一類尋找最優(yōu)的超參數(shù)組合的方法。這些方法可以分成3 類。第1 類是窮舉搜索,例如隨機(jī)搜索和網(wǎng)格搜索。網(wǎng)格搜索將超參數(shù)空間劃分為不同的網(wǎng)格并運(yùn)行每個網(wǎng)格對應(yīng)的參數(shù)組合以此找到最佳參數(shù)。這種遍歷式的搜索方法由于沒有對搜索空間進(jìn)行任何裁剪,因此非常耗時。為了使得搜索過程效率更高,研究人員提出了第2 類啟發(fā)式搜索方法,該類搜索方法可以在搜索過程中根據(jù)可用信息(例如之前訓(xùn)練的結(jié)果)選擇后續(xù)最佳的搜索分支。超參數(shù)優(yōu)化方法中包含有一些經(jīng)典的啟發(fā)式搜索方法,例如樸素進(jìn)化和模擬退火。最近,研究人員也提出了Hyperband[15]、Popluation-Based Training[16]等新的啟發(fā)式方法。第3 類是貝葉斯優(yōu)化,它通過條件概率建模來預(yù)測給定超參數(shù)的最終性能,例如序列貝葉斯優(yōu)化(sequential Bayesian optimization hyperband,BOHB)[17]、樹形Parzen 估計方法(tree-structured Parzen estimator approach,TPE)[18]等。與手動設(shè)置參數(shù)相比,超參數(shù)優(yōu)化方法的調(diào)節(jié)器理論上可以節(jié)省一些搜索時間,但是仍然非常耗時。
多任務(wù)學(xué)習(xí)(multi-task learning,MTL)是指通過使用所有任務(wù)和其他一些任務(wù)中包含的知識來共同學(xué)習(xí)多個任務(wù),以此來提高每個任務(wù)性能的一種訓(xùn)練方法。多任務(wù)學(xué)習(xí)方法包括2 個方面[19]。一些多任務(wù)學(xué)習(xí)方法設(shè)計深度學(xué)習(xí)多任務(wù)架構(gòu),包含有設(shè)計側(cè)重于編碼器[20-21]或側(cè)重于解碼器[22-23]的架構(gòu)。其他的一些多任務(wù)學(xué)習(xí)方法則是側(cè)重于平衡多個任務(wù)的訓(xùn)練優(yōu)化,例如Uncertainly[24]、GradNorm[25]、DWA[26]、DTP[27]、Multi-Objective Optim[28]。絕大部分多任務(wù)學(xué)習(xí)方法都會等權(quán)重優(yōu)化所有任務(wù)或者是所有損失函數(shù),因此它可能會和知識蒸餾中將任務(wù)損失視為主要損失、將蒸餾損失視為輔助的理念相沖突。
為了提高小型學(xué)生網(wǎng)絡(luò)的性能,知識蒸餾除了利用來自于真實(shí)數(shù)據(jù)的監(jiān)督外,還額外引入了來自于大型的教師網(wǎng)絡(luò)中的有益的知識。因此,總的損失函數(shù)由來自于真實(shí)標(biāo)簽的任務(wù)損失和來自于教師網(wǎng)絡(luò)的蒸餾損失構(gòu)成,公式為
這里L(fēng)kd是總的知識蒸餾損失函數(shù),Ltask是任務(wù)損失,Ldistill是蒸餾損失。α和β是任務(wù)損失和蒸餾損失的縮放系數(shù)。為了獲得合適的系數(shù)α和β,絕大部分已有的知識蒸餾方法都是通過手動調(diào)節(jié)方法來進(jìn)行搜索,這類方法非常繁瑣又耗時,并且往往無法使學(xué)生網(wǎng)絡(luò)擁有最佳的性能。為了解決這個問題,本文提出了一種自動梯度混合方法來自動高效地找到損失權(quán)重。
假設(shè)在整個訓(xùn)練過程中,第t輪的模型參數(shù)更新迭代時,損失函數(shù)對模型參數(shù)求導(dǎo)后得到的梯度被用來迭代模型參數(shù),公式為
為了有效搜索最優(yōu)混合梯度,需要盡可能地縮小搜索空間。在這項(xiàng)工作中,本文利用了知識蒸餾中的一個重要的先驗(yàn)知識,即任務(wù)損失是主要優(yōu)化目標(biāo),而蒸餾損失是任務(wù)損失的輔助。因此,混合梯度Gkd應(yīng)當(dāng)與任務(wù)梯度Gtask更加相關(guān),蒸餾梯度Gdistill用來做一個細(xì)化調(diào)整。本文通過確定混合梯度Gkd的方向和模長來找到這個混合梯度。一般而言,在使用梯度來更新模型參數(shù)的過程中,梯度向量具有2 個自變量,一個是方向,另一個則是模長,兩者的功能具有一定差異。梯度的模長主要影響著模型參數(shù)的更新速度,從而控制模型收斂,當(dāng)模長太長時,會出現(xiàn)梯度爆炸使得模型無法收斂或者是在最優(yōu)值附近徘徊的情況;而模長過短時,模型收斂會非常緩慢,找到最優(yōu)值的時間過長,也有可能陷入到某個局部最優(yōu)點(diǎn)中。梯度的方向則是決定著模型參數(shù)的更新方向,決定模型最終的收斂位置能否在相應(yīng)的指標(biāo)上取得好的效果(如分類任務(wù)中的準(zhǔn)確率,檢測任務(wù)中的mAP 等)。在非蒸餾訓(xùn)練中,模型僅使用任務(wù)損失產(chǎn)生的梯度就能訓(xùn)練出來一個穩(wěn)定的結(jié)果。本文基于上述先驗(yàn)知識,為了提高效率減小搜索空間,以及保證模型訓(xùn)練的收斂穩(wěn)定性,自動梯度混合方法將混合梯度的模長約束到與任務(wù)損失梯度模長相同,公式為
在實(shí)現(xiàn)該約束后,可以很方便地將學(xué)生網(wǎng)絡(luò)的非蒸餾訓(xùn)練版本的超參數(shù),如學(xué)習(xí)率、權(quán)重衰減等,方便應(yīng)用到本文中使用的蒸餾訓(xùn)練上。因此可以通過對Gkd的模長約束得到一個穩(wěn)定的訓(xùn)練過程。
在確定了模長大小后,自動梯度混合方法只需要在搜索空間中搜索Gkd梯度方向,該梯度方向由任務(wù)梯度Gtask和蒸餾梯度Gdistill決定。如圖1 所示,Gkd方向的搜索空間為Gtask和Gdistill之間的角度空間。θ為Gtask和Gdistill夾角大小:
圖1 梯度混合示意圖
假設(shè)Gtask和Gkd的夾角為λθ,只需要在λ∈[0,1] 這個范圍內(nèi)進(jìn)行搜索。在這種方式下,由于不需要對Gkd的模長進(jìn)行搜索,整個搜索空間得到大幅度縮減,同時對最優(yōu)方向的搜索可以保證混合梯度Gkd的有效性。
通過搜索得到λ后,可以用如下公式表示Gkd的方向:
使用式(3)和(4),可以得到:
聯(lián)立式(5)~(7),可以解得損失權(quán)重系數(shù)α和β為
如式(9)所示,損失權(quán)重系數(shù)α和β取決于λ。λ的有效值為[0,1]。當(dāng)λ等于0 時,蒸餾損失對混合梯度沒有任何影響;當(dāng)λ等于1 的時候,混合梯度方向會完全遵循蒸餾梯度的方向。此外,實(shí)驗(yàn)結(jié)果表明自動梯度混合方法在訓(xùn)練早期和后期的性能(在分類任務(wù)中為準(zhǔn)確率)有著良好的保序性。因此,為了進(jìn)一步提高搜索過程中實(shí)驗(yàn)的效率,本文使用訓(xùn)練早期的訓(xùn)練效果來預(yù)測最終的性能。在具體操作中,本文在搜索空間中對λ進(jìn)行一個早期的搜索來作為預(yù)熱訓(xùn)練模型。然后選擇性能最佳的一個作為λ的最佳值。之后可以采用式(9)來計算損失權(quán)重α和β,并且使用它們來完成訓(xùn)練。搜索和訓(xùn)練模型的整個過程如算法1 所示。
本節(jié)中,本文將提出的自動梯度混合方法應(yīng)用在被廣泛使用的圖像分類數(shù)據(jù)集CIFAR-100[5]和ImagNet LSVRC 2012[6]上。此外,本文使用的Rep-Distiller[1]框架基于Pytorch,其模型庫中包含有13種流行的蒸餾方法。在實(shí)驗(yàn)中,本文遵循RepDistiller 默認(rèn)的超參數(shù)設(shè)置,如訓(xùn)練輪次、學(xué)習(xí)率、優(yōu)化器等。在自動梯度混合方法中,預(yù)熱輪次設(shè)置為5。作為對比實(shí)驗(yàn),本文使用RepDistiller 中給出的手動調(diào)整的損失權(quán)重的訓(xùn)練結(jié)果作為基線。
本文在KD[7]、Fitnets[11]、SP[29]、AT[12]、CC[3]、VID[29]、RKD[13]、PKT[3]、FT[10]和NST[9]這10 種蒸餾方法上進(jìn)行實(shí)驗(yàn)。此外,實(shí)驗(yàn)還包含有7 個相似架構(gòu)的師生網(wǎng)絡(luò)組合和6 個不同架構(gòu)的師生網(wǎng)絡(luò)架構(gòu),即整個實(shí)驗(yàn)包含有10 ×13 個小的實(shí)驗(yàn)。
結(jié)果如表1 所示,可以發(fā)現(xiàn)自動梯度混合方法和手動方法比較,無論是在教師網(wǎng)絡(luò)架構(gòu)和學(xué)生網(wǎng)絡(luò)架構(gòu)相似的VGG13-VGG8 和ResNet110-ResNet32亦或者是ResNet32x4-ShuffleNetV2 和VGG13-MobileNetV2 這類架構(gòu)差異很大的網(wǎng)絡(luò)上都有比較好的效果??偨Y(jié)表1 的結(jié)果可以發(fā)現(xiàn),自動梯度混合方法在70%的蒸餾組合上都要比手動調(diào)節(jié)的方法表現(xiàn)得更好。
表1 在數(shù)據(jù)集CIFAR-100 上使用手動調(diào)節(jié)(Manual)和 自動梯度混合方法(AGB)在10 種不同的蒸餾方法和13 種不同的師生網(wǎng)絡(luò)組合的Top-1 準(zhǔn)確率(%)
本文使用KD、CC、對比表示知識蒸餾方法(contrastive representation distillation,CRD)和注意知識蒸餾方法(attention on distillation,AT)在Image Net-1K數(shù)據(jù)集進(jìn)行實(shí)驗(yàn)。因?yàn)镽epDistiller 框架沒有ImageNet-1K 對應(yīng)代碼,所以本文在ImageNet-1K 上復(fù)現(xiàn)了這4 種方法。超參數(shù)和手動調(diào)整的損失權(quán)重是按照另一個蒸餾框架TorchDistil 設(shè)置的。本文使用Pytorch 團(tuán)隊(duì)發(fā)布的模型ResNet34 和ResNet18 作為教師和學(xué)生網(wǎng)絡(luò),并遵循TorchDistill 的ImageNet 訓(xùn)練設(shè)置。
表2 展示了自動梯度混合方法和手動參數(shù)設(shè)置方法在以ResNet34 和ResNet18 作為師生網(wǎng)絡(luò)組合上的top-1 準(zhǔn)確度。對于KD、CC 和AT 方法,自適應(yīng)梯度混合方法可以獲得更好的性能,對于CRD 方法,自動梯度混合方法也可以達(dá)到和手動設(shè)置接近的性能。因此,ImageNet-1K 上的實(shí)驗(yàn)有效證明了自動梯度混合方法的有效性。
表2 自動梯度混合方法(AGB)和手動調(diào)整(Manual)在ImageNet-1k 上的Top-1 準(zhǔn)確度(%),其中教師網(wǎng)絡(luò)是ResNet34(top-1 準(zhǔn)確度73.314%),學(xué)生網(wǎng)絡(luò)是ResNet18(top-1 準(zhǔn)確度69.76%)
本文在CIFAR100 上使用自動梯度混合方法和Microsoft Neural Network Intelligence (NNI)的3 個不同的超參數(shù)優(yōu)化調(diào)節(jié)器進(jìn)行了對比。這些超參數(shù)優(yōu)化方法包括有啟發(fā)式搜索方法模擬退火(simulated annealing)、Hyperband[15]和貝葉斯優(yōu)化方法TPE[18]。選擇VGG13 和VGG8 作為師生網(wǎng)絡(luò),并使用AT 蒸餾方法進(jìn)行實(shí)驗(yàn),在超參數(shù)優(yōu)化方法中,參照式(1),設(shè)置α等于1,β的搜索空間為0.02 到30 000。
圖2 顯示了3 個超參數(shù)優(yōu)化調(diào)節(jié)器和自動梯度混合方法的比較實(shí)驗(yàn)??梢杂^察到自動梯度混合方法只需要極少訓(xùn)練的時間就能達(dá)到非常高的精度。相比之下,在運(yùn)行同樣的時間中,超參數(shù)優(yōu)化方法只能實(shí)現(xiàn)更低的精度。盡管超參數(shù)優(yōu)化方法在最終的結(jié)果中達(dá)到了與自動梯度混合方法相當(dāng)或者略高的精度,但它們需要更多的時間來進(jìn)行搜索,這是非常低效的。
分析超參數(shù)優(yōu)化方法出現(xiàn)的問題,可以發(fā)現(xiàn)無論是手動調(diào)節(jié)、超參數(shù)優(yōu)化或者是一些簡單約束情況,都會導(dǎo)致超參數(shù)搜索過程變得漫長而復(fù)雜。本質(zhì)上,這是由于這類方法在搜索超參數(shù)時會將總梯度向量模長和方向進(jìn)行耦合,同時去搜索梯度向量的方向和模長,會影響模型的收斂性,并出現(xiàn)兩類冗余搜索的情況:(1)搜索到合適的方向而模長過長或過短,導(dǎo)致出現(xiàn)模型無法收斂;(2)搜索到合適的模長而方向不對,這樣會影響模型最終的收斂位置,即影響模型最終的結(jié)果。而當(dāng)一些更為奇怪的約束使得總梯度向量的方向與模長耦合得更加緊密時,甚至無法搜索到對應(yīng)合適方向。
將Uncertainly 和GradNorm 這2 種無超參數(shù)的多任務(wù)學(xué)習(xí)方法與自動梯度混合方法進(jìn)行對比實(shí)驗(yàn)。本文對所有的10 種蒸餾方法進(jìn)行了實(shí)驗(yàn),所有的13 種教師學(xué)生網(wǎng)絡(luò)組合與3.1 節(jié)中的相同。
如表3 所示,自動梯度混合方法應(yīng)用到絕大多數(shù)蒸餾方法中都優(yōu)于這2 種多任務(wù)學(xué)習(xí)方法。多任務(wù)學(xué)習(xí)方法將蒸餾損失和任務(wù)損失平等對待,忽略了知識蒸餾的重要先驗(yàn)知識,即任務(wù)損失是起到主導(dǎo)作用的,而蒸餾損失是用于輔助的。因此,多任務(wù)學(xué)習(xí)方法可能會為了最大限度地降低蒸餾損失而犧牲了性能。還可以發(fā)現(xiàn),當(dāng)使用GradNorm 時,大多數(shù)蒸餾方法的性能都很差。這是因?yàn)镚radNorm 完全忽略了任務(wù)損失應(yīng)該為主導(dǎo)地位。而且,與任務(wù)損失相比,蒸餾損失通常非常大或者非常小。例如,在CC 中,蒸餾梯度的模長約為任務(wù)梯度的100 倍,
表3 多任務(wù)學(xué)習(xí)方法GradNorm 和Uncertainly 在CIFAR-100 上與自動梯度混合方法(AGB)相比的Top-1 測試準(zhǔn)確度(%)。由于訓(xùn)練過程中的梯度爆炸,一些方法顯示出非常差的準(zhǔn)確性或無法訓(xùn)練出有效的結(jié)果(用表示)。null 表示此蒸餾方法不支持多任務(wù)學(xué)習(xí)方法。
而在PKT 中,蒸餾梯度的模長約為任務(wù)梯度的0.001倍。因此,GradNorm 簡單地平衡2 個損失將會導(dǎo)致整個訓(xùn)練過程不穩(wěn)定。相比之下,自動梯度混合方法將混合梯度的模長限制為與任務(wù)梯度的模長相同。因此,自動梯度混合方法在獲得穩(wěn)定訓(xùn)練過程的同時,可以保留任務(wù)梯度占據(jù)主導(dǎo)地位這一重要信息。
本文驗(yàn)證了在自動梯度混合方法中訓(xùn)練早期和訓(xùn)練后期準(zhǔn)確率的保序性。在CIFAR-100 上使用AT 蒸餾方法進(jìn)行這些實(shí)驗(yàn),在NNI 上用VGG13 作為教師網(wǎng)絡(luò),VGG8 作為學(xué)生網(wǎng)絡(luò)。計算早期(第5輪)的準(zhǔn)確率和整個訓(xùn)練結(jié)束的最終準(zhǔn)確率之間的相關(guān)系數(shù)。本文還對手動調(diào)節(jié)方法進(jìn)行了這些實(shí)驗(yàn),α設(shè)置為1,β從0.003 變化到30 000。為了公平地比較,本文選擇結(jié)果接近收斂時的最后80 次實(shí)驗(yàn)來驗(yàn)證相關(guān)性。
如圖3所示,下圖為自動梯度混合方法,其相關(guān)系數(shù)為0.724,遠(yuǎn)高于上圖中手動調(diào)節(jié)方法的0.410。這個實(shí)驗(yàn)說明了使用自動梯度混合方法時早期輪次表現(xiàn)較好的設(shè)置同樣可以運(yùn)用到晚期輪次。因此,預(yù)熱策略可以在不損失性能的前提下大幅提升自動梯度混合方法的效率。
圖3 最佳精度與早期輪次精度之間的相關(guān)性
本文就預(yù)熱階段設(shè)置的熱身輪次和預(yù)熱階段用于離散化的步長進(jìn)行了消融實(shí)驗(yàn)。在CIFAR-100上使用KD 蒸餾方法進(jìn)行實(shí)驗(yàn),教師網(wǎng)絡(luò)為Res-Net32x4,學(xué)生網(wǎng)絡(luò)為ResNet8x4。
圖4 顯示了準(zhǔn)確率、時間開銷與步長的關(guān)系??梢钥吹?當(dāng)步長從0.2 變小后,時間開銷增大,對應(yīng)的結(jié)果略有上升;而當(dāng)步長變大后,實(shí)際上的節(jié)省的時間相當(dāng)有限,而性能也會出現(xiàn)一定程度的下降。圖5 則顯示了準(zhǔn)確率、時間和熱身輪次的關(guān)系。可以發(fā)現(xiàn),與前面步長類似,選取更小的熱身輪次并不會導(dǎo)致運(yùn)行時間有一個顯著的變小。而當(dāng)熱身輪次提升后,時間開銷增大了,對于實(shí)驗(yàn)的準(zhǔn)確率也沒有提升太多。因此本文取的熱身輪次和步長并不具備特殊性,取附近的幾個值結(jié)果差異不會太大,這也說明了是前面模長約束在方法中起到了主要的作用而預(yù)熱的等間距選取最優(yōu)的策略只是用于輔助的。
圖4 準(zhǔn)確率、時間與步長之間的關(guān)系
圖5 準(zhǔn)確率、時間與熱身輪次之間的關(guān)系
圖2 中的結(jié)果也顯示了自動梯度混合方法的高效性。圖2 中圓點(diǎn)表示一次超參數(shù)優(yōu)化方法實(shí)驗(yàn)的準(zhǔn)確性。隨著訓(xùn)練實(shí)驗(yàn)的增加,每條虛線表示 超參數(shù)優(yōu)化方法的最佳準(zhǔn)確性。三角形標(biāo)記表示自動梯度混合方法的結(jié)果,該方法需要大約1.50 次實(shí)驗(yàn)時間才能達(dá)到72.48%的準(zhǔn)確率。在知識蒸餾中尋找損失權(quán)重時,手動調(diào)整會受到大的搜索空間的影響。通過使用貝葉斯優(yōu)化或者是其他算法改進(jìn)搜索過程,超參數(shù)優(yōu)化方法會高效一些,但是仍然有著比較大的搜索空間。相比之下,自動梯度混合方法通過約束混合梯度的模長并僅僅在預(yù)熱階段在方向上進(jìn)行搜索,從而顯著減少了搜索空間。如圖2 所示,超參數(shù)優(yōu)化方法需要10 次以上的實(shí)驗(yàn)才能達(dá)到與自動梯度混合方法相當(dāng)?shù)木取R虼?與超參數(shù)優(yōu)化方法相比,自動梯度混合方法效率更高。
本文提出了一種自動梯度混合方法,可以有效地為絕大部分知識蒸餾方法找到合適的損失權(quán)重。利用蒸餾損失是用于輔助任務(wù)損失這一先驗(yàn)知識,自動梯度混合方法通過減少超參數(shù)搜索空間來優(yōu)化搜索過程。自動梯度混合方法只搜索梯度方向,即2 個損失梯度之間的角度,同時將混合梯度的模長約束為與任務(wù)損失梯度模長相同。本文在13 種不同的師生網(wǎng)絡(luò)組合之間對10 種不同的知識蒸餾方法進(jìn)行了實(shí)驗(yàn)。自動梯度混合方法在使用更少的運(yùn)算資源的前提下在70%的蒸餾方法上性能超過了手動調(diào)節(jié)方法,這說明自動梯度混合方法具有更好的效果以及更高的效率。本文工作的前提是假設(shè)當(dāng)有多個蒸餾損失時,所有的蒸餾損失共享相同的權(quán)重。未來,可以將本文工作擴(kuò)展到具有多種蒸餾損失的情況。