劉金金,李清寶,李曉楠
1.戰(zhàn)略支援部隊(duì)信息工程大學(xué),鄭州450003
2.數(shù)學(xué)工程與先進(jìn)計(jì)算國(guó)家重點(diǎn)實(shí)驗(yàn)室,鄭州450003
3.中原工學(xué)院 計(jì)算機(jī)學(xué)院,鄭州450007
深度神經(jīng)網(wǎng)絡(luò)在多種計(jì)算機(jī)視覺(jué)相關(guān)的任務(wù)中展現(xiàn)出了最優(yōu)越的性能,例如圖像分類(lèi)[1]、工業(yè)視覺(jué)檢測(cè)[2]、姿態(tài)估計(jì)[3]、行人再識(shí)別[4]和人臉識(shí)別[5]等。隨著配套硬件設(shè)備的發(fā)展和對(duì)卷積神經(jīng)網(wǎng)絡(luò)認(rèn)識(shí)的不斷加深,研究表明越深的網(wǎng)絡(luò)能夠提取越抽象的語(yǔ)義信息,網(wǎng)絡(luò)的表示能力越強(qiáng)。然而更深更寬的神經(jīng)網(wǎng)絡(luò)將難以收斂,并且會(huì)導(dǎo)致反向傳播算法中的梯度消失[6-7]。殘差網(wǎng)絡(luò)ResNet[1]和批量歸一化(Batch Normalization,BN)[8]能夠在一定程度上解決這一問(wèn)題,但是具有大量參數(shù)的深度學(xué)習(xí)模型需要更大的存儲(chǔ)空間和更強(qiáng)的運(yùn)算單元,無(wú)法在移動(dòng)終端上進(jìn)行部署和實(shí)時(shí)推理,從而影響深度學(xué)習(xí)模型在實(shí)際應(yīng)用中的落地和推廣。例如公共區(qū)域的視頻監(jiān)控系統(tǒng)多部署在內(nèi)存有限和計(jì)算能力較低的嵌入式設(shè)備上,無(wú)法實(shí)時(shí)準(zhǔn)確地對(duì)視頻幀中的多人進(jìn)行身份識(shí)別和行為分析。本研究旨在改善深度學(xué)習(xí)網(wǎng)絡(luò)在人臉識(shí)別系統(tǒng)中的部署和應(yīng)用問(wèn)題。
為了解決這一問(wèn)題,研究人員采用多種技術(shù)壓縮網(wǎng)絡(luò)參數(shù),主要包括量化或二值化、因子分解、網(wǎng)絡(luò)剪枝和知識(shí)蒸餾[9]等。其中知識(shí)蒸餾的方法基于教師-學(xué)生的策略,旨在訓(xùn)練一個(gè)輕量級(jí)的學(xué)生網(wǎng)絡(luò),使學(xué)生網(wǎng)絡(luò)模仿完備的大尺寸的教師網(wǎng)絡(luò)輸出的軟目標(biāo),達(dá)到知識(shí)遷移的目的。相較于樣本固有的一位有效信息標(biāo)簽,教師網(wǎng)絡(luò)的輸出能夠提供“不正確”分類(lèi)的相對(duì)概率,使分類(lèi)概率具有更大的信息量。學(xué)生網(wǎng)絡(luò)的權(quán)重更新是一個(gè)最小化知識(shí)蒸餾損失的過(guò)程,即最小化學(xué)生網(wǎng)絡(luò)輸出與教師網(wǎng)絡(luò)輸出、學(xué)生網(wǎng)絡(luò)輸出與真實(shí)標(biāo)簽間的差異。
雖然通過(guò)最小化知識(shí)蒸餾損失能夠使學(xué)生網(wǎng)絡(luò)模仿教師網(wǎng)絡(luò)的輸出,但是其性能仍有差距。主要有以下幾個(gè)原因。首先學(xué)生網(wǎng)絡(luò)只學(xué)習(xí)教師網(wǎng)絡(luò)輸出的分類(lèi)概率分布,而忽略了包含豐富語(yǔ)義信息和空間相關(guān)性的中間特征圖。一些現(xiàn)有方法直接對(duì)齊學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的中間層表示,不能有效地轉(zhuǎn)移潛在的空間相關(guān)性。其次由于教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)具有不同的拓?fù)浣Y(jié)構(gòu),其最優(yōu)解空間也存在差異。如果只采用蒸餾損失全程監(jiān)督學(xué)生網(wǎng)絡(luò)的訓(xùn)練過(guò)程,學(xué)生網(wǎng)絡(luò)無(wú)法找到自己的最優(yōu)解空間。此外,教師網(wǎng)絡(luò)的預(yù)測(cè)不是完全正確的,在訓(xùn)練過(guò)程中,如果學(xué)生網(wǎng)絡(luò)完全學(xué)習(xí)教師網(wǎng)絡(luò)輸出的軟目標(biāo),會(huì)遷移錯(cuò)誤的知識(shí)。
針對(duì)上述問(wèn)題,本研究提出了一種針對(duì)分類(lèi)概率和特征圖兩個(gè)層面的深度學(xué)習(xí)模型壓縮算法??蚣苡扇糠纸M成,分別為預(yù)訓(xùn)練得到的教師網(wǎng)絡(luò)、小規(guī)模的學(xué)生網(wǎng)絡(luò)和判別器。其中教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)可以為任意結(jié)構(gòu)的卷積神經(jīng)網(wǎng)絡(luò),判別器則為由多個(gè)全連接層構(gòu)成的深度學(xué)習(xí)網(wǎng)絡(luò),并在訓(xùn)練過(guò)程中更新權(quán)重。為了減輕教師網(wǎng)絡(luò)錯(cuò)誤分類(lèi)的影響,添加指示函數(shù),優(yōu)化知識(shí)蒸餾損失,使學(xué)生網(wǎng)絡(luò)只學(xué)習(xí)正確的輸出。此外由于人臉特征具有豐富的空間相關(guān)性,學(xué)生網(wǎng)絡(luò)模仿教師網(wǎng)絡(luò)提取的特征圖是十分必要的。引入生成對(duì)抗網(wǎng)絡(luò)[10]中的判別器,識(shí)別輸入的特征圖是“真”(教師網(wǎng)絡(luò))還是“假”(學(xué)生網(wǎng)絡(luò)),使學(xué)生網(wǎng)絡(luò)能夠自動(dòng)學(xué)習(xí)類(lèi)間的相關(guān)性。為了使學(xué)生網(wǎng)絡(luò)能夠自主探索自己的最優(yōu)解空間,在訓(xùn)練過(guò)程中打破教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)之間的單向轉(zhuǎn)換通路,使其相互學(xué)習(xí),交替更新。完整流程如圖1所示,其中實(shí)線(xiàn)為網(wǎng)絡(luò)的正向傳播過(guò)程,虛線(xiàn)為目標(biāo)函數(shù)的計(jì)算過(guò)程。
圖1 對(duì)抗學(xué)習(xí)輔助下的知識(shí)蒸餾過(guò)程Fig.1 Process of knowledge distillation assisted by adversarial learning
本研究的貢獻(xiàn)總結(jié)如下:
(1)改進(jìn)經(jīng)典的知識(shí)蒸餾損失,使其不僅能夠?qū)W習(xí)教師網(wǎng)絡(luò)輸出的正確軟目標(biāo),而且能夠從中間層獲取豐富的隱含知識(shí)。
(2)引入對(duì)抗學(xué)習(xí)中的判別器,鑒別教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)特征圖的差異,進(jìn)一步縮小大模型和容量有限的小模型最優(yōu)解空間之間的差異。
(3)在訓(xùn)練過(guò)程中采用互學(xué)習(xí)的策略,使教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)學(xué)習(xí)對(duì)方的特征圖,提升泛化能力。
(4)由于本研究針對(duì)人臉識(shí)別的應(yīng)用,采用公開(kāi)的人臉數(shù)據(jù)集訓(xùn)練模型,并與已有的理想算法進(jìn)行比較,驗(yàn)證所提算法的有效性和先進(jìn)性。
神經(jīng)網(wǎng)絡(luò)的中存在很多冗余參數(shù),文獻(xiàn)[11]的研究表明,模型中有僅用1%的深度卷積就能達(dá)到和原來(lái)網(wǎng)絡(luò)相近的性能。神經(jīng)網(wǎng)絡(luò)壓縮早在文獻(xiàn)[12]的工作中就已為人所知,但最近由于現(xiàn)代深度模型的性能和計(jì)算需求的綜合增長(zhǎng)而受到了廣泛關(guān)注。
神經(jīng)網(wǎng)絡(luò)壓縮的相關(guān)方法主要分為五大類(lèi):量化、剪枝、因子分解、精細(xì)模型設(shè)計(jì)和知識(shí)蒸餾等。
量化的方法是將網(wǎng)絡(luò)的權(quán)重離散化,逐步將預(yù)先訓(xùn)練的全精度卷積網(wǎng)絡(luò)轉(zhuǎn)換為低精度卷積網(wǎng)絡(luò)?;谶@一思想,Gong等人[13]使用k-means對(duì)權(quán)值進(jìn)行聚類(lèi),然后進(jìn)行量化。量化可以簡(jiǎn)化到二進(jìn)制級(jí)別的?1和1,如XNOR-Net[14]和BinaryConnect[15],但后者不在參數(shù)更新期間量化,而是在前向和后向梯度傳遞的過(guò)程中二進(jìn)制化權(quán)重。
剪枝則是根據(jù)一定規(guī)則剔除網(wǎng)絡(luò)中鏈接的過(guò)程,根據(jù)粒度的粗細(xì)又可以分為面向鏈接的剪枝和面向卷積模板的剪枝。Han等人[16]將量化與剪枝相結(jié)合,進(jìn)一步減少存儲(chǔ)需求和網(wǎng)絡(luò)計(jì)算。在HashedNet[17]中,網(wǎng)絡(luò)連接被隨機(jī)分組到哈希桶中,相同桶的連接共享權(quán)值。然而當(dāng)使用卷積神經(jīng)網(wǎng)絡(luò)時(shí),稀疏連接不一定會(huì)加速推理。出于這個(gè)原因,Li等人[18]裁剪完整的卷積模板,而不是單個(gè)的連接。因此剪枝后的神經(jīng)網(wǎng)絡(luò)仍然進(jìn)行密集矩陣乘法,而不需要稀疏卷積庫(kù)。
因子分解的方法旨在對(duì)卷積模板的矩陣進(jìn)行低秩分解或找尋近似的低秩矩陣。其中,使用深度可分離卷積和點(diǎn)卷積的組合可以近似深度卷積模板,例如MobileNet[19]和ShuffleNet[20]。
相較于AlexNet[21]和VGG16[22]需要較多的計(jì)算資源,殘差網(wǎng)絡(luò)(Residual Network,ResNet)及其變體不僅減少了參數(shù)數(shù)量,同時(shí)保持(甚至提高)了性能。SqueezeNet[23]通過(guò)利用1×1卷積模板替換3×3卷積模板并減少3×3卷積模板的通道數(shù)量來(lái)進(jìn)一步修剪參數(shù)。此外Inception[24]、Xception[25]、CondenseNet[26]和ResNeXt[27]也有效地設(shè)計(jì)更深更寬的網(wǎng)絡(luò),而不引入比AlexNet和VGG16更多的參數(shù)。Octave卷積[28]根據(jù)不同的頻率將特征圖進(jìn)行因式分解,對(duì)不同頻率的信息進(jìn)行不同的存儲(chǔ)和操作,以實(shí)現(xiàn)基于高低頻率的輕量化存儲(chǔ)方式。Octave卷積可用于ResNet、GoogLeNet等基線(xiàn)網(wǎng)絡(luò)結(jié)構(gòu)的優(yōu)化,也可以對(duì)如MobileNet-v1&v2,ShuffleNet v1&v2等常規(guī)輕量化網(wǎng)絡(luò)進(jìn)行進(jìn)一步地優(yōu)化,能夠有效減少深度神經(jīng)網(wǎng)絡(luò)對(duì)于存儲(chǔ)空間的要求,實(shí)現(xiàn)輕量化。
除了人工設(shè)計(jì)輕量化深度神經(jīng)網(wǎng)絡(luò)以外,基于神經(jīng)架構(gòu)搜索的自動(dòng)化模型設(shè)計(jì)的優(yōu)勢(shì)愈加凸顯,如MnasNet[29]等。
知識(shí)蒸餾(Knowledge Distillation,KD)的目標(biāo)是通過(guò)使用Softmax函數(shù)之前(Logits)或者之后的輸出(分類(lèi)概率),將知識(shí)從教師網(wǎng)絡(luò)轉(zhuǎn)移至學(xué)生網(wǎng)絡(luò)。
為了使學(xué)生網(wǎng)絡(luò)能夠完成多個(gè)計(jì)算機(jī)視覺(jué)任務(wù),文獻(xiàn)[30]則是蒸餾多個(gè)教師網(wǎng)絡(luò),構(gòu)成多分支的學(xué)生網(wǎng)絡(luò)。文獻(xiàn)[31]優(yōu)化了教師-學(xué)生策略,弱化策略中兩者的指導(dǎo)與學(xué)習(xí)關(guān)系,兩個(gè)均從頭訓(xùn)練,相互學(xué)習(xí),并且采用循環(huán)訓(xùn)練策略同時(shí)訓(xùn)練多個(gè)網(wǎng)絡(luò)。文獻(xiàn)[32]中提出的教師-學(xué)生策略中,兩種網(wǎng)絡(luò)采用相同的架構(gòu),而使用不同分辨率的人臉圖像,能在一定程度上解決實(shí)際應(yīng)用中圖像分辨率低的問(wèn)題。學(xué)生網(wǎng)絡(luò)的訓(xùn)練方法分為知識(shí)蒸餾和知識(shí)遷移兩種,前者隨機(jī)初始化學(xué)生網(wǎng)絡(luò)參數(shù),采用基于分類(lèi)的交叉熵?fù)p失和特征向量間的歐氏距離更新網(wǎng)路,后者則是用教師網(wǎng)絡(luò)的參數(shù)初始化學(xué)生網(wǎng)絡(luò),然后只采用交叉熵?fù)p失更新網(wǎng)絡(luò)。
生成對(duì)抗網(wǎng)絡(luò)通過(guò)對(duì)抗學(xué)習(xí)生成圖像,使生成器能夠模擬特定的特征分布空間。這一本質(zhì)特性與知識(shí)蒸餾的目的存在交叉,可以將小容量的學(xué)生網(wǎng)絡(luò)看作生成器,在給定相同輸入圖像的情況下,將學(xué)生輸出映射到教師輸出。
Belagiannis等人提出了一種基于對(duì)抗學(xué)習(xí)的網(wǎng)絡(luò)壓縮算法[33],去掉了KD損失中學(xué)生網(wǎng)絡(luò)輸出與真實(shí)標(biāo)簽的交叉熵?fù)p失,而是使用兩者Logits間的L2范數(shù),故不需要提前給出訓(xùn)練樣本的真實(shí)標(biāo)簽。由于學(xué)生網(wǎng)絡(luò)具有較小的容量,很難使其完全精確地模仿教師網(wǎng)絡(luò)的軟目標(biāo),增加對(duì)抗損失,使學(xué)生網(wǎng)絡(luò)能夠能快地收斂于教師網(wǎng)絡(luò)的最優(yōu)解空間。由于判別器過(guò)早達(dá)到平衡會(huì)使學(xué)生網(wǎng)絡(luò)無(wú)法從教師網(wǎng)絡(luò)學(xué)習(xí)到有效的梯度,引入對(duì)判別器的正則化,避免判別器支配后續(xù)的訓(xùn)練過(guò)程。
針對(duì)人臉識(shí)別任務(wù)的特性,本研究?jī)?yōu)化了現(xiàn)有的知識(shí)蒸餾方法,不只簡(jiǎn)單學(xué)習(xí)分類(lèi)概率,同時(shí)考慮特征圖間的知識(shí)遷移。為了使學(xué)生網(wǎng)絡(luò)能夠探索自己的最優(yōu)解空間,加入判別損失這一更加宏觀的標(biāo)準(zhǔn),使學(xué)生網(wǎng)絡(luò)在訓(xùn)練過(guò)程中具有更多的自主性。
在視頻監(jiān)控系統(tǒng)的應(yīng)用中,為了具有較好的人臉識(shí)別性能,多采用具有更深更寬結(jié)構(gòu)的深度學(xué)習(xí)模型。相反,在一些現(xiàn)實(shí)場(chǎng)景中,為了滿(mǎn)足資源有限的設(shè)備的需求,需要對(duì)已有模型進(jìn)行剪枝或量化。為了解決這兩個(gè)目標(biāo)之間的權(quán)衡困難,本研究提出了一種對(duì)抗學(xué)習(xí)輔助下的知識(shí)蒸餾算法。
本研究主要從知識(shí)獲取對(duì)象和知識(shí)蒸餾策略?xún)蓚€(gè)方面入手優(yōu)化了現(xiàn)有算法,從以下三個(gè)方面進(jìn)行詳細(xì)闡述。
知識(shí)蒸餾的基本思想是通過(guò)最小化教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)間的預(yù)測(cè)分布的差異,使學(xué)生網(wǎng)絡(luò)近似于教師網(wǎng)絡(luò)。神經(jīng)網(wǎng)絡(luò)通常通過(guò)使用Softmax輸出層來(lái)產(chǎn)生分類(lèi)概率,將計(jì)算出的每個(gè)類(lèi)別的Logits轉(zhuǎn)換為分類(lèi)概率,如式(1)所示:
其中,zi為L(zhǎng)ogits的第i個(gè)分量,T為溫度參數(shù),越高的溫度會(huì)產(chǎn)生越軟的類(lèi)間分類(lèi)概率。
知識(shí)蒸餾損失有兩部分組成,一是分類(lèi)概率間的交叉熵,學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)使用相同的溫度T,二是學(xué)生網(wǎng)絡(luò)的分類(lèi)預(yù)測(cè)與真實(shí)標(biāo)簽間的交叉熵?fù)p失,溫度為1,如式(2)所示:
其中,N為小批量的尺寸,LCE代表交叉熵,也可以用相對(duì)熵,即Kullback-Leibler散度代替。σ()代表Softmax函數(shù),T為蒸餾溫度,yi為樣本i的真實(shí)標(biāo)簽,zS∈?C和zT∈?C分別為C類(lèi)分類(lèi)任務(wù)的學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)輸出的Logits。
雖然訓(xùn)練初期教師網(wǎng)絡(luò)比學(xué)生網(wǎng)絡(luò)更準(zhǔn)確,但教師仍然會(huì)有一些預(yù)測(cè)錯(cuò)誤。當(dāng)教師網(wǎng)絡(luò)預(yù)測(cè)錯(cuò)誤時(shí),知識(shí)同樣會(huì)轉(zhuǎn)移到學(xué)生網(wǎng)絡(luò)身上,這將會(huì)影響學(xué)生網(wǎng)絡(luò)的表現(xiàn)。因此改進(jìn)傳統(tǒng)知識(shí)蒸餾的方法,忽略教師網(wǎng)絡(luò)錯(cuò)誤的預(yù)測(cè)分布,只把正確的預(yù)測(cè)分布傳遞給學(xué)生網(wǎng)絡(luò),具體目標(biāo)函數(shù)如式(3)所示:
其中,I()為指示函數(shù),yS為學(xué)生網(wǎng)絡(luò)預(yù)測(cè)的標(biāo)簽。當(dāng)教師網(wǎng)絡(luò)能夠正確預(yù)測(cè)輸入樣本的分類(lèi)時(shí),指示函數(shù)為1,學(xué)生網(wǎng)絡(luò)同時(shí)學(xué)習(xí)樣本標(biāo)簽和教師網(wǎng)絡(luò)輸出的軟目標(biāo);教師網(wǎng)絡(luò)無(wú)法正確分類(lèi)時(shí),指示函數(shù)為0,僅計(jì)算學(xué)生網(wǎng)絡(luò)的分類(lèi)情況和真實(shí)標(biāo)簽間的交叉熵。
研究表明[34],位于較淺層面的卷積層會(huì)對(duì)邊緣、角度、曲線(xiàn)等低級(jí)特征做出響應(yīng);下一級(jí)卷積層則能響應(yīng)更復(fù)雜的特征,如圓和矩形。因此,當(dāng)卷積操作逐漸加深,卷積層會(huì)提取出更復(fù)雜、高維的特征。另一方面,深度卷積特征也比淺層卷積特征更能代表網(wǎng)絡(luò)的泛化能力。因此,選擇教師網(wǎng)絡(luò)Logits前的特征圖作為學(xué)生網(wǎng)絡(luò)學(xué)習(xí)的對(duì)象,分別記為f T和f S。
常用相似度計(jì)算方法有余弦距離、歐氏距離、馬氏距離等。這些方法的目的均是使學(xué)生網(wǎng)絡(luò)最大程度模仿教師網(wǎng)絡(luò)輸出的特征圖。由于學(xué)生網(wǎng)絡(luò)的容量小,它可能無(wú)法精確地再現(xiàn)某一特定的輸出模態(tài),并且實(shí)際中學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)具有不同的結(jié)構(gòu),沒(méi)有必要精確地模擬一個(gè)教師網(wǎng)絡(luò)的輸出來(lái)達(dá)到良好的表現(xiàn)。因此本研究提出一個(gè)面向教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的對(duì)抗學(xué)習(xí)機(jī)制。對(duì)抗訓(xùn)練策略緩解了人工設(shè)計(jì)損失函數(shù)的困難,已經(jīng)在多個(gè)計(jì)算機(jī)視覺(jué)任務(wù)中顯示出了優(yōu)越性。
特征圖學(xué)習(xí)機(jī)制由三部分組成,即教師網(wǎng)絡(luò)、學(xué)生網(wǎng)絡(luò)和判別器。教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的輸入為相同的人臉圖像,將其輸出的特征圖作為判別器的輸入,判別器鑒定其來(lái)自哪個(gè)網(wǎng)絡(luò)。采用生成對(duì)抗網(wǎng)絡(luò)中的對(duì)抗損失作為目標(biāo)函數(shù),如式(4)所示:
在訓(xùn)練過(guò)程中,判別器的目的是最小化對(duì)抗損失,確保正確區(qū)分兩個(gè)不同的分布;學(xué)生網(wǎng)絡(luò)的目的則是使判別器無(wú)法區(qū)分其與教師網(wǎng)絡(luò)的差異,以此構(gòu)成對(duì)抗訓(xùn)練判別器和學(xué)生網(wǎng)絡(luò)交替更新,直至判別器的識(shí)別準(zhǔn)確率為1/2,此時(shí)網(wǎng)絡(luò)收斂。
相較分類(lèi)概率,高維特征圖能夠保留更多的特征,采用高維的特征圖作為判別器的鑒定對(duì)象能夠使判別器具有更強(qiáng)的鑒別能力,指導(dǎo)學(xué)生網(wǎng)絡(luò)的更新,最小化與教師網(wǎng)絡(luò)的差異。
相關(guān)研究表明[35],教師網(wǎng)絡(luò)的影響不總是積極的。在網(wǎng)絡(luò)訓(xùn)練的前期,知識(shí)蒸餾輔助學(xué)生網(wǎng)絡(luò)的更新,但是隨后會(huì)抑制學(xué)生網(wǎng)絡(luò)的優(yōu)化。實(shí)驗(yàn)結(jié)果表明,在某一時(shí)期,交叉熵?fù)p失會(huì)反向上升。此外模型蒸餾算法需要有提前預(yù)訓(xùn)練好的教師網(wǎng)絡(luò),且教師網(wǎng)絡(luò)在學(xué)習(xí)過(guò)程中保持固定,僅對(duì)學(xué)生網(wǎng)絡(luò)進(jìn)行單向的知識(shí)傳遞,難以從學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)狀態(tài)中得到反饋信息來(lái)對(duì)訓(xùn)練過(guò)程進(jìn)行優(yōu)化調(diào)整。
深度互學(xué)習(xí)[36]指即多個(gè)網(wǎng)絡(luò)相互學(xué)習(xí),每個(gè)網(wǎng)絡(luò)在訓(xùn)練過(guò)程中不僅接受來(lái)自真值標(biāo)記的監(jiān)督,還參考同伴網(wǎng)絡(luò)的學(xué)習(xí)經(jīng)驗(yàn)來(lái)進(jìn)一步提升泛化能力。真值標(biāo)簽提供的信息僅包含樣本是否屬于某一類(lèi),但缺少不同類(lèi)別之間的聯(lián)系,而網(wǎng)絡(luò)輸出的分類(lèi)概率則能夠在一定程度上恢復(fù)該信息,因此網(wǎng)絡(luò)之間進(jìn)行分類(lèi)概率交叉學(xué)習(xí)可以傳遞學(xué)習(xí)到的數(shù)據(jù)分布特性,從而幫助網(wǎng)絡(luò)改善泛化性能。其次,網(wǎng)絡(luò)在訓(xùn)練過(guò)程中會(huì)參考同伴網(wǎng)絡(luò)的經(jīng)驗(yàn)來(lái)調(diào)整自己的學(xué)習(xí)過(guò)程,最終能夠收斂到一個(gè)更平緩的極小值點(diǎn),小的波動(dòng)不會(huì)對(duì)網(wǎng)絡(luò)的預(yù)測(cè)結(jié)果造成劇烈影響,從而具備更好的泛化性能。
本研究中采用學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)特征圖間的Jensen-Shannon散度作為互學(xué)習(xí)的目標(biāo)函數(shù),如式(5)所示。相較于KL散度,JS散度是對(duì)稱(chēng)的,解決了KL散度非對(duì)稱(chēng)的問(wèn)題。
其中,qT和qS分別為教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的分類(lèi)概率分布。聯(lián)合基于對(duì)抗學(xué)習(xí)的特征圖遷移和互學(xué)習(xí)方法,學(xué)生網(wǎng)絡(luò)不僅能夠模仿教師網(wǎng)絡(luò)特征圖的分布,同時(shí)能夠保留自主學(xué)習(xí)的能力。
綜上所述,用以訓(xùn)練學(xué)生網(wǎng)絡(luò)的目標(biāo)函數(shù)的完整形式為:
其中,α,β∈[0,1)為超參數(shù),用以平衡各部分間的權(quán)重。
3.1.1 數(shù)據(jù)集
本研究中所采用CASIA-WebFace[37]和CelebA[38]數(shù)據(jù)集作為訓(xùn)練樣本集。CASIA-WebFace數(shù)據(jù)集中的樣本來(lái)自IMDb網(wǎng)站,有10 575人的494 414張照片。CelebA數(shù)據(jù)集包含超過(guò)200 000的名人圖片,每張圖片有40個(gè)屬性標(biāo)注。該數(shù)據(jù)集中的圖像具有較大的姿態(tài)變化和復(fù)雜背景。對(duì)所有圖像進(jìn)行歸一化處理,尺寸統(tǒng)一為256×256。
3.1.2 網(wǎng)絡(luò)結(jié)構(gòu)
由于殘差網(wǎng)絡(luò)在圖像分類(lèi)應(yīng)用中具有最優(yōu)的性能,所以本研究中采用ResNet-101作為教師網(wǎng)絡(luò)。ResNet-101由兩種殘差塊組成,一是Identity Block,輸入和輸出的維度相同,二是Conv Block,輸入和輸出的維度不同,用以改變特征向量的維度。教師網(wǎng)絡(luò)預(yù)先在數(shù)據(jù)集上訓(xùn)練,在學(xué)生網(wǎng)絡(luò)的訓(xùn)練過(guò)程中的互學(xué)習(xí)階段微調(diào)。學(xué)生網(wǎng)絡(luò)在實(shí)驗(yàn)中采用和教師網(wǎng)絡(luò)同樣采用殘差結(jié)構(gòu)的ResNet-18,在訓(xùn)練前隨機(jī)分配鏈接權(quán)重。具體結(jié)構(gòu)如表1所示。
表1 教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的具體結(jié)構(gòu)Table 1 Detailed structure of teacher and student
判別器是模型中的重要組成部分,必須在簡(jiǎn)單性和網(wǎng)絡(luò)容量之間取得平衡。判別器由三個(gè)全連接層組成(128fc-256fc-128fc),中間激活為非線(xiàn)性激活單元ReLU。輸出則是由Sigmoid函數(shù)給出的二元預(yù)測(cè),判定輸入的特征圖來(lái)自哪個(gè)網(wǎng)絡(luò)。
3.1.3 相關(guān)參數(shù)設(shè)置
在所有實(shí)驗(yàn)中,均采用隨機(jī)梯度下降的算法。根據(jù)文獻(xiàn)[39],動(dòng)量設(shè)為0.9,權(quán)重衰減為0.000 1,小批量尺寸為128。對(duì)于兩個(gè)訓(xùn)練集,均訓(xùn)練200輪次。初始學(xué)習(xí)率為0.1,分別在第80輪次和160輪次分別將學(xué)習(xí)率調(diào)整為0.01和0.001。
在前160輪中,只采用知識(shí)蒸餾損失和對(duì)抗損失指導(dǎo)學(xué)生網(wǎng)絡(luò)的更新,使學(xué)生網(wǎng)絡(luò)的分類(lèi)錯(cuò)誤率迅速下降;在后40輪中加入兩個(gè)網(wǎng)絡(luò)的互學(xué)習(xí)損失,對(duì)學(xué)生網(wǎng)絡(luò)微調(diào),具有更好的泛化能力。在多次實(shí)驗(yàn)中選取最優(yōu)的超參數(shù),使學(xué)生網(wǎng)絡(luò)盡快收斂,故在上述兩個(gè)訓(xùn)練階段超參數(shù)分別取α=0.5,β=0和α=0.5,β=0.05。
網(wǎng)絡(luò)壓縮分為兩個(gè)階段:首先,在學(xué)生網(wǎng)絡(luò)訓(xùn)練的前半段,知識(shí)蒸餾損失和對(duì)抗損失指導(dǎo)學(xué)生網(wǎng)絡(luò)和判別器的更新,在訓(xùn)練的后半段,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)互學(xué)習(xí),對(duì)模型進(jìn)行微調(diào),以提高泛化能力。算法偽碼如算法1所示。
算法1基于知識(shí)蒸餾和對(duì)抗學(xué)習(xí)的網(wǎng)絡(luò)壓縮算法
輸入 預(yù)訓(xùn)練的教師網(wǎng)絡(luò);隨機(jī)初始化的學(xué)生網(wǎng)絡(luò)和判別器;訓(xùn)練數(shù)據(jù)集;學(xué)習(xí)率、小批量尺寸及α,β
輸出 小尺寸的學(xué)生網(wǎng)絡(luò)
本節(jié)中從兩個(gè)方面定量地分析所提方法的有效性,首先比較全監(jiān)督下訓(xùn)練的學(xué)生網(wǎng)絡(luò)和不同溫度下知識(shí)蒸餾得到的學(xué)生網(wǎng)絡(luò),如表2所示;然后在CASIA WEBFACE和CelebA數(shù)據(jù)集上訓(xùn)練其他知識(shí)蒸餾方法,與所提方法比較,如表3所示。將數(shù)據(jù)集按照4∶4∶2的比例隨機(jī)劃分為訓(xùn)練集、測(cè)試集和驗(yàn)證集,表中準(zhǔn)確率在驗(yàn)證集上獲得,評(píng)價(jià)指標(biāo)為T(mén)op-1準(zhǔn)確率。
表2 全監(jiān)督訓(xùn)練與知識(shí)蒸餾性能對(duì)比Table 2 Comparison between fully supervised training and knowledge distillation%
表3 所提方法與其他知識(shí)蒸餾方法的性能對(duì)比Table 3 Comparison between proposed method and other knowledge distillation methods %
表2 中前兩行分別為全監(jiān)督訓(xùn)練方式下得到的教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò),網(wǎng)絡(luò)結(jié)構(gòu)分別固定為ResNet-101和ResNet-18。ResNet-101的參數(shù)量為44.6×106,浮點(diǎn)運(yùn)算量為1.8 GFLOPs,而ResNet-18的參數(shù)量?jī)H為11.2×106,浮點(diǎn)運(yùn)算量為7.6 GFLOPs。同時(shí)利用本文所提出的知識(shí)蒸餾算法訓(xùn)練ResNet-18網(wǎng)絡(luò),溫度T∈{1,2,5,10}。表2中還可以看出,由于網(wǎng)絡(luò)深度不同,全監(jiān)督訓(xùn)練下的學(xué)生網(wǎng)絡(luò)在兩種驗(yàn)證集上的準(zhǔn)確率均低于教師網(wǎng)絡(luò),但是學(xué)生網(wǎng)絡(luò)的參數(shù)數(shù)量?jī)H為教師網(wǎng)絡(luò)的1/4,模型尺寸和識(shí)別性能間有較好的平衡。采用知識(shí)蒸餾算法得到的學(xué)生網(wǎng)絡(luò)在驗(yàn)證集上的性能略低于全監(jiān)督訓(xùn)練的教師網(wǎng)絡(luò),但是隨著溫度T的增加,其性能超越了全監(jiān)督訓(xùn)練的學(xué)生網(wǎng)絡(luò)。由此可見(jiàn),本研究所提出的知識(shí)蒸餾算法能夠使小尺寸的網(wǎng)絡(luò)學(xué)習(xí)到大規(guī)模網(wǎng)絡(luò)的知識(shí),有效地實(shí)現(xiàn)知識(shí)遷移。同時(shí),溫度T越高,輸出的分類(lèi)概率越平緩,得到的識(shí)別性能越好;但是隨著溫度的逐漸增加,性能提高趨于平緩,表明較高的溫度會(huì)引起錯(cuò)誤標(biāo)簽的概率增加,使學(xué)生網(wǎng)絡(luò)更多地關(guān)注相關(guān)知識(shí)。因此在知識(shí)蒸餾方法中選擇合適的溫度T是比較重要的。
本研究中還將所提方法與其他知識(shí)蒸餾算法進(jìn)行了比較,包括經(jīng)典知識(shí)蒸餾算法KD、深度互學(xué)習(xí)算法DML、FitNet[40]、Channel Distillation[41]和KTAN[42]。FitNet將中間層的表示作為學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)對(duì)象,要求學(xué)生網(wǎng)絡(luò)的中間層模仿老師網(wǎng)絡(luò)特定的中間層的輸出。CD中分別計(jì)算教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)每個(gè)通道的注意力信息,教師監(jiān)督學(xué)生學(xué)習(xí)其注意力信息,已達(dá)到通道間傳遞信息的目的。KTAN除了采用分類(lèi)概率間的交叉熵?fù)p失外,采用均方誤差損失最小化學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)特征圖之間的差異,同時(shí)引入判別器鑒定特征圖的出處。
表3 中的數(shù)據(jù)均在CASIA WEBFACE和CelebA數(shù)據(jù)集上獲得,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的結(jié)構(gòu)分別為ResNet-101和ResNet-18,訓(xùn)練參數(shù)如3.1.3小節(jié)所示。由表中數(shù)據(jù)可知,由于教師網(wǎng)絡(luò)具有更深的結(jié)構(gòu),所以展現(xiàn)出最好的性能。五種引自文獻(xiàn)的方法,除了CD和KTAN在CelebA數(shù)據(jù)集上的性能外,其余數(shù)據(jù)在Top-1性能中均低于全監(jiān)督的學(xué)生網(wǎng)絡(luò)。本文所提出的知識(shí)蒸餾方法相較其他知識(shí)蒸餾方法,學(xué)生網(wǎng)絡(luò)不僅學(xué)習(xí)了教師網(wǎng)絡(luò)正確的分類(lèi)概率,同時(shí)利用對(duì)抗學(xué)習(xí)的方法使學(xué)生網(wǎng)絡(luò)定性學(xué)習(xí)教師網(wǎng)絡(luò)的特征圖;在訓(xùn)練的后半段,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)相互學(xué)習(xí),促進(jìn)學(xué)生網(wǎng)絡(luò)探索自己的最優(yōu)解空間,具有更好的泛化能力。在CASIA WEBFACE和CelebA兩個(gè)數(shù)據(jù)集上,本研究的識(shí)別準(zhǔn)確率均超過(guò)了全監(jiān)督訓(xùn)練得到的學(xué)生網(wǎng)絡(luò),證明了所提方法的有效性和先進(jìn)性。
本節(jié)中,通過(guò)消融實(shí)驗(yàn)分別驗(yàn)證了式(6)中知識(shí)蒸餾損失,對(duì)抗損失和互學(xué)習(xí)損失的有效性。實(shí)驗(yàn)結(jié)果如表4所示。
表4 基于不同損失函數(shù)的性能對(duì)比Table 3 Performance comparison based on different loss functions %
對(duì)比表3和表4中數(shù)據(jù)可以看出,改進(jìn)后的知識(shí)蒸餾損失剔除了錯(cuò)誤估計(jì)樣本,使識(shí)別性能得到了小幅度提升。分別在知識(shí)蒸餾損失的基礎(chǔ)上添加基于特征圖的對(duì)抗損失和互學(xué)習(xí)損失后,性能的提成在1%左右,均超過(guò)了互學(xué)習(xí)DML算法、將中間層作為學(xué)習(xí)對(duì)象的FitNet和CD方法。對(duì)比KTAN,本研究方法摒棄了均方誤差,使學(xué)生網(wǎng)絡(luò)不必完全模擬教師網(wǎng)絡(luò)的特征圖,能夠自主更新,性能得到了進(jìn)一步的提升。
通過(guò)消融實(shí)驗(yàn)可知,本研究所采用的三種損失函數(shù)能夠有效地完成從教師網(wǎng)絡(luò)到學(xué)生網(wǎng)絡(luò)的知識(shí)遷移,不僅包括分類(lèi)概率知識(shí),同時(shí)學(xué)習(xí)了特征圖的分布,并且適度保留自主學(xué)習(xí)的空間。
深度學(xué)習(xí)方法越來(lái)越多地應(yīng)用于計(jì)算機(jī)視覺(jué)任務(wù)中,并且在人臉識(shí)別任務(wù)中的準(zhǔn)確率已經(jīng)超越了人眼,但是仍然面對(duì)著訓(xùn)練時(shí)間長(zhǎng)、泛化能力弱、部署困難等落地問(wèn)題。本研究針對(duì)深度學(xué)習(xí)模型在嵌入式設(shè)備難以進(jìn)行部署和實(shí)時(shí)性能差的問(wèn)題,深入研究了現(xiàn)有的模型壓縮和加速算法,提出了一種基于知識(shí)蒸餾和對(duì)抗學(xué)習(xí)的神經(jīng)網(wǎng)絡(luò)壓縮算法。算法框架由三部分組成,預(yù)訓(xùn)練得到的大規(guī)模教師網(wǎng)絡(luò),輕量級(jí)的學(xué)生網(wǎng)絡(luò)和輔助對(duì)抗學(xué)習(xí)的判別器。學(xué)生網(wǎng)絡(luò)通過(guò)知識(shí)蒸餾損失學(xué)習(xí)教師網(wǎng)絡(luò)的分類(lèi)概率,同時(shí)通過(guò)對(duì)抗損失模擬教師網(wǎng)絡(luò)的特征圖知識(shí)。鑒于教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)具有不同的最優(yōu)解空間,在訓(xùn)練的后半段利用深度互學(xué)習(xí)理論,促使學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)相互學(xué)習(xí),以促使學(xué)生網(wǎng)絡(luò)探索自己的最優(yōu)解。
針對(duì)人臉識(shí)別任務(wù),采用CASIA WEBFACE和CelebA兩個(gè)數(shù)據(jù)集作為訓(xùn)練集,通過(guò)消融實(shí)驗(yàn)驗(yàn)證了所提組合目標(biāo)函數(shù)的有效性,同時(shí)與面向特征圖知識(shí)蒸餾算法和基于對(duì)抗學(xué)習(xí)訓(xùn)練的模型壓縮算法對(duì)比,實(shí)驗(yàn)數(shù)據(jù)表明,根據(jù)所提算法訓(xùn)練得到的學(xué)生網(wǎng)絡(luò)具有較少的鏈接數(shù),同時(shí)保證了較好的識(shí)別準(zhǔn)確率。