葉峰 陳彪 賴乙宗
基于特征空間嵌入的對比知識蒸餾算法
葉峰 陳彪 賴乙宗
(華南理工大學(xué) 機(jī)械與汽車工程學(xué)院,廣東 廣州 510640)
因能有效地壓縮卷積神經(jīng)網(wǎng)絡(luò)模型,知識蒸餾在深度學(xué)習(xí)領(lǐng)域備受關(guān)注。然而,經(jīng)典知識蒸餾算法在進(jìn)行知識遷移時,只利用了單個樣本的信息,忽略了樣本間關(guān)系的重要性,算法性能欠佳。為了提高知識蒸餾算法知識遷移的效率和性能,文中提出了一種基于特征空間嵌入的對比知識蒸餾(FSECD)算法。該算法采用批次內(nèi)構(gòu)建策略,將學(xué)生模型的輸出特征嵌入到教師模型特征空間中,使得每個學(xué)生模型的輸出特征和教師模型輸出的個特征構(gòu)成個對比對。每個對比對中,教師模型的輸出特征是已優(yōu)化、固定的,學(xué)生模型的輸出特征是待優(yōu)化、可調(diào)優(yōu)的。在訓(xùn)練過程中,F(xiàn)SECD縮小正對比對的距離并擴(kuò)大負(fù)對比對的距離,使得學(xué)生模型可感知并學(xué)習(xí)教師模型輸出特征的樣本間關(guān)系,進(jìn)而實(shí)現(xiàn)教師模型知識向?qū)W生模型的遷移。在CIFAR-100和ImageNet數(shù)據(jù)集上對不同師生網(wǎng)絡(luò)架構(gòu)進(jìn)行的實(shí)驗(yàn)結(jié)果表明,與其他主流蒸餾算法相比,F(xiàn)SECD算法在不需要額外的網(wǎng)絡(luò)結(jié)構(gòu)和數(shù)據(jù)的情況下,顯著提升了性能,進(jìn)一步證明了樣本間關(guān)系在知識蒸餾中的重要性。
圖像分類;知識蒸餾;卷積神經(jīng)網(wǎng)絡(luò);深度學(xué)習(xí);對比學(xué)習(xí)
近十年,卷積神經(jīng)網(wǎng)絡(luò)在計(jì)算機(jī)視覺任務(wù)中取得了巨大的成功,并廣泛應(yīng)用于圖像分類[1-6]、圖像檢測[7-8]和圖像分割[9-10]等領(lǐng)域,其中圖像分類被認(rèn)為是其他視覺任務(wù)的基礎(chǔ)。隨著網(wǎng)絡(luò)容量的增加,在有限的硬件資源下部署卷積神經(jīng)網(wǎng)絡(luò)變得越來越困難,獲得準(zhǔn)確率高且輕量級的卷積神經(jīng)網(wǎng)絡(luò),對于實(shí)際應(yīng)用至關(guān)重要。針對這一問題,研究者們提出了網(wǎng)絡(luò)修剪[11]、量化[12]、低秩分解[13]和知識蒸餾[14-19]等技術(shù)。經(jīng)典知識蒸餾(KD)方法最初由Hinton等[14]提出,該方法通過縮小兩個卷積神經(jīng)網(wǎng)絡(luò)預(yù)測概率之間的Kullback-Leibler(KL)散度來實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)之間知識的遷移。知識蒸餾的原理在于,相較于原始的標(biāo)簽,教師模型的預(yù)測概率中隱含了輸入樣本與非目標(biāo)類別之間相似性關(guān)系的信息。通常情況下,教師模型結(jié)構(gòu)復(fù)雜,擁有良好的性能和泛化能力,學(xué)生模型是輕量級的,更適合部署在邊緣設(shè)備,但性能不如教師模型。
Gou等[20]對主流的知識蒸餾方法做了梳理,從知識類別、訓(xùn)練方案、師生架構(gòu)等角度對知識蒸餾方法進(jìn)行了全面的介紹。根據(jù)文獻(xiàn)[20],可以將知識蒸餾方法中教師模型的知識分為基于響應(yīng)的知識、基于特征的知識以及基于關(guān)系的知識。Hinton等[14]提出的傳統(tǒng)KD方法使用的是基于響應(yīng)的知識。Zhao等[21]將KL散度分解為目標(biāo)類知識蒸餾(TCKD)與非目標(biāo)類知識蒸餾(NCKD)兩部分,以研究目標(biāo)類別的響應(yīng)和非目標(biāo)類別的響應(yīng)對知識蒸餾的影響,結(jié)果發(fā)現(xiàn)NCKD更加重要且TCKD與NCKD是耦合的,并進(jìn)一步提出了解耦知識蒸餾(DKD)用以實(shí)現(xiàn)解耦。
Romero等[17]通過縮小教師模型和學(xué)生模型中間特征層輸出特征圖的差異來訓(xùn)練學(xué)生模型,使用了基于特征的知識。該方法將教師模型某一隱含層之前的網(wǎng)絡(luò)結(jié)構(gòu)定義為Hint層,將學(xué)生模型某一隱含層之前的網(wǎng)絡(luò)定義為Guided層,并定義教師模型和學(xué)生模型之間的距離為Hint層的輸出特征與回歸器轉(zhuǎn)化后的Guided層輸出特征之間的歐幾里得距離。該方法將卷積神經(jīng)網(wǎng)絡(luò)知識的定義從網(wǎng)絡(luò)的整體輸出拓寬到中間層的隱含表達(dá)上。Zagoruyko等[18]從人類視覺體驗(yàn)中注意力的作用得到啟發(fā),提出了學(xué)生模型通過模仿教師模型的注意力特征圖來提高學(xué)生模型性能的方法。該方法在訓(xùn)練過程中實(shí)時計(jì)算教師模型和學(xué)生模型的多個隱含層以及最終輸出層的注意力特征圖,并縮小教師模型和學(xué)生模型對應(yīng)的注意力特征圖之間的歐幾里得距離。Heo等[19]對教師模型和學(xué)生模型的特征變換及距離函數(shù)的形式進(jìn)行了全面的分析,認(rèn)為對教師模型的特征進(jìn)行變形會導(dǎo)致教師模型知識的缺失,提出了使用pre-ReLU的特征進(jìn)行知識遷移的方法,同時設(shè)計(jì)了一個新的距離函數(shù)以實(shí)現(xiàn)教師模型和學(xué)生模型之間的知識遷移。
有些學(xué)者則提出要關(guān)注樣本間關(guān)系所包含的豐富的結(jié)構(gòu)信息,即使用基于關(guān)系的知識。Park等[15]提出了將教師模型輸出特征之間的結(jié)構(gòu)化關(guān)系遷移給學(xué)生模型的方法,同時提出了二階的樣本間距離損失和三階的樣本間角度損失,但該方法在計(jì)算距離時,分配給所有樣本的權(quán)重是一樣的,缺乏樣本之間的相互重要性的考慮。Gou等[22]基于網(wǎng)絡(luò)的注意力區(qū)域具有更多信息和使用單層特征圖進(jìn)行知識蒸餾容易過擬合等特點(diǎn),利用來自多個中間層的注意力圖構(gòu)建樣本間關(guān)系,在多種不同類型數(shù)據(jù)集上進(jìn)行了廣泛的實(shí)驗(yàn)并取得了優(yōu)秀的結(jié)果。
近年來對比學(xué)習(xí)方法廣泛應(yīng)用于無監(jiān)督學(xué)習(xí)[23-24],該方法通過縮小正對比對的距離并擴(kuò)大負(fù)對比對的距離來實(shí)現(xiàn)特征聚類。有學(xué)者[16,25]嘗試將對比學(xué)習(xí)引入知識蒸餾,以實(shí)現(xiàn)學(xué)生模型對教師模型的模仿,取得了一定的成績,這些方法的基本原理是挖掘數(shù)據(jù)中的結(jié)構(gòu)化關(guān)系,使用的也是基于關(guān)系的知識。在對比學(xué)習(xí)中,對比對的構(gòu)建至關(guān)重要,目前主流的方法主要有兩種。第一種方法是將同一圖像進(jìn)行兩種不同的數(shù)據(jù)增強(qiáng)處理,得到該圖像的兩種數(shù)據(jù)增強(qiáng)形式且兩種形式互為正例,并將批次內(nèi)的其他圖像數(shù)據(jù)增強(qiáng)形式作為負(fù)例。Xu等[25]采用了該方法構(gòu)建對比對,提出了自我監(jiān)督知識蒸餾(SSKD)方法,該方法在教師模型和學(xué)生模型分別獨(dú)立進(jìn)行對比學(xué)習(xí),通過提高學(xué)生模型和教師模型的對比矩陣的相似度來實(shí)現(xiàn)知識蒸餾。該方法將批次內(nèi)圖像的數(shù)量擴(kuò)大了兩倍,因而計(jì)算和顯存的消耗也擴(kuò)大了兩倍。第二種方法是將訓(xùn)練集所有圖像在上一個迭代的特征存儲在記憶庫中。在當(dāng)前迭代中,用輸入網(wǎng)絡(luò)的圖像的輸出特征和記憶庫中相同類別的圖像特征構(gòu)建正對比對,并在記憶庫中隨機(jī)用其他圖像的特征構(gòu)建負(fù)對比對。Tian等[16]提出的對比表示知識蒸餾(CRD)算法屬于此類,其優(yōu)化目標(biāo)是最大化教師模型和學(xué)生模型輸出特征之間的互信息。雖然該方法沒有提高計(jì)算上的消耗,但在訓(xùn)練時用內(nèi)存庫存儲每個樣本的特征需要大量額外的顯存。此外,在CRD中還需要額外的網(wǎng)絡(luò)模塊對齊教師模型和學(xué)生模型輸出特征的維度。
針對以上問題,文中提出了一種基于特征空間嵌入的對比知識蒸餾(FSECD)算法,該算法使用基于關(guān)系的知識和模型全連接層的輸出作為輸入,使得算法在保持高性能的同時,在硬件資源和計(jì)算耗時上的增加幾乎可以忽略不計(jì)。在訓(xùn)練過程中,F(xiàn)SECD算法將對應(yīng)相同輸入圖像的學(xué)生模型輸出特征和教師模型輸出特征組成正對比對,輸入圖像不互相對應(yīng)的特征則視為負(fù)對比對,該操作等價于將任一學(xué)生網(wǎng)絡(luò)輸出特征嵌入到教師模型的特征空間中,在多個教師模型輸出特征的作用下,根據(jù)梯度下降原理,自動靠近匹配度最高的教師模型輸出特征。在此過程中,學(xué)生模型間接學(xué)習(xí)了教師模型的特征空間知識,實(shí)現(xiàn)對學(xué)生模型特征空間的優(yōu)化。文中通過實(shí)驗(yàn)探究正/負(fù)例的兩種選取策略對FSECD算法性能的影響,并在兩個主流的圖像分類數(shù)據(jù)集CIFAR-100和ImageNet上評估了文中所提出的知識蒸餾算法的性能。
式中:(·;)為卷積神經(jīng)網(wǎng)絡(luò);τ為溫度系數(shù),用以調(diào)節(jié)損失函數(shù)的平滑度。假定卷積神經(jīng)網(wǎng)絡(luò)輸入圖像樣本,并生成特征向量()。為了簡單起見,文中將從()中省略,直接標(biāo)記為,同時繼承樣本的上標(biāo)和下標(biāo)。
使用該損失函數(shù)訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)(·;)時,相似的輸入圖像的輸出特征也會趨近,反之,則互相遠(yuǎn)離。默認(rèn)情況下,為了保證訓(xùn)練時模型的穩(wěn)定性,需要對特征使用2正則進(jìn)行處理。
相較于學(xué)生模型,教師模型的結(jié)構(gòu)更復(fù)雜,因此往往有強(qiáng)的分辨能力,在相同的訓(xùn)練條件下,教師模型的準(zhǔn)確率也更高。本研究將學(xué)生模型輸出的特征嵌入教師模型的特征空間,以學(xué)習(xí)教師模型輸出的多個特征之間形成的結(jié)構(gòu)化知識,實(shí)現(xiàn)更好的蒸餾效果。在使用FSECD算法訓(xùn)練的過程中,教師模型的參數(shù)被凍結(jié),只優(yōu)化學(xué)生模型的參數(shù)。
圖1 正/負(fù)例的選取策略示意圖
若采用類別級策略,即來自同一類別的圖像互為正例,則標(biāo)簽為狗的樣本,有兩個正例和兩個負(fù)例;若采取實(shí)例級策略,則只將同一圖像的教師模型輸出視為正例,其余為負(fù)例。此外,采用第二種策略時,可以將式(2)簡化為
此處可以將實(shí)例級別策略視為一個多分類器,該分類器的權(quán)重由教師模型動態(tài)提供,實(shí)現(xiàn)將任一學(xué)生模型特征正確識別為對應(yīng)教師模型特征的功能。無論采用哪種策略,都可以使學(xué)生模型向教師模型學(xué)習(xí)樣本間的關(guān)系。每個學(xué)生模型的輸出特征是獨(dú)立優(yōu)化的,但它們都被嵌入了同一個教師模型的特征空間并學(xué)習(xí)該特征空間的結(jié)構(gòu)化知識,即學(xué)習(xí)和模仿的是同一個對象。在訓(xùn)練過程中,間接優(yōu)化了學(xué)生模型自身的樣本間關(guān)系,最終使得學(xué)生模型和教師模型具有相似的特征空間。
FSECD算法能夠以模型的任一卷積層經(jīng)全局平均池處理的輸出特征為輸入(需經(jīng)全局平均池處理或展平處理),或者是以模型最終全連接層的輸出特征為輸入,然而后者與前者相比,具有以下兩個優(yōu)點(diǎn):
圖2 以全連接層輸出特征為輸入的FSECD算法的流程圖
1)全連接層的輸出特征具有更抽象的語義信息。在圖像分類中,得到全連接層的輸出特征的方法是將最后一層卷積層的特征依次輸入全局平均池化層和全連接層。該操作在數(shù)學(xué)上等價于:先將最后一層的卷積層特征輸入到一個卷積核大小為1的卷積層(且該卷積層輸出通道數(shù)等于待預(yù)測的類別數(shù)),然后進(jìn)行全局平均池化。因此,全連接層的輸出特征可以視為一種特殊的卷積層特征,且該特征處于網(wǎng)絡(luò)的最后一層。由于模型的層度越深,特征的語義越豐富,所以全連接層的輸出特征是具有最高級別的語義的特征,用于知識蒸餾可以使學(xué)生模型學(xué)習(xí)到更好的知識。
2)不存在特征對齊的問題。無論網(wǎng)絡(luò)結(jié)構(gòu)和容量如何變化,全連接層輸出特征的維數(shù)總是等于數(shù)據(jù)集要預(yù)測的類別數(shù)。采用全連接層的輸出特征甚至可以在不知道教師網(wǎng)絡(luò)結(jié)構(gòu)的情況下進(jìn)行蒸餾,即只需要教師模型的輸出。使用卷積層的輸出特征進(jìn)行知識蒸餾,在教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的特征維數(shù)不同時,存在特征維數(shù)對齊的問題。例如,將ResNet50模型的知識蒸餾給ResNet18模型時,教師模型的特征維數(shù)為2 048,而學(xué)生模型的特征維數(shù)為512,無法直接蒸餾,需要訓(xùn)練額外的模型分支以實(shí)現(xiàn)維度對齊。在網(wǎng)絡(luò)結(jié)構(gòu)不同時,特征維度對齊的情況會變得更加復(fù)雜。
基于以上原因,本研究采用全連接層的輸出特征進(jìn)行對比蒸餾,故不存在維度對齊問題。圖2是以模型全連接層輸出特征為輸入的FSECD算法的流程圖。
=CE+FSECD(4)
式中,是兩種損失函數(shù)的平衡系數(shù)。
本研究采用圖像分類領(lǐng)域主流的兩個數(shù)據(jù)集(CIFAR-100[26]和ImageNet[27]數(shù)據(jù)集)對網(wǎng)絡(luò)進(jìn)行訓(xùn)練和測試。CIFAR-100數(shù)據(jù)集包含100個類別的圖像,圖像尺寸為32×32;訓(xùn)練集包含5萬幅圖像,每個類別各500幅圖像;測試集包含1萬幅圖像,每個類別各100幅圖像。在使用CIFAR-100數(shù)據(jù)集訓(xùn)練網(wǎng)絡(luò)時,對訓(xùn)練集圖像進(jìn)行標(biāo)準(zhǔn)的數(shù)據(jù)增強(qiáng)處理,即在圖像的每邊填充4個像素,再重新將圖像裁剪為32×32的大小,然后以50%的概率進(jìn)行水平翻轉(zhuǎn);而對測試集圖像不進(jìn)行數(shù)據(jù)增強(qiáng)處理。ImageNet數(shù)據(jù)集的訓(xùn)練集包含128萬幅圖像,共1 000個類別,每個類別1 300幅圖像左右;驗(yàn)證集包含5萬幅圖像,每個類別各50幅圖像。在使用ImageNet數(shù)據(jù)集訓(xùn)練網(wǎng)絡(luò)時,隨機(jī)裁剪訓(xùn)練集圖像中的一個區(qū)域并將該區(qū)域尺寸拉伸為224×224,然后以50%的概率進(jìn)行水平翻轉(zhuǎn);而對測試集圖像同樣不進(jìn)行數(shù)據(jù)增強(qiáng)處理。
本研究在Ubuntu環(huán)境下,使用Pytorch深度學(xué)習(xí)框架對不同的知識蒸餾算法的性能進(jìn)行評估。為了覆蓋相同/相異架構(gòu)上的教師-學(xué)生模型,本研究在一系列不同架構(gòu)的卷積神經(jīng)網(wǎng)絡(luò)上進(jìn)行了實(shí)驗(yàn),使用的卷積神經(jīng)網(wǎng)絡(luò)架構(gòu)包括:
(1)WRN--[6],其中為該網(wǎng)絡(luò)架構(gòu)的深度因子,為寬度因子。
(2)ResNet。在CIFAR-100數(shù)據(jù)集上,ResNet表示具有3個卷積組的Cifar風(fēng)格的ResNet[2],為該網(wǎng)絡(luò)架構(gòu)的深度因子,每個卷積組分別有16、32和64個通道。此外,ResNet8×4和ResNet32×4分別表示深度因子為8和32且具有4倍通道數(shù)的模型。在ImageNet數(shù)據(jù)集上,ResNet表示ImageNet風(fēng)格的ResNet。
(3)MobileNetV1和MobileNetV2[5,28],本研究使用寬度因子為0.5的MobileNetV2。
(4)VGG[1],為該網(wǎng)絡(luò)架構(gòu)的深度因子,本研究采用具有Batchnorm層的VGG網(wǎng)絡(luò)。
(5)ShuffleNet和ShuffleNetV2[3-4],其中ShuffleNet的模型寬度系數(shù)為3,ShuffleNetV2的模型尺寸系數(shù)默認(rèn)為1。
為了與其他知識蒸餾算法進(jìn)行客觀的對比,所有實(shí)驗(yàn)均采用與文獻(xiàn)[16]相同的實(shí)驗(yàn)設(shè)置。在CIFAR-100數(shù)據(jù)集上,所有網(wǎng)絡(luò)訓(xùn)練240個回合,在第150、180和210個訓(xùn)練回合時學(xué)習(xí)率除以10;重量衰減和動量分別設(shè)置為5×10-4和0.9;所有網(wǎng)絡(luò)的批次大小為128;ShuffleNet系列和MobileNetV2的學(xué)習(xí)率為0.02,其余模型的學(xué)習(xí)率為0.1;對于FSECD算法,溫度系數(shù)設(shè)置為4,對于不同的師生對,損失函數(shù)平衡系數(shù)的取值是不同的,具體見表1。在ImageNet數(shù)據(jù)集上,網(wǎng)絡(luò)訓(xùn)練100個回合,批次大小為512,學(xué)習(xí)率為0.2(學(xué)習(xí)率在第30、60和90個訓(xùn)練回合時除以10),重量衰減和動量分別設(shè)置為1×10-4和0.9,所有師生對的損失函數(shù)平衡系數(shù)都取為1,溫度系數(shù)設(shè)置為8。
表1 教師+學(xué)生模型師生對的超參數(shù)取值
Table 1 Values of hyperparameter of teacher + student pairs
教師+學(xué)生模型師生對λ ResNet56+ResNet204 ResNet110+ResNet328 ResNet5+MobileNetV26 ResNet32×4+ ResNet8x410 ResNet32×4+ShuffleNetV18 ResNet32×4+ShuffleNetV28 WRN-40-2+WRN-40-18 WRN-40-2+WRN-16-212 WRN-40-2+ShuffleNetV18 VGG13+VGG84 VGG13+MobileNetV28
在CIFAR-100數(shù)據(jù)集上,本研究使用Top-1準(zhǔn)確度cc1作為評價指標(biāo)。設(shè)測試集的圖像數(shù)量為,模型預(yù)測的概率最高的類等于真實(shí)標(biāo)簽的圖像數(shù)量為1,則
在ImageNet數(shù)據(jù)集上,除了cc1外,還使用了Top-5準(zhǔn)確度cc5作為評價指標(biāo)。設(shè)真實(shí)標(biāo)簽是模型預(yù)測出的概率最高的5個類之一的圖像數(shù)量為2,則
對同一網(wǎng)絡(luò)模型,所有知識蒸餾算法重復(fù)進(jìn)行3次訓(xùn)練,取3次訓(xùn)練的cc1和cc5作為該知識蒸餾算法的最終得分。CE表示只使用基于硬標(biāo)簽的交叉熵?fù)p失函數(shù)CE訓(xùn)練的算法。
文中通過實(shí)驗(yàn)比較了類別級策略和實(shí)例級策略對FSECD算法性能的影響,并對結(jié)果進(jìn)行分析。兩種策略的核心區(qū)別在于正/負(fù)例選取的標(biāo)準(zhǔn),在批次大小較小時,由于批次內(nèi)屬于相同類別的圖像數(shù)量較少,兩種策略的實(shí)際差異不大;隨著批次大小的增加,在每次迭代中屬于同一類的樣本數(shù)量將增加,兩種策略的實(shí)際差異會不斷擴(kuò)大;使對比實(shí)驗(yàn)的結(jié)果差異更明顯。為了更好地展示對比實(shí)驗(yàn)的結(jié)果,文中采用了3種批次大小,分別為128、512和1 024,選擇WRN-40-2+WRN-40-1和ResNet32×4+ResNet8×4兩對教師-學(xué)生模型師生對在CIFAR-100數(shù)據(jù)集上進(jìn)行對比實(shí)驗(yàn),結(jié)果見表2。從表中可知,當(dāng)批次大小過大時,模型過擬合程度增大,導(dǎo)致網(wǎng)絡(luò)性能下降,但相較于CE算法,采用兩種選取策略的FSECD算法提升了網(wǎng)絡(luò)的性能,并且使用實(shí)例級策略時性能提升的幅度更大。
表2 在CIFAR-100數(shù)據(jù)集上使用類別級策略和實(shí)例級策略的FSECD算法與CE算法的性能對比
Table 2 Comparison of performance among CE algorithm and FSECD algorithms with class-level policy and instance-level policy on CIFAR-100 dataset
算法教師+學(xué)生模型師生對Acc1/% B=128B=512B=1 024 CEWRN-40-2+WRN-40-173.2669.7568.11 ResNet32×4+ResNet8×472.5070.6369.14 類別級策略的FSECDWRN-40-2+WRN-40-173.3771.5270.33 ResNet32×4+ResNet8×475.7472.9471.24 實(shí)例級策略的FSECDWRN-40-2+WRN-40-174.4974.3973.26 ResNet32×4+ResNet8×476.5775.5874.58
當(dāng)批次大小為512時,對于WRN-40-2+WRN-40-1師生對,相對于CE算法,使用實(shí)例級策略的FSECD算法的網(wǎng)絡(luò)性能提升了4.64個百分點(diǎn),而使用類別級策略的FSECD算法則使網(wǎng)絡(luò)性能提升了1.77個百分點(diǎn)。對于WRN-40-2+WRN-40-1師生對,當(dāng)批次大小分別為128、512、1 024時,采用實(shí)例級策略時網(wǎng)絡(luò)的性能比采用類別級策略時分別多提升了1.12、2.87、2.93個百分點(diǎn)??梢钥吹?,批次大小越大,采用實(shí)例級策略的FSECD算法在性能上超過采用類別級策略的FSECD算法的幅度也相應(yīng)變大。同樣的現(xiàn)象也存在于另一組師生對中。
采用實(shí)例級策略訓(xùn)練的卷積神經(jīng)網(wǎng)絡(luò)性能更好的原因在于:采用實(shí)例級策略訓(xùn)練時,訓(xùn)練的細(xì)粒度更高。實(shí)例級策略要求能夠一一區(qū)分每一幅圖像,而不是單純地區(qū)分某一個類別,前者的難度要高于后者。采用實(shí)例級策略訓(xùn)練的模型,可以為每幅圖像生成獨(dú)特的特征,從而與教師圖像生成的特征相對應(yīng),因此網(wǎng)絡(luò)的性能更高。在后續(xù)與其他知識蒸餾算法的對比實(shí)驗(yàn)中,默認(rèn)FSECD算法使用的是實(shí)例級策略。
在CIFAR-100數(shù)據(jù)集上,采用FSECD算法與其他主流的知識蒸餾算法(包括KD[14]、FitNets[17]、AT[18]、RKD[15]、OFD[19]、CRD[16]和DKD[21])進(jìn)行了對比實(shí)驗(yàn),結(jié)果見表3和表4,分別用粗體和斜體標(biāo)記最優(yōu)和次優(yōu)的結(jié)果。從表3可見,對于實(shí)驗(yàn)的6組教師-學(xué)生模型師生對,使用FSECD算法訓(xùn)練的網(wǎng)絡(luò)的cc1性能取得了4組最優(yōu)和2組次優(yōu),這表明了特征空間嵌入的優(yōu)越性。從表4可見,對于實(shí)驗(yàn)的5組教師-學(xué)生模型師生對,采用FSECD算法訓(xùn)練的網(wǎng)絡(luò)的cc1性能取得了2組最優(yōu)和2組次優(yōu),只有ResNet32×4+ShuffleNetV2師生對的cc1低于OFD和DKD。
表3 在CIFAR-100數(shù)據(jù)集上使用9種知識蒸餾算法訓(xùn)練的6種相同網(wǎng)絡(luò)架構(gòu)模型的Acc1對比
Table 3 Comparison of Acc1 among six models with the same network architecture trained by nine knowledge distillation algorithms on CIFAR-100 dataset %
算法ResNet56+ResNet20ResNet110+ResNet32ResNet32×4+ResNet8×4WRN-40-2+WRN-16-2WRN-40-2+WRN-40-1VGG13+VGG8 CE69.0671.1472.5173.2671.9870.36 KD70.6673.0873.3374.9273.5472.98 FitNet69.2171.0673.5073.5872.2471.02 AT70.5572.3173.4474.0872.7771.43 RKD69.6171.8271.9073.3572.2271.48 OFD70.9873.2374.9575.2474.3373.95 CRD71.1673.4875.5175.4874.1473.94 DKD1)71.3273.7775.9275.3274.1474.41 FSECD71.3973.5176.5775.6274.4974.11
1)使用作者提供的代碼復(fù)現(xiàn)得到的結(jié)果,下同。
表4 在CIFAR-100數(shù)據(jù)集上使用9種知識蒸餾算法訓(xùn)練的5種相異網(wǎng)絡(luò)架構(gòu)模型的Acc1對比
Table 4 Comparison of Acc1 among five models with different network architectures trained by nine knowledge distillation algorithms on CIFAR-100 dataset %
算法ResNet32×4+ShuffleNetV1WRN-40-2+ShuffleNetV1VGG13+MobileNetV2ResNet50+MobileNetV2ResNet32×4+ShuffleNetV2 CE70.5070.5064.6064.6071.82 KD74.0774.8367.3767.3574.45 FitNet73.5973.7364.1463.1673.54 AT71.7373.3259.4058.5872.73 RKD72.2872.2164.5264.4373.21 OFD75.9875.8569.4869.0476.82 CRD75.1176.0569.7369.1175.65 DKD1)76.4576.6769.2969.9676.70 FSECD76.0176.3269.9770.0676.15
在ImageNet數(shù)據(jù)集上,采用FSECD算法與其他主流知識蒸餾算法(包括KD[14]、AT[18]、OFD[19]、CRD[16]和DKD[21]),對兩組教師-學(xué)生模型師生對進(jìn)行了對比實(shí)驗(yàn),結(jié)果見表5。表5顯示,對于具有相同網(wǎng)絡(luò)架構(gòu)的師生對ResNet34+ResNet18,DKD算法訓(xùn)練的網(wǎng)絡(luò)的cc1略微超過FSECD算法,但差距非常小,為0.05%,這兩種算法訓(xùn)練的網(wǎng)絡(luò)的cc5幾乎相等。對于具有相異網(wǎng)絡(luò)架構(gòu)的師生對ResNet50+MobileNetV1,使用FSECD算法訓(xùn)練的網(wǎng)絡(luò),其cc1和cc5性能均優(yōu)于其他知識蒸餾算法。
在CIFAR100和ImageNet數(shù)據(jù)集上,F(xiàn)SECD和DKD算法均取得優(yōu)秀的結(jié)果,然而FSECD算法使用的超參數(shù)少于DKD算法,使得文中FSECD算法的拓展性更好。
表5 ImageNet數(shù)據(jù)集上使用7種知識蒸餾算法訓(xùn)練的2種網(wǎng)絡(luò)模型的Acc1和Acc5對比
Table 5 Comparison of Acc1 and Acc5 between two network models trained by seven knowledge distillation algorithms on ImageNet dataset
算法Acc1/%Acc5/% ResNet34+ResNet18ResNet50+MobileNetvV1ResNet34+ResNet18ResNet50+MobileNetV1 CE69.7568.8789.0788.76 KD71.0370.5090.0589.80 AT70.6969.5690.0189.33 CRD71.1771.3790.1390.41 OFD70.8171.2589.9890.34 DKD1)71.5472.0190.4390.02 FSECD71.4972.1990.4490.98
FSECD算法通過在每一個訓(xùn)練迭代內(nèi),學(xué)生模型輸出特征根據(jù)與其他所有教師模型輸出特征的關(guān)系,產(chǎn)生吸引或排斥的效果。文中用式(7)所示的損失函數(shù)替換式(4)中的FSECD,并將修改后的算法命名為FSECD_S。在FSECD_S中,只構(gòu)建正對比對,不構(gòu)建負(fù)對比對,學(xué)生模型只能學(xué)習(xí)到一對一的樣本間關(guān)系的知識,無法學(xué)習(xí)到結(jié)構(gòu)化的知識。
損失函數(shù)(7)只顯式地最小化正對比對的距離,不考慮負(fù)對比對,無結(jié)構(gòu)化知識的學(xué)習(xí)。文中選擇WRN-40-2+WRN-40-1和ResNet32×4+ResNet8×4兩組教師-學(xué)生模型師生對,在CIFAR-100數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),結(jié)果見表6。從表中可知:兩種損失函數(shù)都可以提高模型的性能;FSECD算法提升學(xué)生模型性能的幅度更大,充分證明了基于特征空間嵌入進(jìn)行結(jié)構(gòu)化知識學(xué)習(xí)的重要性。
表6 在CIFAR-100數(shù)據(jù)集上結(jié)構(gòu)化知識對模型性能的影響
Table 6 Influence of structural knowledge on the performance of models on CIFAR-100 dataset %
算法WRN-40-2+WRN-40-1ResNet32×4+ResNet8×4 CE71.9872.51 FSECD_S73.6674.49 FSECD75.7476.57
為探究充當(dāng)負(fù)例的樣本數(shù)量對FSECD算法的影響,比較了使用不同數(shù)量負(fù)例的學(xué)生模型的cc1得分。在每個訓(xùn)練迭代,對任一學(xué)生模型輸出特征,先根據(jù)負(fù)例與該學(xué)生模型輸出特征的相似度進(jìn)行降序排序,然后只保留降序排列后top-的負(fù)例用于FSECD算法,此處是一個預(yù)先設(shè)置好的超參數(shù)。此外,還比較了只使用一個負(fù)例進(jìn)行訓(xùn)練的學(xué)生模型的性能。在CIFAR100數(shù)據(jù)集上兩組師生對的實(shí)驗(yàn)結(jié)果如表7所示。
表7 在CIFAR-100數(shù)據(jù)集上負(fù)例樣本數(shù)對模型性能的影響
Table 7 Influence of the number of negative instances on the performance of models on CIFAR-100 dataset %
算法ResNet32×4+ResNet8×4ResNet56+ResNet20 CE72.5169.02 FSECD(1個負(fù)例)73.9270.08 FSECD(k=25%)75.8670.96 FSECD(k=50%)76.1471.15 FSECD(k=100%)76.5771.39
當(dāng)只使用一個負(fù)例時,相對于CE算法,使用FSECD算法訓(xùn)練的ResNet8×4的cc1只提升了1.41個百分點(diǎn),ResNet20的cc1只提升了1.06個百分點(diǎn)。隨著使用負(fù)例的比例增大,學(xué)生模型的性能也不斷提升,當(dāng)=100%時,兩個學(xué)生網(wǎng)絡(luò)的cc1都達(dá)到峰值。該實(shí)驗(yàn)結(jié)果表明了負(fù)例數(shù)量的重要性。隨著負(fù)例個數(shù)的增加,學(xué)生獲取到更多教師特征空間的信息,使得學(xué)生模型的特征空間有更好的泛化能力,獲得了更好的性能。
3.3.1溫度系數(shù)對模型性能的影響
溫度系數(shù)取不同值(1、2、4、6、8、10)時對模型性能的影響如圖3所示,過高或過低的溫度系數(shù)都會導(dǎo)致模型性能的下降。溫度系數(shù)的最佳值與數(shù)據(jù)集預(yù)測類別的數(shù)量呈正相關(guān)關(guān)系。數(shù)據(jù)集預(yù)測類別的數(shù)量越多,合適的溫度系數(shù)的數(shù)值就越大,而且它對模型不敏感。CIFAR-100數(shù)據(jù)集的類別數(shù)量為100,所用溫度系數(shù)的數(shù)值在3到5之間;ImageNet數(shù)據(jù)集的類別數(shù)量為1 000,溫度系數(shù)的數(shù)值設(shè)置為8比較合適。
3.3.2損失函數(shù)平衡系數(shù)對模型性能的影響
損失函數(shù)平衡系數(shù)取不同值(1、2、4、8、10、15、20)時對模型性能的影響如圖4所示,隨著損失函數(shù)平衡系數(shù)的增加,模型性能先上升后下降。損失函數(shù)平衡系數(shù)的最佳值需要根據(jù)網(wǎng)絡(luò)模型師生對的情況進(jìn)行調(diào)優(yōu),如ResNet32×4+ResNet8×4的損失函數(shù)平衡系數(shù)最佳值是10,而ResNet56-ResNet20的最佳值是4。此外,當(dāng)損失函數(shù)平衡系數(shù)的取值過大導(dǎo)致網(wǎng)絡(luò)性能下降時,不同網(wǎng)絡(luò)性能的下降程度是不同的,如ResNet32×4-ResNet8×4相比其最佳性能下降的幅度不大,而ResNet56-ResNet20相比其最佳性能的下降幅度很大。
圖4 在CIFAR-100數(shù)據(jù)集上損失函數(shù)平衡系數(shù)對模型性能的影響
文中使用t-SNE[29]來可視化ResNet8×4模型倒數(shù)第二層卷積層的特征,t-SNE算法可以在降維的同時,保持特征之間的相互關(guān)系??梢暬窃贑IFAR-100數(shù)據(jù)集的測試集上進(jìn)行的,先對測試集上的每幅圖像進(jìn)行特征提取,然后使用t-SNE算法對高維特征進(jìn)行降維。本研究將每個高維特征降到二維空間,并在二維地圖上以點(diǎn)表示。
圖5 基于t-SNE算法的特征可視化
KD和FSECD算法的t-SNE可視化結(jié)果如圖5所示,由于ResNet8×4在測試集上的cc1在70%~80%之間,因此存在一些特征是分散的,而不是靠近其特征中心,在生成的特征分布圖中心形成了一個不太清晰的區(qū)域。由特征分布圖可知,使用FSECD算法得到的模型實(shí)現(xiàn)了更緊密的特征聚集和更高區(qū)分度的類間邊界。
文中通過可視化距離矩陣,對比了采用KD和FSECD算法訓(xùn)練的學(xué)生模型與教師模型的相似程度,采用的網(wǎng)絡(luò)模型師生對為ResNet32×4+ResNet8×4。
首先,按真實(shí)標(biāo)簽對測試集中的圖像進(jìn)行分類;然后,對于某一類別的所有圖像,統(tǒng)計(jì)模型預(yù)測概率并取均值,對所有類別的圖像進(jìn)行預(yù)測后,可得到類別預(yù)測概率矩陣∈ R100×100,其中P,定義為所有真實(shí)標(biāo)簽為的圖像被預(yù)測為類別的平均概率;最后,計(jì)算教師模型類別預(yù)測概率矩陣t和學(xué)生模型類別預(yù)測概率矩陣s之間的距離矩陣,計(jì)算公式為
為了可視化,對距離矩陣采用全局歸一化:
可視化結(jié)果如圖6所示,圖中用顏色深淺表示該點(diǎn)距離的大小。學(xué)生模型與教師模型越相似,點(diǎn)的顏色越淺,同時也代表學(xué)生模型的性能越好。從圖中可知,使用FSECD算法訓(xùn)練的學(xué)生模型歐幾里得距離更小,進(jìn)一步證明了FSECD算法遷移教師模型知識的優(yōu)越性。
圖6 兩種算法的教師和學(xué)生模型的差異
本研究提出了一種基于特征空間嵌入的對比知識蒸餾算法FSECD,該算法將教師模型的結(jié)構(gòu)化知識提取到學(xué)生模型。在每個訓(xùn)練迭代,批次內(nèi)的教師模型輸出特征或充當(dāng)正例或充當(dāng)負(fù)例,共同優(yōu)化每個學(xué)生模型的輸出特征。批次內(nèi)的學(xué)生模型輸出特征被嵌入到教師模型的特征空間,學(xué)習(xí)相同的結(jié)構(gòu)化的教師模型的知識,最終學(xué)生模型模仿了教師模型的特征空間,并能夠輸出與教師相似的樣本間關(guān)系,實(shí)現(xiàn)了知識的遷移。文中算法通過在批次內(nèi)進(jìn)行對比學(xué)習(xí)的對比對的構(gòu)建,克服了以往對比學(xué)習(xí)需要額外的計(jì)算或內(nèi)存的缺點(diǎn),是一種高效簡潔的知識蒸餾算法。本研究在CIFAR-100和ImageNet數(shù)據(jù)集上進(jìn)行了大量的對比實(shí)驗(yàn),結(jié)果顯示,在大多數(shù)實(shí)驗(yàn)配置下,文中提出的算法均取得最優(yōu)或次優(yōu)的結(jié)果,充分證明了文中算法的優(yōu)越性,并進(jìn)一步證明了樣本間關(guān)系在知識蒸餾中的重要性。
[1] SIMONYAN K,ZISSERMAN A.Very deep convolutional networks for large-scale image recognition [EB/OL].(2015-04-10)[2022-10-20].https://arxiv.org/abs/1409.1556v1.
[2] HE K,ZHANG X,REN S,et al.Deep residual learning for image recognition[C]∥ Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition.Las Vegas:IEEE,2016:770-778.
[3] ZHANG X,ZHOU X,LIN M,et al.ShuffleNet:an extremely efficient convolutional neural network for mobile devices[C]∥ Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition.Salt Lake City:IEEE,2018:6848-6856.
[4] MA N,ZHANG X,ZHENG H-T,et al.ShuffleNet V2:practical guidelines for efficient CNN architecture design[C]∥ Proceedings of the 15th European Conference on Computer Vision.Munich:Springer,2018:122-138.
[5] SANDLER M,HOWARD A,ZHU M,et al.MobileNetV2:inverted residuals and linear bottlenecks [C]∥ Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition.Salt Lake City:IEEE,2018:4510-4520.
[6] ZAGORUYKO S,KOMODAKIS N.Wide residual networks[EB/OL].(2017-06-14)[2022-10-20].https://arxiv.org/abs/1605.07146.
[7] REDMON J,DIVVALA S,GIRSHICK R,et al.You only look once:unified,real-time object detection[C]∥ Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition.Las Vegas:IEEE,2016:779-788.
[8] LIU W,ANGUELOV D,ERHAN D,et al.SSD:single shot multibox detector[C]∥ Proceedings of the 14th European Conference on Computer Vision.Amsterdam:Springer,2016:21-37.
[9] HE K,GKIOXARI G,DOLLáR P,et al.Mask R-CNN[C]∥ Proceedings of 2017 IEEE International Conference on Computer Vision.Venice:IEEE,2017:2961-2969.
[10] ZHAO H,SHI J,QI X,et al.Pyramid scene parsing network[C]∥ Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition.Honolulu:IEEE,2017:2881-2890.
[11] LUO J-H,WU J,LIN W.ThiNet:a filter level pruning method for deep neural network compression[C]∥ Proceedings of 2017 IEEE International Conference on Computer Vision.Venice:IEEE,2017:5058-5066.
[12] JACOB B,KLIGYS S,CHEN B,et al.Quantization and training of neural networks for efficient integer-arithmetic-only inference[C]∥ Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition.Salt Lake City:IEEE,2018:2704-2713.
[13] YU X,LIU T,WANG X,et al.On compressing deep models by low rank and sparse decomposition[C]∥ Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition.Honolulu:IEEE,2017:7370-7379.
[14] HINTON G,VINYALS O,DEAN J.Distilling the knowledge in a neural network[EB/OL].(2015-05-09)[2022-10-20].https://arxiv.org/abs/1503.02531.
[15] PARK W,KIM D,LU Y,et al.Relational knowledge distillation[C]∥ Proceedings of 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition.Long Beach:IEEE,2019:3967-3976.
[16] TIAN Y,KRISHNAN D,ISOLA P.Contrastive representation distillation[C]∥ Proceedings of the 8th International Conference on Learning Representations.Addis Ababa:OpenReview.net,2020:1-19.
[17] ROMERO A,BALLAS N,KAHOU S E,et al.FitNets:hints for thin deep nets[C]∥ Proceedings of the 3rd International Conference on Learning Representations.San Diego:OpenReview.net,2015:1-13.
[18] ZAGORUYKO S,KOMODAKIS N.Paying more attention to attention:improving the performance of convolutional neural networks via attention transfer[C]∥ Proceedings of the 5th International Conference on Learning Representations.Toulon:OpenReview.net,2017:1-13.
[19] HEO B,KIM J,YUN S,et al.A comprehensive overhaul of feature distillation[C]∥ Proceedings of 2019 IEEE/CVF International Conference on Computer Vision.Long Beach:IEEE,2019:1921-1930.
[20] GOU J,YU B,MAYBANK S J,et al.Knowledge distillation:a survey[J].International Journal of Computer Vision,2021,129(6):1789-1819.
[21] ZHAO B,CUI Q,SONG R,et al.Decoupled knowledge distillation[C]∥ Proceedings of 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition.New Orleans:IEEE,2022:11953-11962.
[22] GOU J,SUN L,YU B,et al.Multi-level attention-based sample correlations for knowledge distillation[J].IEEE Transactions on Industrial Informatics,2022,DOI:10.1109/TII.2022.3209672.
[23] CHEN T,KORNBLITH S,NOROUZI M,et al.A simple framework for contrastive learning of visual representations[C]∥ Proceedings of the Thirty-seventh International Conference on Machine Learning.Vienna:IMLS,2020:1597-1607.
[24] RADFORD A,KIM J W,HALLACY C,et al.Learning transferable visual models from natural language supervision[C]∥ Proceedings of the 38th International Conference on Machine Learning.Vienna:IMLS,2021:8748-8763.
[25] XU G,LIU Z,LI X,et al.Knowledge distillation meets self-supervision[C]∥ Proceedings of the 16th European Conference on Computer Vision.Glasgow:Springer,2020:588-604.
[26] KRIZHEVSKY A.Learning multiple layers of features from tiny images[D].Toronto:University of Toronto,2009.
[27] DENG J,DONG W,SOCHER R,et al.ImageNet:a large-scale hierarchical image database[C]∥ Proceedings of 2009 IEEE Conference on Computer Vision and Pattern Recognition.Miami:IEEE,2009:248-255.
[28] HOWARD A G,ZHU M,CHEN B,et al.MobileNets:efficient convolutional neural networks for mobile vision applications[EB/OL].(2017-04-17)[2022-10-20].https://arxiv.org/abs/1704.04861.
[29] Van der MAATEN L,HINTON G.Visualizing data using t-SNE[J].Journal of Machine Learning Research,2008,9(11):2579-2605.
Contrastive Knowledge Distillation Method Based on Feature Space Embedding
(School of Mechanical and Automotive Engineering,South China University of Technology,Guangzhou 510640,Guangdong,China)
Because of its important role in model compression, knowledge distillation has attracted much attention in the field of deep learning. However, the classical knowledge distillation algorithm only uses the information of a single sample, and neglects the importance of the relationship between samples, leading to its poor performance. To improve the efficiency and performance of knowledge transfer in knowledge distillation algorithm, this paper proposed a feature-space-embedding based contrastive knowledge distillation (FSECD) algorithm. The algorithm adopts efficient batch construction strategy, which embeds the student feature into the teacher feature space so that each student feature buildscontrastive pairs withteacher features. In each pair, the teacher feature is optimized and fixed, while student feature is to be optimized and tunable. In the training process, the distance for positive pairs is narrowed and the distance for negative pairs is expanded, so that student model can perceive and learn the inter-sample relations of teacher model and realize the transfer of knowledge from teacher model to student model. Extensive experiments with different teacher/student architecture settings on CIFAR-100 and ImageNet datasets show that, FSECD algorithm achieves significant performance improvement without additional network structures and data when compared with other cutting-edge distillation methods, which further proves the importance of the inter-sample relations in knowledge distillation.
image classification;knowledge distillation;convolutional neural network;deep learning;contrastive learning
Supported by the Key-Area R&D Program of Guangdong Province (2021B0101420003)
10.12141/j.issn.1000-565X.220684
2022?10?24
廣東省重點(diǎn)領(lǐng)域研發(fā)計(jì)劃項(xiàng)目(2021B0101420003)
葉峰(1972-),男,博士,副教授,主要從事機(jī)器視覺及移動機(jī)器人傳感控制研究。E-mail:mefengye@scut.edu.cn
TP391
1000-565X(2023)05-0013-11
華南理工大學(xué)學(xué)報(bào)(自然科學(xué)版)2023年5期