劉紹華, 杜 康, 佘春東, 楊 傲
(北京郵電大學(xué)電子工程學(xué)院, 北京 100080)
在過去幾年的時間里,計算機(jī)視覺領(lǐng)域迅猛發(fā)展,目標(biāo)檢測作為計算機(jī)視覺的一個重要分支已經(jīng)被運(yùn)用于現(xiàn)實中的各種場景下,如智能手機(jī)相機(jī)中的人臉自動抓取功能,停車場中檢測車位占用情況的攝像頭,以及軍事項目中無人機(jī)對于基地目標(biāo)的檢測任務(wù)等。
目前,一些主流的目標(biāo)檢測網(wǎng)絡(luò)都采取殘差神經(jīng)網(wǎng)絡(luò)(residual neural network, ResNet),如ResNet50[1]或者Hourglass[2]作為主干網(wǎng)絡(luò),同時多數(shù)實驗中也將這兩種網(wǎng)絡(luò)作為參照。ResNet50是一種深度殘差網(wǎng)絡(luò),可以有效地解決深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練中的梯度消失和梯度爆炸問題,從而提高網(wǎng)絡(luò)的準(zhǔn)確率。Hourglass網(wǎng)絡(luò)具有良好的特征提取以及多尺度建模的能力,同時在目標(biāo)檢測等任務(wù)中表現(xiàn)出色。但是在移動設(shè)備、嵌入式平臺、物聯(lián)網(wǎng)等資源受限的環(huán)境下,上述兩種深度神經(jīng)網(wǎng)絡(luò)需要消耗大量的計算資源,同時其緩慢的訓(xùn)練推理速度以及龐大的空間占用,使得人們不得不將目光聚焦于輕量化目標(biāo)檢測網(wǎng)絡(luò)。
輕量化目標(biāo)檢測網(wǎng)絡(luò)可以降低計算復(fù)雜度和減少空間占用,在各種資源受限或?qū)τ嬎闼俣?、能耗、存儲等有要求的?yīng)用場景中具有廣泛的應(yīng)用前景。同時其也面臨新的挑戰(zhàn),比如如何提出有效的模型訓(xùn)練方案,使輕量化目標(biāo)檢測模型仍然能夠保持優(yōu)異的性能。
本文介紹了一種基于輕量化CenterNet的多教師聯(lián)合知識蒸餾方案。該方案能有效解決模型輕量化帶來的性能惡化問題,可以大大縮小教師模型和學(xué)生模型之間的性能差距。將多個大規(guī)模復(fù)雜模型作為教師網(wǎng)絡(luò),以教導(dǎo)作為學(xué)生網(wǎng)絡(luò)的輕量化模型進(jìn)行訓(xùn)練。相比于使用傳統(tǒng)訓(xùn)練方式的輕量化目標(biāo)檢測網(wǎng)絡(luò),使用該知識蒸餾方案可以在相同的訓(xùn)練輪數(shù)后達(dá)到更優(yōu)的檢測性能。
本文提出了針對CenterNet目標(biāo)檢測網(wǎng)絡(luò)特有的知識蒸餾模塊——多教師聯(lián)合知識蒸餾。該知識蒸餾模塊采用的蒸餾損失(knowledge distillation loss, LossKD)分為4部分:類別蒸餾損失LossCLSKD、長寬預(yù)測蒸餾損失LossWHKD、目標(biāo)中心點偏移蒸餾損失LossOFFKD、主干網(wǎng)絡(luò)蒸餾損失LossBACKBONEKD。在單教師知識蒸餾實驗中,使用加入ImageNet[3]預(yù)訓(xùn)練權(quán)重的ResNet50作為教師網(wǎng)絡(luò),對比多種輕量化模型的訓(xùn)練效果,驗證了該知識蒸餾模塊的有效性。引入蒸餾損失權(quán)重參數(shù)W,對不同部分蒸餾損失設(shè)置不同的權(quán)重,通過調(diào)整不同蒸餾損失部分的權(quán)重,可以快速地優(yōu)化蒸餾效果。在多教師聯(lián)合蒸餾與單教師蒸餾的對比實驗中,使用加入ImageNet預(yù)訓(xùn)練權(quán)重的ResNet50、Hourglass作為多教師網(wǎng)絡(luò),提出使用蒸餾注意力機(jī)制對不同教師網(wǎng)絡(luò)自適應(yīng)分配不同權(quán)重,驗證了多教師聯(lián)合蒸餾指導(dǎo)出的學(xué)生網(wǎng)絡(luò)在泛化能力以及性能指標(biāo)方面的提升。
在VOC2007數(shù)據(jù)集上,以MobileNetV2輕量化網(wǎng)絡(luò)為例,相較于傳統(tǒng)的CenterNet(主干網(wǎng)絡(luò)為ResNet50),所提方案的參數(shù)大小壓縮了74.7%,推理速度提升了70.5%,在平均精度(mean average Precision, mAP)上僅有1.99的降低,取得了更好的“性能-速度”平衡。同樣經(jīng)過100輪訓(xùn)練,使用多教師聯(lián)合知識蒸餾的訓(xùn)練方式相較于普通訓(xùn)練方式,mAP提升了11.30。
在目標(biāo)檢測領(lǐng)域,最早出現(xiàn)的是雙階段目標(biāo)檢測,如基于區(qū)域的卷積神經(jīng)網(wǎng)絡(luò)(region-based convolutional neural network, R-CNN)[4]、空間金字塔池化網(wǎng)絡(luò)(spatial pyramid pooling network, SPP-net)[5],第一個階段首先產(chǎn)生候選區(qū)域,該區(qū)域包含目標(biāo)大概的位置信息,第二個階段的任務(wù)包含對候選區(qū)域進(jìn)行分類和位置微調(diào)。由此可見,雙階段目標(biāo)檢測在產(chǎn)生候選區(qū)域時會消耗大量的內(nèi)存空間,在訓(xùn)練速度和推理速度上也略顯不足。隨著人們對于目標(biāo)檢測模型性能要求的提升,單階段目標(biāo)檢測應(yīng)運(yùn)而生。單階段目標(biāo)檢測算法可以在同一個階段直接產(chǎn)生物體的類別概率和位置坐標(biāo)值。相比于雙階段目標(biāo)檢測算法,單階段目標(biāo)檢測沒有產(chǎn)生候選區(qū)域的階段,整體流程更為簡單。同時,單階段目標(biāo)檢測算法又可以分為基于錨框和無錨框兩類,以基于錨框為代表的模型有更快的R-CNN(Faster RCNN)[6]、YOLOv2[7]、YOLOv3[8];以無錨框為代表的模型有:CornerNet[9]、CenterNet[10]、ExtremeNet[11]。兩者區(qū)別在于有沒有利用錨框提取候選目標(biāo)框。比如YOLOv3,其首先按一定的規(guī)則在圖片上生成一系列位置固定的錨框,將這些錨框看作是可能的候選區(qū)域,再對錨框是否包含目標(biāo)物體進(jìn)行預(yù)測,如果經(jīng)預(yù)測判斷包含目標(biāo)物體,則還需要預(yù)測所包含物體的類別以及預(yù)測框相對于錨框位置需要調(diào)整的位移大小。CenterNet則采用關(guān)鍵點檢測,通過特征提取后輸出的熱力圖峰值來確定待檢測目標(biāo)物體的中心點和種類,再通過對每個中心點預(yù)測其對應(yīng)的預(yù)測框來完成檢測任務(wù)。
CenterNet網(wǎng)絡(luò)的組成包括主干網(wǎng)絡(luò)特征提取模塊、特征圖解碼模塊以及分類頭。CenterNet的解碼器一般采用反卷積模塊Deconv進(jìn)行3次上采樣,得到分辨率更高的特征圖。解碼器輸出的特征圖隨后傳入到CenterNet的分類頭中,假設(shè)在目標(biāo)檢測任務(wù)中有C個待檢測類,那么分類頭將輸出C+4個通道的特征圖,其中預(yù)測的熱力圖占有C個通道的輸出,每一個類別都有一張熱力圖與之對應(yīng)。在每一張熱力圖上,若某個坐標(biāo)處有物體目標(biāo)的中心點,即在該坐標(biāo)處產(chǎn)生一個關(guān)鍵點。預(yù)測框的長寬大小占2個通道的輸出,預(yù)測框的中心點偏移占有2個通道的輸出。CenterNet的損失包括3部分,熱力圖類別預(yù)測損失LossCLS、目標(biāo)長寬預(yù)測損失LossWH、目標(biāo)中心點偏移損失LossOFF。其中,LossCLS損失函數(shù)采用改進(jìn)的Focal Loss, LossWH和LossOFF的損失函數(shù)均采用L1 Loss。 標(biāo)準(zhǔn)CenterNet的結(jié)構(gòu)示意圖如圖1所示。
圖1 標(biāo)準(zhǔn)CenterNet的結(jié)構(gòu)Fig.1 Structure of standard CenterNet
近年來,隨著深度學(xué)習(xí)應(yīng)用的不斷拓展和移動設(shè)備的廣泛普及,輕量化神經(jīng)網(wǎng)絡(luò)成為了深度學(xué)習(xí)領(lǐng)域的熱門話題之一。輕量化神經(jīng)網(wǎng)絡(luò)是指在保持神經(jīng)網(wǎng)絡(luò)模型準(zhǔn)確性的前提下,通過設(shè)計輕量化算法和結(jié)構(gòu),減小神經(jīng)網(wǎng)絡(luò)的參數(shù)數(shù)量和計算復(fù)雜度,以實現(xiàn)在較低硬件資源條件下的高效運(yùn)行。
ShuffleNet[12]是由曠視科技提出的一種輕量級卷積神經(jīng)網(wǎng)絡(luò)模型。它通過使用逐通道的1×1卷積和通道重排來減少計算量。通道重排是指將輸入的特征圖進(jìn)行分組,然后對每個分組內(nèi)的通道進(jìn)行混洗,從而將不同的通道分布到不同的分組中。該方法可以有效減少模型中的參數(shù)量和計算量。MobileNetV2[13]使用了瓶頸設(shè)計以及倒殘差模塊,滿足了在移動設(shè)備等資源受限的環(huán)境下進(jìn)行圖像分類和目標(biāo)檢測的需求。EfficientNet[14]是一個多尺度卷積神經(jīng)網(wǎng)絡(luò),該網(wǎng)絡(luò)通過利用深度網(wǎng)絡(luò)和多尺度特征提取,在保證準(zhǔn)確性的情況下減小了神經(jīng)網(wǎng)絡(luò)參數(shù)量和計算復(fù)雜度。小型基于語義理解的深度雙向預(yù)訓(xùn)練Transformer(tiny bidirectional encoder representation from Transformer, TinyBERT)[15]是一個輕量化的基于Transformer的雙向編碼器表示(bidirectional encoder representation from Transformer, BERT)模型,該模型通過蒸餾技術(shù)將原版BERT模型的知識遷移到較小的模型中,從而在保證準(zhǔn)確性的情況下,減小了模型的參數(shù)數(shù)量和計算復(fù)雜度。2020年華為推出GhostNet[16],該網(wǎng)絡(luò)通過Ghost模塊,即一種基于低秩分解的特征重用技術(shù),將參數(shù)量減小了50%,同時在ImageNet分類任務(wù)中超過了MobileNetV3。MobileBERT[17]相較于原版BERT模型減少了90%的計算量,同時在自然語言理解(natural language understanding, NLU)任務(wù)中取得了不錯的性能。Bias Loss[18]是一種輕量化網(wǎng)絡(luò)的新型損失函數(shù),將注意力集中在一組有價值的數(shù)據(jù)點上,并防止特性差的樣本誤導(dǎo)優(yōu)化過程。Mobile-Former[19]將MobileNet和Transformer進(jìn)行并行設(shè)計,可以實現(xiàn)局部和全局特征的雙向融合。在分類和目標(biāo)檢測任務(wù)中,性能遠(yuǎn)超其他輕量級網(wǎng)絡(luò)。
知識蒸餾作為一種新興的模型壓縮訓(xùn)練方法,目前已成為深度學(xué)習(xí)領(lǐng)域的一個研究熱點和重點[20]。Hinton等[21]認(rèn)為,學(xué)生模型可以利用教師模型傳遞的“暗知識”更好地理解和掌握特征知識,從而增強(qiáng)其泛化能力。最早,知識蒸餾算法主要針對分類問題,分類問題的共同點是模型最后會有一個Softmax層,其輸出值對應(yīng)了相應(yīng)類別的概率值。知識蒸餾使用軟標(biāo)簽,能夠提供高效的遷移泛化能力。軟標(biāo)簽是指輸入圖像通過教師模型得到Softmax層的輸出。軟標(biāo)簽相比硬標(biāo)簽,有著更高的熵、更小的梯度變化,因此學(xué)生模型相比教師模型可以使用更少的數(shù)據(jù)和更大的學(xué)習(xí)率。在使用軟標(biāo)簽時,一般會引入溫度系數(shù)T:
(1)
式中:T為蒸餾溫度,T=1時,即為普通的Softmax。蒸餾溫度T的升高,可以使標(biāo)簽變得更軟。通過使用軟標(biāo)簽,可以獲得比硬標(biāo)簽更豐富的泛化信息,從而避免學(xué)生模型過擬合。雖然知識蒸餾已經(jīng)獲得了廣泛的應(yīng)用,但是學(xué)生模型的性能通常接近于教師模型。特別地,若學(xué)生模型和教師模型使用相同的網(wǎng)絡(luò),學(xué)生模型的性能將超越教師模型[22],性能更差的教師模型反而教出了更好的學(xué)生模型[23]。一般的知識蒸餾可以統(tǒng)一抽象為如下形式:
L=αLsoft+βLhard
(2)
式中:α和β是超參數(shù),β是學(xué)生網(wǎng)絡(luò)在硬標(biāo)簽上的損失,α是學(xué)生網(wǎng)絡(luò)在教師網(wǎng)絡(luò)生成的軟標(biāo)簽上的損失。為了更好地理解知識蒸餾的作用,一些研究工作從數(shù)學(xué)或?qū)嶒炆蠈χR蒸餾的作用機(jī)制進(jìn)行了證明和解釋,主要分為以下幾類:
(1) 軟標(biāo)簽為學(xué)生模型提供正則化約束。這一結(jié)論最早可以追溯到通過貝葉斯優(yōu)化來控制網(wǎng)絡(luò)超參數(shù)的對比試驗[24],其表明了教師模型的軟標(biāo)簽為學(xué)生模型提供了顯著的正則化。軟標(biāo)簽正則化的作用是雙向的,因此能將知識從較弱的教師模型遷移到能力更強(qiáng)大的學(xué)生模型中[25-26]。軟標(biāo)簽通過標(biāo)簽平滑訓(xùn)練提供了正則化[25-26],標(biāo)簽平滑通過避免過分相信訓(xùn)練樣本的真實標(biāo)簽來防止訓(xùn)練的過擬合[25]。
(2) 軟標(biāo)簽為學(xué)生模型提供了“特權(quán)信息”?!疤貦?quán)信息”可以被用于教師模型的解釋和評估[27]。教師模型在訓(xùn)練的過程中,將軟目標(biāo)的“暗知識”遷移到學(xué)生模型中,而學(xué)生模型在測試的過程中并不能使用“暗知識”。從這個角度看,知識蒸餾通過軟標(biāo)簽來為學(xué)生模型傳遞“特權(quán)信息”。
(3) 軟標(biāo)簽引導(dǎo)了學(xué)生模型優(yōu)化的方向。Phuong等[28]從模型訓(xùn)練的角度證明了軟標(biāo)簽?zāi)軌蛑笇?dǎo)學(xué)生模型的優(yōu)化方向。Cheng 等[29]從數(shù)學(xué)角度驗證了使用軟標(biāo)簽比使用原始數(shù)據(jù)進(jìn)行的優(yōu)化學(xué)習(xí)具有更高的學(xué)習(xí)速度和更好的性能。
同時,為了解決教師網(wǎng)絡(luò)特征圖與學(xué)生網(wǎng)絡(luò)特征圖對應(yīng)不一致的問題,Lin等[30]提出了Target-aware模型Transformer模型,使得整個學(xué)生網(wǎng)絡(luò)在蒸餾過程中能夠分別模仿教師網(wǎng)絡(luò)的每一個空間組件。通過這種方式,可以提高匹配能力,進(jìn)而提高知識蒸餾的性能。針對前景和背景特征圖之間的不均勻差異將對蒸餾產(chǎn)生負(fù)面影響的問題,Yang等[31]提出局部和全局知識蒸餾(focal and global knowledge distillation, FGD)。局部蒸餾分離了前景和背景,迫使學(xué)生將注意力集中在老師的關(guān)鍵像素和通道上。全局蒸餾重建了不同像素之間的關(guān)系,并將其從教師傳遞給學(xué)生,以補(bǔ)償局部蒸餾中丟失的全局信息。
注意力機(jī)制是近幾年深度學(xué)習(xí)中重要的方法之一,在不同領(lǐng)域得到了廣泛的研究。在發(fā)展過程中,注意力機(jī)制首先被應(yīng)用在自然語言處理任務(wù)之中。而如今,在人工智能領(lǐng)域,注意力機(jī)制作為神經(jīng)網(wǎng)絡(luò)體系結(jié)構(gòu)的一個重要組成部分,在數(shù)據(jù)挖掘、時間序列預(yù)測、計算機(jī)視覺等領(lǐng)域都有著廣泛的應(yīng)用。顧名思義,注意力機(jī)制就是通過對人類閱讀、聽說中的注意力行為進(jìn)行模擬操作。注意力機(jī)制可以做到重點關(guān)注有利于任務(wù)完成的信息,而抑制那些不重要的信息。這項技術(shù)也可以自然地應(yīng)用在目標(biāo)檢測模型中:在執(zhí)行檢測任務(wù)時,網(wǎng)絡(luò)會更重視特征更加重要的通道,因此正確利用注意力機(jī)制給通道分配不同的注意力分?jǐn)?shù)有利于網(wǎng)絡(luò)使用更加重要的特征圖進(jìn)行特征融合,從而提高目標(biāo)檢測的準(zhǔn)確性。
Transformer[32]模型利用多頭注意力機(jī)制加速了模型的訓(xùn)練,加強(qiáng)了不同位置信息的關(guān)聯(lián)性,一定程度地解決了RNN系列模型的長距離依賴問題。Transformer模型僅利用注意力機(jī)制就實現(xiàn)了機(jī)器翻譯等復(fù)雜的任務(wù),并且取得了目前最好的效果,向相關(guān)研究人員展示了注意力機(jī)制的強(qiáng)大作用。
在本節(jié)中,首先介紹CenterNet單教師知識蒸餾模塊,然后提出了多教師知識蒸餾模塊的網(wǎng)絡(luò)結(jié)構(gòu),并對在多教師聯(lián)合知識蒸餾模塊中加入的蒸餾注意力機(jī)制、損失權(quán)重參數(shù)W進(jìn)行了描述,最后闡述了聯(lián)合訓(xùn)練與凍結(jié)解凍訓(xùn)練的模型訓(xùn)練方式。
為了達(dá)到模型輕量化的目的,目前常用的方法是將模型的特征提取網(wǎng)絡(luò)替換為輕量化網(wǎng)絡(luò)。若在模型輕量化后依然采用常規(guī)的訓(xùn)練方式,即使用LossCLS、LossWH、LossOFF組成的損失函數(shù),很難完成對于模型參數(shù)的有效更新,導(dǎo)致訓(xùn)練出的模型檢測性能不佳。于是,為了兼顧模型輕量化以及模型檢測的性能,提出了專門針對于CenterNet的知識蒸餾方案。
目前的知識蒸餾算法大多應(yīng)用于分類任務(wù),很難直接應(yīng)用于目標(biāo)檢測模型。這其中的難點主要有兩個方面:① 檢測任務(wù)中大量的背景樣本使得正負(fù)樣本不均衡,導(dǎo)致模型的分類更加困難;② 目標(biāo)檢測的網(wǎng)絡(luò)更加復(fù)雜,尤其對于雙階段網(wǎng)絡(luò)而言,雙階段的網(wǎng)絡(luò)由于RPN輸出的不確定性,導(dǎo)致教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的候選區(qū)不能精確對齊。而對于這兩個難點,目標(biāo)檢測模型CenterNet均有良好的解決方案。針對難點①,根據(jù)CenterNet的范式,該網(wǎng)絡(luò)沒有復(fù)雜的正負(fù)樣本采樣,只有物體的中心點是正樣本,其他都是負(fù)樣本。針對于難點②,如第1節(jié)所介紹,CenterNet為單階段目標(biāo)檢測網(wǎng)絡(luò),結(jié)構(gòu)相對簡單?;贑enterNet的知識蒸餾模塊的結(jié)構(gòu)如圖2所示。
圖2 基于CenterNet的知識蒸餾模塊結(jié)構(gòu)Fig.2 Knowledge distillation module structure based on CenterNet
所采用的總體蒸餾模塊由兩部分構(gòu)成,分別為基于主干模塊的蒸餾以及基于分類頭的蒸餾。在模型端到端的訓(xùn)練過程中,總體網(wǎng)絡(luò)參數(shù)更新由原來的分類頭損失以及蒸餾損失共同決定,模型總體損失定義如下:
LossTOTAL=LossCLS+LossWH+LossOFF+LossKD
(3)
關(guān)于教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)的選擇,在基于CenterNet多教師聯(lián)合知識蒸餾任務(wù)中,選用ResNet50、Hourglass等作為教師網(wǎng)絡(luò),它們擁有更復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu)以及更多的參數(shù)運(yùn)算量,但模型可以提供更高的檢測性能。選用ResNet18、MobileNetV2等作為學(xué)生網(wǎng)絡(luò),模型簡單并擁有較低的計算復(fù)雜度,足夠輕量化,但是其檢測性能也相對較低。綜上所述,本文采取的教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)可以提供較好的平衡,以滿足知識蒸餾的需求,契合教師帶動學(xué)生、“先富帶動后富”的思想,使作為學(xué)生網(wǎng)絡(luò)的輕量化模型得到性能提升。
根據(jù)第2.1節(jié)圖2所展示的知識蒸餾的框架,提出基于CenterNet的蒸餾損失函數(shù)。
在基于主干模塊的蒸餾上,使用MSE損失作為該部分蒸餾損失:
LossBACKBONEKD=MSE(ReLU(FMt),ReLU(FMs))
(4)
式中:FMt為教師網(wǎng)絡(luò)主干模塊輸出的特征圖;FMs為學(xué)生網(wǎng)絡(luò)主干模塊輸出的特征圖。
根據(jù)第1.2節(jié)介紹的CenterNet分類頭的特殊結(jié)構(gòu),擬采取的蒸餾損失也從分類頭的這3類優(yōu)化目標(biāo)著手。對于輸出的熱力圖,把教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)輸出的特征圖通過一個ReLU層將負(fù)數(shù)部分去掉,然后使用MSE損失生成LossCLSKD,如下所示:
LossCLSKD=MSE(ReLU(hmt),ReLU(hms))
(5)
式中:hmt為教師網(wǎng)絡(luò)輸出的熱力圖,hms為學(xué)生網(wǎng)絡(luò)輸出的熱力圖。
對于LossWH和LossOFF,根據(jù)CenterNet只學(xué)習(xí)正樣本的原理,將教師網(wǎng)絡(luò)輸出熱力圖用非極大值抑制操作后,產(chǎn)生掩膜,再將掩膜與寬高和中心點偏移的對應(yīng)特征圖相乘作為新特征圖。最后,將教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)所有對應(yīng)的像素點使用L1 Loss作為該部分的蒸餾損失。具體操作過程如下:
LossWHKD=L1(wht·mask/T,whs·mask/T)
(6)
LossOFFKD=L1(offt·mask/T,offs·mask/T)
(7)
式中:wht與offt為教師網(wǎng)絡(luò)輸出的特征圖,whs與offs為學(xué)生網(wǎng)絡(luò)的特征圖;T為溫度系數(shù);“·”為張量乘法符號。
根據(jù)CenterNet損失的定義,CenterNet損失主要分為主干模塊損失、熱力圖損失、預(yù)測框?qū)捀邠p失以及中心點偏移損失,這些損失對于目標(biāo)檢測性能的重要程度不盡相同。于是給予不同類別損失不同的權(quán)重W,將分類頭部分總體蒸餾損失定義如下:
LossKD=w1*LossCLSKD+w2*LossWHKD+
w3*LossOFFKD+w4*LossBACKBONEKD
(8)
式中:w1、w2、w3、w4為模型訓(xùn)練前人工設(shè)置的損失權(quán)重的超參數(shù),需根據(jù)訓(xùn)練效果進(jìn)行調(diào)節(jié)。
為了進(jìn)一步提升知識蒸餾的效果,可以借鑒Transformer模型中多頭注意力機(jī)制的作用,融合多個教師網(wǎng)絡(luò)生成的特征圖以及軟標(biāo)簽。對于同一張圖片,不同的教師網(wǎng)絡(luò)生成的軟標(biāo)簽分布是不同的,如圖3所示。
圖3 不同網(wǎng)絡(luò)對于同一張圖片生成的軟標(biāo)簽分布Fig.3 Soft label distribution generated for the same image by different networks
由于不同教師網(wǎng)絡(luò)的性能效果存在差異(與標(biāo)簽之間的損失),若直接使用Softmax對損失進(jìn)行權(quán)重分配,容易發(fā)生權(quán)重的極限偏移,使權(quán)重分布更加尖銳(極端情況為其中一個教師網(wǎng)絡(luò)權(quán)重接近于1,其余教師網(wǎng)絡(luò)權(quán)重接近于0),此時多教師蒸餾網(wǎng)絡(luò)退化為單教師蒸餾網(wǎng)絡(luò)。為避免上述情況發(fā)生,得到更加平緩的權(quán)重分布,需要先對不同教師網(wǎng)絡(luò)的損失進(jìn)行歸一化,如下所示:
(9)
(10)
(11)
多教師聯(lián)合知識蒸餾模塊如圖4所示,使用N個不同的教師網(wǎng)絡(luò)對同一學(xué)生網(wǎng)絡(luò)進(jìn)行知識蒸餾。
圖4 多教師聯(lián)合知識蒸餾模型Fig.4 Multi-teacher joint knowledge distillation network
根據(jù)知識蒸餾的特性,使用教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)聯(lián)合訓(xùn)練的方式。具體來說,教師網(wǎng)絡(luò)在硬標(biāo)簽上單獨訓(xùn)練N個輪次,學(xué)生網(wǎng)絡(luò)在硬標(biāo)簽上單獨訓(xùn)練N個輪次,然后使用教師網(wǎng)絡(luò)指導(dǎo)學(xué)生網(wǎng)絡(luò),并保留原本的硬標(biāo)簽損失,再訓(xùn)練N個輪次。實驗證明,這種訓(xùn)練方式得到的模型性能是最佳的。
借鑒遷移學(xué)習(xí)的思想,神經(jīng)網(wǎng)絡(luò)主干特征提取部分所提取到的特征是通用的,所以將主干模塊凍結(jié)起來訓(xùn)練可以加快訓(xùn)練效率,也可以防止訓(xùn)練好的主干模塊權(quán)值被破壞。在凍結(jié)訓(xùn)練階段,模型主干模塊的初始參數(shù)加載了其在ImageNet上預(yù)訓(xùn)練得到的模型參數(shù),在新任務(wù)的訓(xùn)練過程中參數(shù)被凍結(jié)了,主干網(wǎng)絡(luò)權(quán)重不發(fā)生改變,只改變解碼模塊以及分類頭模塊的權(quán)重,占用的顯存較小,僅對網(wǎng)絡(luò)進(jìn)行微調(diào)。 在解凍訓(xùn)練階段,模型的主干網(wǎng)絡(luò)不被凍結(jié),主干模塊、解碼模塊以及分類頭模塊的參數(shù)都可以得到更新。在解凍訓(xùn)練階段,模型訓(xùn)練占用的顯存較大,凍結(jié)解凍訓(xùn)練方式如圖5所示。
圖5 解凍和凍結(jié)訓(xùn)練Fig.5 Unfreeze and freeze training
在本節(jié)中,首先介紹了實驗使用的數(shù)據(jù)集、硬件設(shè)備以及訓(xùn)練方法。其次,對比了多種輕量化主干網(wǎng)絡(luò)是否使用知識蒸餾訓(xùn)練的檢測效果,評價指標(biāo)主要包括不同主干網(wǎng)絡(luò)下CenterNet的目標(biāo)檢測性能mAP和模型參數(shù)大小。同時,通過消融實驗對比單教師與多教師知識蒸餾,證明了多教師聯(lián)合知識蒸餾帶來的性能提升。基于CenterNet的知識蒸餾訓(xùn)練方案可以分為以下3步。
(1) 訓(xùn)練教師模型:使用硬標(biāo)簽,即正常的標(biāo)簽訓(xùn)練教師模型。
(2) 計算“軟標(biāo)簽”:利用訓(xùn)練好的教師模型來計算“軟標(biāo)簽”,也就是教師模型“軟化后”的輸出。
(3) 訓(xùn)練學(xué)生模型:在學(xué)生模型硬標(biāo)簽損失函數(shù)的基礎(chǔ)上加入教師模型輸出軟標(biāo)簽的損失函數(shù),共同指導(dǎo)學(xué)生模型的訓(xùn)練。
采用VOC2007作為數(shù)據(jù)集。其中,訓(xùn)練集包括19 352張分辨率為500×375的圖片和對應(yīng)標(biāo)簽,驗證集采用隨機(jī)分割訓(xùn)練集的方式生成,每一輪訓(xùn)練中,訓(xùn)練集和驗證集的比例為9∶1。測試集由3 870張分辨率為500×375的圖片和對應(yīng)標(biāo)簽組成。在將數(shù)據(jù)集導(dǎo)入模型時,需要對圖片進(jìn)行尺寸歸一化和數(shù)據(jù)增強(qiáng)的預(yù)處理,將圖片尺寸歸一化為512×512。
使用2塊GPU Tesla T4對不同模型進(jìn)行ImageNet[22]預(yù)訓(xùn)練權(quán)重加載的訓(xùn)練。使用ResNet50作為教師網(wǎng)絡(luò),對學(xué)生網(wǎng)絡(luò)ResNet18、ResNet34、MobileNetV2、EfficientNet-b0等都設(shè)置100輪訓(xùn)練。在訓(xùn)練過程中,采用第3節(jié)中的聯(lián)合訓(xùn)練、凍結(jié)訓(xùn)練與解凍訓(xùn)練結(jié)合的方法,將100輪訓(xùn)練分為2步, 1~50輪采用學(xué)習(xí)率為1e-3、權(quán)重衰減為5e-5的Adam優(yōu)化器,使用凍結(jié)訓(xùn)練方式;51~100輪采用學(xué)習(xí)率為1e-4、權(quán)重衰減為5e-5的Adam優(yōu)化器,使用解凍訓(xùn)練方式。采用不同主干模塊的CenterNet模型大小、參數(shù)大小、浮點運(yùn)算次數(shù)(floating point operations per second,FLOPs)如表1所示。
表1 不同結(jié)構(gòu)CenterNet的模型大小、參數(shù)大小、FLOPsTable 1 Model size, parameter size, FLOPs of different structures of CenterNet
經(jīng)過100輪訓(xùn)練后,不同模型使用普通訓(xùn)練以及知識蒸餾訓(xùn)練,得到數(shù)據(jù)集中“汽車”類別的精度(average precision, AP)如圖6所示。
圖6 不同模型在兩種訓(xùn)練方式下對汽車類別的APFig.6 AP of different models trained in two ways in car category
對同一輕量化網(wǎng)絡(luò)模型,分別使用普通訓(xùn)練方式和基于CenterNet知識蒸餾的訓(xùn)練方式,在VOC2007數(shù)據(jù)集上得到的mAP的交并比(intersection over union, IoU)閾值取0.5,如表2所示。
表2 不同輕量化結(jié)構(gòu)CenterNet的mAP、訓(xùn)練輪數(shù)以及預(yù)訓(xùn)練 模型加載情況(VOC數(shù)據(jù)集下)Table 2 mAP, number of training rounds, and loading of pre-trained models in different light-weight structures of CenterNet (with VOC dataset)
可以看出,在100輪訓(xùn)練后,使用知識蒸餾訓(xùn)練方式的輕量級特征提取網(wǎng)絡(luò)相較于使用普通訓(xùn)練方式的輕量化特征提取網(wǎng)絡(luò)可以得到更高的mAP,實驗的4種輕量化模型的mAP平均提升了4.52。
為了測試多教師知識蒸餾的效果,采用ResNet50和Hourglass組成多教師網(wǎng)絡(luò)對輕量化網(wǎng)絡(luò)進(jìn)行知識蒸餾。在實驗中主要采用了ResNet18、ResNet34、MobileNetV2、EfficientNet-b0作為BACKBONE的CenterNet目標(biāo)檢測模型進(jìn)行消融實驗。
使用4塊GPU Tesla T4對不同模型在VOC2007數(shù)據(jù)集上進(jìn)行ImageNet[22]預(yù)訓(xùn)練權(quán)重加載的解凍訓(xùn)練與凍結(jié)訓(xùn)練。訓(xùn)練輪數(shù)設(shè)置為100輪,其中1~50輪訓(xùn)練采用學(xué)習(xí)率為1e-3,權(quán)重衰減為5e-4的Adam優(yōu)化器,使用凍結(jié)訓(xùn)練方式;51~100輪采用學(xué)習(xí)率為1e-4,權(quán)重衰減為5e-5的Adam優(yōu)化器,使用解凍訓(xùn)練方式。經(jīng)過訓(xùn)練后,不同模型對于VOC2007數(shù)據(jù)集中數(shù)據(jù)集中“摩托車”類別的AP如圖7所示。
圖7 不同模型在兩種訓(xùn)練方式下對于摩托車類別的APFig.7 AP of different models trained in two ways in motorbike category
不同模型使用單教師知識蒸餾與多教師知識蒸餾,在VOC2007數(shù)據(jù)集訓(xùn)練得到的mAP如表3所示。
表3 單教師知識蒸餾和多教師知識蒸餾對比Table 3 Comparison of single teacher knowledge distillation and multi-teacher knowledge distillation
可以看出,在100輪訓(xùn)練后,使用多教師聯(lián)合知識蒸餾訓(xùn)練方式相較于單教師知識蒸餾訓(xùn)練得到的輕量級特征提取網(wǎng)絡(luò)可以得到更高的mAP。實驗中4種輕量化模型的mAP平均提升了1.83。
根據(jù)第2節(jié)對于損失權(quán)重參數(shù)W的定義,W由w1、w2、w3、w4組合而成,其中w1、w2、w3、w4分別為LossCLSKD、LossWHKD、LossOFFKD、LossBACKBONEKD的權(quán)重。由式(8)可知損失權(quán)重參數(shù)W選擇的不同會直接影響蒸餾損失的構(gòu)成。在模型訓(xùn)練前,需要人工設(shè)置損失權(quán)重參數(shù)W的可選取值,W∈{0.1,1,10,100,1 000},采取網(wǎng)格搜索算法枚舉多種蒸餾權(quán)重組合進(jìn)行實驗。在VOC2007數(shù)據(jù)集下,不同權(quán)重組合的模型性能如表4所示。
表4 由不同損失權(quán)重參數(shù)W蒸餾訓(xùn)練得到的CenterNet性能 對比(主干網(wǎng)絡(luò)為ResNet18)Table 4 Comparison of CenterNet performance obtained by distillation training with different loss weight parameters W (withResNet18 as backbone)
根據(jù)表4中使用不同損失權(quán)重參數(shù)W訓(xùn)練出的模型的檢測性能效果對比,對于VOC2007數(shù)據(jù)集而言,當(dāng)w1,w2,w3,w4分別為1 000,1,0.1,100時,mAP相較w1,w2,w3,w4分別為0.1,10,1 000,1時提升了6.48,說明了W超參數(shù)對于模型調(diào)優(yōu)的重要性,即如果找到更加適配模型的損失權(quán)重參數(shù),可以使訓(xùn)練效果得到提升。其中,w1,w2,w3,w4分別為1 000,1,0.1,100時訓(xùn)練出的ResNet18-CenterNet模型的檢測效果圖如圖8所示。
圖8 ResNet18-CenterNet目標(biāo)檢測效果Fig.8 ResNet18-CenterNet object’s detection effect
綜上所述,當(dāng)訓(xùn)練條件發(fā)生變化時,可以選擇更換損失權(quán)重參數(shù)W,以快速達(dá)到更好的訓(xùn)練效果以及檢測效果。
使用BACKBONE為MobileNetV2的CenterNet目標(biāo)檢測模型,使用普通訓(xùn)練與多教師聯(lián)合知識蒸餾訓(xùn)練兩種不同方式下得到的模型進(jìn)行目標(biāo)檢測的測試,效果對比如圖9所示,圖9(a)、圖9(b)、圖9(c)左側(cè)為普通訓(xùn)練方式檢測效果圖,圖9(a)、圖9(b)、圖9(c)右側(cè)為多教師蒸餾檢測效果圖。
圖9 普通訓(xùn)練方式檢測效果圖(各分圖左側(cè))及多教師蒸餾訓(xùn)練檢測效果圖(各分圖右側(cè))Fig.9 Detection performance diagram of the regular training(left) and detection performance diagram of the multi-teacher distillation training (right)
由圖9可見,使用了多教師知識蒸餾訓(xùn)練得到的模型,無論在檢測的置信度還是檢測出的目標(biāo)的數(shù)量上,都更加優(yōu)異。圖9(c)圖片存在多目標(biāo)重疊的情況,使用多教師聯(lián)合知識蒸餾后的模型也能很好地解決這一問題,可將多重疊目標(biāo)進(jìn)行準(zhǔn)確分割。
本文回顧了目前國內(nèi)外對于目標(biāo)檢測領(lǐng)域的研究現(xiàn)狀,總結(jié)了一些輕量化目標(biāo)檢測模型的實現(xiàn)方法以及一些優(yōu)秀的知識蒸餾方案。本文提出了基于CenterNet的多教師聯(lián)合知識蒸餾訓(xùn)練方案,定義了基于CenterNet的蒸餾損失函數(shù),提出了利用蒸餾注意力機(jī)制來有效融合多個教師網(wǎng)絡(luò)的“暗知識”,可應(yīng)用于多種輕量化主干網(wǎng)絡(luò)CenterNet的訓(xùn)練過程。實驗結(jié)果表明,本蒸餾訓(xùn)練方案在VOC2007數(shù)據(jù)集上表現(xiàn)出了比普通訓(xùn)練方案更加優(yōu)異的性能。以使用了本文的多教師聯(lián)合蒸餾訓(xùn)練后的MobileNet-CenterNet為例,相較于傳統(tǒng)的CenterNet(主干網(wǎng)絡(luò)為ResNet50),參數(shù)大小壓縮了74.7%,推理速度提升了70.5%,而mAP只有1.99的降低,取得了更好的“性能-速度”平衡。同樣經(jīng)過100輪訓(xùn)練,多教師聯(lián)合知識蒸餾訓(xùn)練方式相較于普通訓(xùn)練方式,mAP提升了11.30。未來,還將繼續(xù)改進(jìn)CenterNet多教師聯(lián)合知識蒸餾損失函數(shù),以更優(yōu)的蒸餾結(jié)構(gòu)訓(xùn)練得到更好的輕量化目標(biāo)檢測模型。