鄒美霞
(北京交通大學(xué)計(jì)算機(jī)與信息技術(shù)學(xué)院,北京 100044)
近年來,卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)得益于神經(jīng)網(wǎng)絡(luò)模型結(jié)構(gòu)的發(fā)展,受到了人工智能領(lǐng)域的廣泛關(guān)注。它在圖像分類、目標(biāo)檢測(cè)等方面取得了巨大的成功,這也促使人們一直致力于開發(fā)更深入、更復(fù)雜的模型,以達(dá)到更高的精度。然而,更深的模型會(huì)顯著增加模型推理所需的延遲和計(jì)算成本。例如,在Titan X GPU上對(duì)比VGG19和AlexNet。實(shí)驗(yàn)表明,兩個(gè)模型的運(yùn)行時(shí)間和計(jì)算消耗增加了20倍,錯(cuò)誤率降低4%(從11%到7%)。如果不能妥善降低CNN巨大的計(jì)算成本和延遲,將在很大程度上阻礙CNN在各種資源受限的平臺(tái)上的大規(guī)模部署,特別是在延遲敏感和實(shí)時(shí)的場(chǎng)景中,比如自動(dòng)駕駛、機(jī)器人導(dǎo)航或移動(dòng)設(shè)備上。
為了解決CNN巨大的計(jì)算成本和延遲問題,研究者提出了多分支結(jié)構(gòu),其能夠使網(wǎng)絡(luò)在達(dá)到一定退出條件時(shí)停止計(jì)算,從而退出網(wǎng)絡(luò)。但是當(dāng)需要強(qiáng)制網(wǎng)絡(luò)退出時(shí)會(huì)出現(xiàn)淺層網(wǎng)絡(luò)精度較低的情況。研究者除了提出多分支結(jié)構(gòu)的方法,還提出了輕量化網(wǎng)絡(luò)設(shè)計(jì)、剪枝、量化和知識(shí)蒸餾等模型壓縮技術(shù)來減少模型的計(jì)算量。其中,知識(shí)蒸餾是一種切實(shí)可用的技術(shù)之一,傳統(tǒng)的知識(shí)蒸餾是基于“教師-學(xué)生”網(wǎng)絡(luò)的訓(xùn)練方法實(shí)現(xiàn)知識(shí)的轉(zhuǎn)移,其中“教師”是復(fù)雜且計(jì)算量大的模型,“學(xué)生”是精簡(jiǎn)模型,通過讓學(xué)生模型學(xué)習(xí)教師模型的輸出,使得學(xué)生模型接近或超過教師模型的表現(xiàn)力。通過使用緊湊的學(xué)生模型代替龐大復(fù)雜的教師模型,可以實(shí)現(xiàn)模型的壓縮和推理加速。但是,此種方式訓(xùn)練的學(xué)生模型優(yōu)于教師模型是相對(duì)困難的,兩個(gè)模型之間存在著知識(shí)不能很好地轉(zhuǎn)移和學(xué)習(xí)的問題。
針對(duì)上述問題,文獻(xiàn)[11]提出了一種自蒸餾方法(BOT),在一個(gè)模型上進(jìn)行知識(shí)蒸餾,將模型分為幾個(gè)部分,通過在不同模型深度上添加分支分類器,將深層網(wǎng)絡(luò)作為教師,淺層網(wǎng)絡(luò)作為學(xué)生進(jìn)行網(wǎng)絡(luò)訓(xùn)練。文獻(xiàn)[12]提出了基于知識(shí)蒸餾的學(xué)習(xí)方法,以提高卷積神經(jīng)網(wǎng)絡(luò)(CNN)的分類性能,而無需預(yù)先訓(xùn)練的教師網(wǎng)絡(luò)。其也是使用深層的網(wǎng)絡(luò)教授淺層網(wǎng)絡(luò),但是淺層網(wǎng)絡(luò)與深層網(wǎng)絡(luò)之間存在性能差距。為了縮小二者的差距,本文提出一種新的自蒸餾方法,將整個(gè)網(wǎng)絡(luò)分支的輸出作為集成的結(jié)果,借助助教網(wǎng)絡(luò)的思想,使用集成的結(jié)果訓(xùn)練最后的分支,同時(shí)使用最后的分支訓(xùn)練網(wǎng)絡(luò)的其他分支,以提高各分支的分類精度。本文將提出的自蒸餾方法與其他自蒸餾方法進(jìn)行對(duì)比,實(shí)驗(yàn)結(jié)果顯示,該方法能有效提高各個(gè)分支的精度,以便網(wǎng)絡(luò)能夠更好地提前退出。
一般來說,隨著網(wǎng)絡(luò)深度的加大,網(wǎng)絡(luò)性能也會(huì)隨著深度的加大而提升,但是網(wǎng)絡(luò)模型訓(xùn)練和推理的時(shí)間也會(huì)隨之變得越來越長(zhǎng),隨之資源的消耗也會(huì)越來越多。針對(duì)這個(gè)問題,一些研究者提出了多分支結(jié)構(gòu)。多分支結(jié)構(gòu)是在傳統(tǒng)CNN模型上添加分支器,使得模型能夠在滿足一定的精度需求時(shí)終止計(jì)算提早退出網(wǎng)絡(luò),以達(dá)到節(jié)約計(jì)算資源的目的。文獻(xiàn)[3]提出了BranchyNet,其通過在主干網(wǎng)絡(luò)上添加多個(gè)分支器,使用分類結(jié)果的信息熵作為對(duì)預(yù)測(cè)結(jié)果的置信度度量。具體的,在每個(gè)退出點(diǎn)位置設(shè)置退出閾值,并對(duì)分支器中樣本置信度進(jìn)行判斷。當(dāng)樣本的熵低于退出點(diǎn)閾值時(shí),即意味著分類器認(rèn)為預(yù)測(cè)結(jié)果是可信的,這樣能夠使樣本提前退出網(wǎng)絡(luò)而不再參與后面網(wǎng)絡(luò)層的計(jì)算,從而減少計(jì)算量,節(jié)約計(jì)算資源。否則,當(dāng)前樣本的預(yù)測(cè)結(jié)果不可信,則繼續(xù)到下一個(gè)退出點(diǎn),直到主干網(wǎng)絡(luò)上的最后一個(gè)退出點(diǎn)。在此基礎(chǔ)上,文獻(xiàn)[5]通過不同長(zhǎng)度的子路徑組合,自適應(yīng)地選擇合適的子路徑集合提升模型表現(xiàn)。文獻(xiàn)[6]提出了基于分支在某一個(gè)時(shí)間點(diǎn),當(dāng)應(yīng)用程序分配的資源用盡,需要強(qiáng)制退出網(wǎng)絡(luò)時(shí)可在任何時(shí)間退出的多分支結(jié)構(gòu)。通過該方法保證在某個(gè)時(shí)間點(diǎn)需要輸出網(wǎng)絡(luò)模型的結(jié)果時(shí),網(wǎng)絡(luò)可以把最近的分支的分類結(jié)果作為輸出。然而,多分支結(jié)構(gòu)中淺層分支由于層數(shù)少,特征提取能力不夠,導(dǎo)致其精度過低。
知識(shí)蒸餾是模型壓縮技術(shù)中最常用的壓縮方法之一,傳統(tǒng)知識(shí)蒸餾的核心思想是通過引入復(fù)雜但精度較高的模型作為教師引導(dǎo)模型精簡(jiǎn)的學(xué)生模型,將復(fù)雜學(xué)習(xí)能力強(qiáng)的教師模型學(xué)到的特征的知識(shí)通過蒸餾提取出來,傳遞給參數(shù)量小、學(xué)習(xí)能力弱的學(xué)生模型。其主要包含兩個(gè)核心內(nèi)容,分別是蒸餾內(nèi)容和蒸餾方式。
在蒸餾內(nèi)容上主要分為基于Softmax輸出層的標(biāo)簽蒸餾和基于中間層的特征蒸餾。2014年,Hinton等為了加強(qiáng)輸出的概率分布中學(xué)習(xí)到的網(wǎng)絡(luò)知識(shí),通過將全連接層的輸出經(jīng)過超參數(shù)T進(jìn)行平滑輸出,該輸出被稱為軟目標(biāo)(Soft Label),最后通過軟目標(biāo)和真實(shí)標(biāo)簽一起指導(dǎo)學(xué)生網(wǎng)絡(luò)的訓(xùn)練。文獻(xiàn)[15]提出了一種基于中間特征蒸餾的FitNet方法,其將教師網(wǎng)絡(luò)的軟目標(biāo)和發(fā)掘教師網(wǎng)絡(luò)的中間特征圖的知識(shí)相結(jié)合教授學(xué)生知識(shí),使得學(xué)生網(wǎng)絡(luò)可以比教師網(wǎng)絡(luò)更深更窄,提高學(xué)生網(wǎng)絡(luò)的性能。文獻(xiàn)[16]提出了基于圖的蒸餾方法,提出分對(duì)數(shù)(Logits)圖和表示(Representation)圖來傳遞多個(gè)知識(shí)。
在蒸餾方式上大部分工作是通過在教師和學(xué)生兩個(gè)模型之間轉(zhuǎn)移知識(shí)的方式進(jìn)行知識(shí)的傳遞和教授。但學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)之間性能差距大,會(huì)導(dǎo)致學(xué)生網(wǎng)絡(luò)不能充分學(xué)習(xí)教師網(wǎng)絡(luò)的知識(shí),從而使學(xué)生網(wǎng)絡(luò)的效果不佳。針對(duì)這個(gè)問題,文獻(xiàn)[13]和文獻(xiàn)[17]通過助教網(wǎng)絡(luò)的方式減少二者之間的差距,從而提升學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)效率。文獻(xiàn)[13]提出了使用多個(gè)教師助手來密集引導(dǎo)知識(shí)蒸餾的方法,通過每一位助教迭代地指導(dǎo)每一位更小的助教,并通過隨機(jī)投放一名教師和助教來提高學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)能力。文獻(xiàn)[17]通過使用多步知識(shí)蒸餾,利用一個(gè)中等規(guī)模的網(wǎng)絡(luò)充當(dāng)教師助理的角色作為教師和學(xué)生之間的橋梁,從而減少二者之間的差距。
考慮到1.1節(jié)中所提出的淺層分支精度低的問題,研究人員在基于1.2節(jié)中的傳統(tǒng)知識(shí)蒸餾技術(shù)上提出了利用深層分支來指導(dǎo)淺層分支的自蒸餾方案。自蒸餾方法主要是將教師模型和學(xué)生模型集成到一個(gè)網(wǎng)絡(luò)模型上,將淺層的網(wǎng)絡(luò)作為學(xué)生模型,將深層的網(wǎng)絡(luò)作為教師模型,從后向前進(jìn)行知識(shí)的傳遞。自蒸餾的方法無需訓(xùn)練教師和尋找學(xué)生網(wǎng)絡(luò),這樣可以有效避免傳統(tǒng)知識(shí)蒸餾中兩個(gè)模型之間傳遞知識(shí)的問題(教師網(wǎng)絡(luò)向?qū)W生網(wǎng)絡(luò))。文獻(xiàn)[11]提出了一種自蒸餾方式(BOT),針對(duì)多分支結(jié)構(gòu)通過深層網(wǎng)絡(luò)教授淺層網(wǎng)絡(luò),能夠有效提升淺層網(wǎng)絡(luò)的性能。文獻(xiàn)[18]提出了一種基于自蒸餾的多出口體系結(jié)構(gòu)訓(xùn)練方法。該方法通過匹配早期退出的輸出概率,鼓勵(lì)早期退出模仿稍后更準(zhǔn)確的退出。文獻(xiàn)[19]從提高分類器精度的角度入手,充分利用同一網(wǎng)絡(luò)中不同分類器的知識(shí),來提高每個(gè)分類器的準(zhǔn)確性。其通過分支增強(qiáng)技術(shù)將一個(gè)單分類器網(wǎng)絡(luò)轉(zhuǎn)換成一個(gè)多分類器網(wǎng)絡(luò)。文獻(xiàn)[20]通過多個(gè)中間分類器自適應(yīng)網(wǎng)絡(luò)模型,提出了采用梯度均衡算法來解決不同分類器之間的學(xué)習(xí)沖突,采用內(nèi)聯(lián)子網(wǎng)絡(luò)協(xié)作方法和知識(shí)蒸餾算法增強(qiáng)分類器之間協(xié)作的方法。
以上這些工作通過將深層的網(wǎng)絡(luò)作為教師,淺層網(wǎng)絡(luò)作為學(xué)生的教授方式。鑒于此,最新的工作通過在特定的中間層上附加出口來創(chuàng)建輔助分類器,相較于之前教師指導(dǎo)學(xué)生的方法,該工作的每個(gè)出口和最終輸出都扮演著教師和學(xué)生的角色,將主網(wǎng)絡(luò)和出口在內(nèi)的整個(gè)網(wǎng)絡(luò)同時(shí)訓(xùn)練。其提出的蒸餾方式與傳統(tǒng)的教師網(wǎng)絡(luò)只教授學(xué)生的方式不同,其提出網(wǎng)絡(luò)中所有分類的集成可以作為一個(gè)教師,使用集成訓(xùn)練網(wǎng)絡(luò)中的各個(gè)分支以提高網(wǎng)絡(luò)的分類性能。由于每個(gè)從輸入到出口或最終輸出的分類器都有不同的結(jié)構(gòu),因此分類器的輸出和特征存在多樣性和互補(bǔ)性,訓(xùn)練整個(gè)網(wǎng)絡(luò)的損失函數(shù)由傳統(tǒng)的交叉熵?fù)p失函數(shù)和各出口的蒸餾損失函數(shù)以及最終輸出的損失函數(shù)組成。但是,深層網(wǎng)絡(luò)與淺層網(wǎng)絡(luò)的能力之間存在較大差距,將整個(gè)網(wǎng)絡(luò)集成的結(jié)果直接教淺層各個(gè)分支,會(huì)影響淺層網(wǎng)絡(luò)的學(xué)習(xí)性能。針對(duì)這個(gè)問題,本文提出了一種新的自蒸餾方法,借助助教網(wǎng)絡(luò)的思想,讓集成的網(wǎng)絡(luò)作為教師指導(dǎo)最后的分支,同時(shí)使最后分支的輸出作為教師指導(dǎo)網(wǎng)絡(luò)中的其他分支,利用這種方式縮小集成分支與淺層分支的性能差距,提升淺層分支的性能。
一種支持自蒸餾方法的多分支結(jié)構(gòu)如圖1所示,其中F表示特征,Z表示logits。根據(jù)CNN的深度和原始結(jié)構(gòu),將CNN劃分為幾個(gè)淺層區(qū)域,在每個(gè)部分后設(shè)置一個(gè)分類器。根據(jù)文獻(xiàn)[11],其退出分支由Bottleneck層和全連接層組成,其中增加Bottleneck層的主要作用是減少淺層分類器之間的影響。
圖1所示的自蒸餾方法一般由3個(gè)損失函數(shù)組成。首先,從標(biāo)簽到深層分類器和淺層分類器的交叉熵?fù)p失,它是由訓(xùn)練數(shù)據(jù)集中的標(biāo)簽和每個(gè)分類器的輸出計(jì)算得到。通過這種方式能夠使數(shù)據(jù)集中的知識(shí)直接從標(biāo)簽傳遞到所有分類器。具體公式如下:
圖1 一種自蒸餾方法示意圖
其次,KL散度(Kullback-Leibler Divergence)是兩個(gè)概率分布間差異的非對(duì)稱度量,通過KL散度算法計(jì)算學(xué)生和教師之間輸出分布的差異,能夠?qū)⑸顚泳W(wǎng)絡(luò)的知識(shí)教授到淺層網(wǎng)絡(luò)的分支。如公式(2)所示,(x)在概率函數(shù)中表示真實(shí)分布,在自蒸餾中表示教師網(wǎng)絡(luò)的輸出分布,(x)在概率函數(shù)中表示需擬合的概率分布,在自蒸餾中表示學(xué)生網(wǎng)絡(luò)的輸出分布,當(dāng)(x)=(x)時(shí)表明學(xué)生網(wǎng)絡(luò)已經(jīng)充分學(xué)習(xí)教師網(wǎng)絡(luò)的性能,因此在自蒸餾中要最小化KL散度的值,使得學(xué)生網(wǎng)絡(luò)的性能提升。
最后,通過計(jì)算特征圖中學(xué)生和教師之間的特征差異,使得網(wǎng)絡(luò)更好地提取知識(shí)以提升自身的能力。
本文提出了一種基于多分支的自蒸餾方法,如圖2所示,其中F表示特征,Z表示logits。本文與文獻(xiàn)[11]的結(jié)構(gòu)設(shè)計(jì)相同,根據(jù)初始網(wǎng)絡(luò)的深度和結(jié)構(gòu),將初始網(wǎng)絡(luò)分為多個(gè)淺層網(wǎng)絡(luò),在每一個(gè)淺層網(wǎng)絡(luò)的后面添加Bottleneck層,構(gòu)成一個(gè)分支。本文采用了交叉熵?fù)p失函數(shù)訓(xùn)練網(wǎng)絡(luò),通過交叉熵表示教師實(shí)際教授的知識(shí)和期望學(xué)生網(wǎng)絡(luò)學(xué)到的知識(shí)之間的差距,二者之間交叉熵的值越小,說明兩個(gè)分布概率越接近,網(wǎng)絡(luò)效果越好,公式如下所示:
圖2 SDA自蒸餾方法示意圖
其中,q表示第個(gè)分支的Softmax輸出,表示蒸餾的溫度,表示實(shí)際的標(biāo)簽輸出。
本文在訓(xùn)練網(wǎng)絡(luò)時(shí)使用了兩種蒸餾內(nèi)容,分別是基于標(biāo)簽蒸餾和基于中間特征的蒸餾方式。
(1)基于標(biāo)簽蒸餾,本文使用了KL散度來衡量學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的輸出分布差異,以表示學(xué)生網(wǎng)絡(luò)對(duì)于教師網(wǎng)絡(luò)知識(shí)的學(xué)習(xí)程度。為了使得淺層分類器能夠更好地接近深層分類器的表現(xiàn),公式定義如下所示:
其中,q表示網(wǎng)絡(luò)最后分支的Softmax輸出。
(2)在進(jìn)行特征蒸餾時(shí),對(duì)于網(wǎng)絡(luò)輸入數(shù)據(jù)將被提取的可以表示學(xué)習(xí)過程的知識(shí),主要是通過各種手段從網(wǎng)絡(luò)的中間隱藏層提取。這些知識(shí)被轉(zhuǎn)移到學(xué)生網(wǎng)絡(luò)中,實(shí)現(xiàn)知識(shí)的升華和提高,從而提高學(xué)生網(wǎng)絡(luò)的性能。基于特征的知識(shí)蒸餾公式定義如下所示:
其中,F表示第個(gè)分支的特征輸出,F表示網(wǎng)絡(luò)最后分支的特征輸出。
(3)模型中淺層網(wǎng)絡(luò)層數(shù)少,對(duì)圖片信息的提取能力不足導(dǎo)致其正確分類的能力弱。而隨著網(wǎng)絡(luò)層數(shù)的增加,深層網(wǎng)絡(luò)通過對(duì)圖片的多層操作使得其對(duì)圖片分類的能力逐漸加強(qiáng)。由于二者之間差距大,若用深層網(wǎng)絡(luò)直接指導(dǎo)淺層網(wǎng)絡(luò),會(huì)導(dǎo)致淺層網(wǎng)絡(luò)學(xué)習(xí)效果不佳。本文使用網(wǎng)絡(luò)中所有分支的集成指導(dǎo)網(wǎng)絡(luò)中最后分支,減少淺層分類器和深層分類器之間的差距,公式定義如下所示:
其中,H表示網(wǎng)絡(luò)中所有分支輸出logits的均值,H表示H的Softmax輸出,H表示網(wǎng)絡(luò)中所有分支輸出特征的均值。
綜上所述,本文將整個(gè)網(wǎng)絡(luò)的損失函數(shù)定義為:
該損失函數(shù)綜合考慮了教師實(shí)際教授的知識(shí)和期望學(xué)生網(wǎng)絡(luò)學(xué)到的知識(shí)之間的差距,通過縮小二者的差距以提升網(wǎng)絡(luò)的表現(xiàn)力;充分考慮了學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的輸出分布差異,通過KL散度表示學(xué)生網(wǎng)絡(luò)對(duì)于教師知識(shí)的學(xué)習(xí)程度;充分考慮了網(wǎng)絡(luò)的中間隱藏層信息的提取,最后充分考慮了深層網(wǎng)絡(luò)和淺層網(wǎng)絡(luò)之間的性能差距,使得淺層網(wǎng)絡(luò)能夠充分學(xué)習(xí)深層網(wǎng)絡(luò)的性能。
本文使用經(jīng)典的CIFAR10和CIFAR100數(shù)據(jù)集,其中CIFAR10數(shù)據(jù)集由10個(gè)類的60000個(gè)32×32彩色圖像組成,每個(gè)類有6000個(gè)圖像。有50000個(gè)訓(xùn)練圖像和10000個(gè)測(cè)試圖像。CIFAR100有100個(gè)類,每個(gè)類包含600個(gè)圖像,每類各有500個(gè)訓(xùn)練圖像和100個(gè)測(cè)試圖像。本文選擇了經(jīng)典的深度神經(jīng)網(wǎng)絡(luò)ResNet一系列變種模型作為主網(wǎng)絡(luò),分別是ResNet18、ResNet50、ResNet101,在自蒸餾網(wǎng)路中分支的安放位置為每個(gè)ResNet塊結(jié)構(gòu)后。本文使用了最早提出多分支結(jié)構(gòu)的BranchyNet的交叉熵作為基準(zhǔn),同時(shí)對(duì)比了BOT和EED方法。所有訓(xùn)練網(wǎng)絡(luò)的學(xué)習(xí)率為0.1,一次批處理圖像設(shè)為128,訓(xùn)練輪次設(shè)為200。
在本組實(shí)驗(yàn)中,我們對(duì)比只使用標(biāo)簽蒸餾技術(shù)下各個(gè)方案的性能。通過實(shí)驗(yàn)結(jié)果可知,在CIFAR10上,與基準(zhǔn)相比,本文的訓(xùn)練方法使得各個(gè)分支的精度和集成的精度均高于基準(zhǔn)。其中,在ResNet18和ResNet101上,每個(gè)分支的精度提升幅度相對(duì)較低,但是在ResNet50上每個(gè)分支的精度提升幅度較為明顯,1/4分支提升了1.51%,2/4分支提升了1.02%,3/4分支提升了0.76%,4/4分支提升了0.75%,集成的精度提升了0.88%。在CIFAR100上,SDA的結(jié)果與基準(zhǔn)相比,在ResNet18和ResNet50上,每個(gè)分支的精度提升幅度相對(duì)較低,但在ResNet101上每個(gè)分支的精度提升幅度較為明顯,1/4分支提升了0.46%,2/4分支提升了0.77%,3/4分支提升了1.1%,4/4分支提升了0.73%,集成的精度提升了0.47%。由實(shí)驗(yàn)結(jié)果可知,本文將標(biāo)簽與特征相結(jié)合的蒸餾方式的各分支精度均有提升。
本組實(shí)驗(yàn)中,本文與工作BOT和最新工作EED在蒸餾內(nèi)容相同的前提下,對(duì)比了不同蒸餾方式的表現(xiàn)。從表1可以得出,與BOT相比,使用分支集成的輸出教授各個(gè)分支的方法,在各分支的精度上是優(yōu)于其他蒸餾方式的。由圖3可知,分支集成的輸出是優(yōu)于最后分支的輸出,因此使用分支集成的輸出教授網(wǎng)絡(luò)中其他分支分類器的效果更好。但是深層網(wǎng)絡(luò)的性能與淺層網(wǎng)絡(luò)的性能存在差距,使用分支集成的輸出教授淺層的分支,可能會(huì)進(jìn)一步拉大差距,影響淺層網(wǎng)絡(luò)的學(xué)習(xí)效果。本文所提方法在最后分支輸出和分支集成輸出的精度高于EED。因此可知,先將分支集成的輸出教授最后分支的輸出,同時(shí)使用最后分支的輸出教授網(wǎng)絡(luò)中的各個(gè)分支,能夠降低分支集成與淺層網(wǎng)絡(luò)的性能差距,使得淺層網(wǎng)絡(luò)能夠更好地學(xué)習(xí)深層網(wǎng)絡(luò)的知識(shí),提升淺層網(wǎng)絡(luò)的性能。
圖3 SDA與EED[12]方法在ResNet50,CIFAR100上最后分支與分支集成的精度對(duì)比
表1 SDA與其他自蒸餾方法精度對(duì)比表
本組實(shí)驗(yàn)中,在CIFAR100數(shù)據(jù)集上,ResNe18和ResNe50模型上探索了不同的溫度參數(shù)對(duì)網(wǎng)絡(luò)模型的性能影響。由圖4和圖5的實(shí)驗(yàn)結(jié)果可知,隨著溫度的升高,網(wǎng)絡(luò)蒸餾的效果并非越來越好。當(dāng)=1時(shí)相當(dāng)于無溫度影響,其效果未受到影響。當(dāng)=3時(shí),與基準(zhǔn)相比,各分支的精度和集成的精度均有提升,效果最好。當(dāng)=7時(shí),在ResNet18上各分支的精度均高于基準(zhǔn),但與T=3時(shí)的效果相比,第2支、第3支的精度高于它的精度值,但第1支、第4支和最后集成的精度值卻小于它的精度值。在ResNet50上,當(dāng)=7時(shí)與基準(zhǔn)相比,除了第1支分支的精度,其他分支及集成的精度均是低于基準(zhǔn)的。因此,在自蒸餾方法中,選擇合適的蒸餾溫度是至關(guān)重要的。
圖4 SDA在ResNet18,CIFAR100上不同溫度參數(shù)下的精度對(duì)比
圖5 SDA在ResNet50,CIFAR100上不同溫度參數(shù)的精度對(duì)比
在本組實(shí)驗(yàn)中,我們?cè)贑IFAR100和ResNet50模型上探索了不同參數(shù)對(duì)網(wǎng)絡(luò)性能的影響。由圖6可知,在不同的參數(shù)下,與基準(zhǔn)相比,網(wǎng)絡(luò)的第1、第2、第3支的精度均不同程度地提高了,范圍在0.19%~1.51%之間,但第4支和集成的結(jié)果在不同情況下出現(xiàn)了精度下降的情況,下降范圍為0.15%~0.17%。由此可知,不同的參數(shù)對(duì)淺層的網(wǎng)絡(luò)精度影響較少,但對(duì)深層的網(wǎng)絡(luò)(最后輸出集成的結(jié)果)精度是有一定影響的。
圖6 SDA在ResNet50,CIFAR100上不同α參數(shù)的精度對(duì)比
本文提出一種基于多分支的自蒸餾方法(SDA),通過借助助教網(wǎng)絡(luò)的思想,使用網(wǎng)絡(luò)中各分支的輸出作為集成的輸出教授最后的輸出,同時(shí)使用最后的輸出教授各個(gè)分支的自蒸餾方式可以提高CNN的分類性能。實(shí)驗(yàn)證明,本文方法的性能表現(xiàn)是優(yōu)于其他自蒸餾方法的,尤其是在多分類的數(shù)據(jù)集CIFAR100上,效果提升最為明顯。我們分別驗(yàn)證了本文方法在蒸餾內(nèi)容上,相較于僅有蒸餾標(biāo)簽的方法,將標(biāo)簽蒸餾與中間特征蒸餾相結(jié)合的方法能夠使網(wǎng)絡(luò)模型效果更好。其次,本文驗(yàn)證了在蒸餾內(nèi)容相同的情況下,SDA與BOT和EED等不同蒸餾方式的比較,SDA方法能夠有效提升網(wǎng)絡(luò)的性能。在未來工作中,我們將會(huì)進(jìn)一步研究自蒸餾方法在其他領(lǐng)域的應(yīng)用。