周兆京,王曉茹,姜竹青,門(mén)愛(ài)東,馬 龍
(1.北京郵電大學(xué) 人工智能學(xué)院,北京 100876;2.北京市經(jīng)濟(jì)管理學(xué)校 信息技術(shù)系,北京 100089;3.中國(guó)人民解放軍96962部隊(duì),北京 102206)
近年來(lái),圖像超分辨率已成為一項(xiàng)重要的研究課題,它在目標(biāo)檢測(cè)、人臉識(shí)別和信息取證等方面有著重要的應(yīng)用價(jià)值。超分辨率旨在學(xué)習(xí)一種非線性映射,從低分辨率圖像中重建出高分辨率圖像。傳統(tǒng)插值算法主要是基于領(lǐng)域像素點(diǎn)進(jìn)行加權(quán)計(jì)算來(lái)生成高分辨率圖像,但僅能獲得原圖本身像素點(diǎn)的領(lǐng)域信息,無(wú)法生成新的高頻信息,導(dǎo)致計(jì)算而得的高分辨率圖像往往不夠清晰。
隨著深度學(xué)習(xí)的快速發(fā)展,研究者們提出了許多基于卷積神經(jīng)網(wǎng)絡(luò)的超分辨率方法[1-4]。SRCNN[2]是首個(gè)將卷積神經(jīng)網(wǎng)絡(luò)引入超分辨率領(lǐng)域的算法,取得了遠(yuǎn)勝傳統(tǒng)算法的性能表現(xiàn)。SRCNN通過(guò)學(xué)習(xí)低分辨率圖像和高分辨率圖像的映射關(guān)系,可以重建出低分辨率圖像中缺失的高頻分量。增加卷積神經(jīng)網(wǎng)絡(luò)深度可以設(shè)計(jì)出卓有成效的超分辨率模型,進(jìn)一步提高超分辨率網(wǎng)絡(luò)的重建效果,但是其計(jì)算復(fù)雜度和內(nèi)存占用量也急劇提升,直接在計(jì)算資源有限的設(shè)備端(如電視或手機(jī))上實(shí)現(xiàn)它們是一個(gè)巨大的挑戰(zhàn)。
為了解決這些難題,研究者們投入大量精力去研究如何壓縮與加速神經(jīng)網(wǎng)絡(luò)模型[5]。大量的研究工作集中在網(wǎng)絡(luò)剪枝、網(wǎng)絡(luò)量化(將網(wǎng)絡(luò)權(quán)值從浮點(diǎn)數(shù)量化為8比特整型值)、輕量化結(jié)構(gòu)設(shè)計(jì)和知識(shí)蒸餾4個(gè)方面去優(yōu)化神經(jīng)網(wǎng)絡(luò)的推理速度。其中,知識(shí)蒸餾是一種新穎的模型壓縮方法,它通過(guò)將效果卓越的大網(wǎng)絡(luò)中的知識(shí)傳遞到一個(gè)層數(shù)更淺、復(fù)雜度更低的小網(wǎng)絡(luò)中去減輕計(jì)算負(fù)擔(dān),無(wú)需改變網(wǎng)絡(luò)原本的結(jié)構(gòu)特點(diǎn)。通過(guò)傳遞大網(wǎng)絡(luò)的高級(jí)特征表達(dá),小網(wǎng)絡(luò)能夠接收到比數(shù)據(jù)集里的標(biāo)簽更強(qiáng)有力的監(jiān)督信息。
本文主要研究了知識(shí)蒸餾在超分辨率網(wǎng)絡(luò)上的運(yùn)用,通過(guò)大量實(shí)驗(yàn)論述了目前知識(shí)蒸餾方法在超分辨率上的局限性,并提出了一種新的知識(shí)蒸餾訓(xùn)練框架,基于編解碼器的結(jié)構(gòu)提取高分辨率圖像中的先驗(yàn)知識(shí),再將其傳遞給小網(wǎng)絡(luò),從而有效提升小網(wǎng)絡(luò)的超分辨率重建效果,實(shí)現(xiàn)超分辨率網(wǎng)絡(luò)的輕量化。
知識(shí)蒸餾[6]的主要思想是引導(dǎo)性能較弱、模型體積小的網(wǎng)絡(luò)模仿性能更優(yōu)、模型體積大的網(wǎng)絡(luò),以獲得更優(yōu)越的性能表現(xiàn)。一般而言,在知識(shí)蒸餾的訓(xùn)練模式中,網(wǎng)絡(luò)層數(shù)深、復(fù)雜度高且性能更優(yōu)的大模型被定義為大網(wǎng)絡(luò),網(wǎng)絡(luò)層數(shù)淺、復(fù)雜度低且性能平平的小模型被定義為小網(wǎng)絡(luò)。其關(guān)鍵在于如何定義大網(wǎng)絡(luò)中的知識(shí),以及如何傳遞知識(shí)。目前,知識(shí)蒸餾方法按照知識(shí)的定義可大致分為3類:基于softmax層輸出的知識(shí)蒸餾[6]、基于特征的知識(shí)蒸餾[7-9]和基于特征關(guān)系的知識(shí)蒸餾[10-11]。這些方法在許多視覺(jué)任務(wù)上都取得了不錯(cuò)的實(shí)驗(yàn)結(jié)果。
然而,大多數(shù)知識(shí)蒸餾方法都是面向高級(jí)視覺(jué)任務(wù)而言的,如圖像分類、目標(biāo)檢測(cè)等,而鮮有針對(duì)低級(jí)視覺(jué)任務(wù)的研究。為了實(shí)現(xiàn)超分辨率網(wǎng)絡(luò)的壓縮,本文將經(jīng)典的知識(shí)蒸餾方法直接運(yùn)用至超分辨率網(wǎng)絡(luò)上,對(duì)比小網(wǎng)絡(luò)采用蒸餾訓(xùn)練前后的重建效果,判斷其方法的有效性,詳細(xì)結(jié)果見(jiàn)表1。
表1 知識(shí)蒸餾在超分辨率網(wǎng)絡(luò)的應(yīng)用Tab.1 Application of knowledge distillation in super-resolution
表1采用文獻(xiàn)[14]的EDSR作為基準(zhǔn)網(wǎng)絡(luò),大網(wǎng)絡(luò)采用與EDSR原論文相同的設(shè)置(包含32個(gè)殘差模塊、256個(gè)通道數(shù)),小網(wǎng)絡(luò)采用結(jié)構(gòu)規(guī)模更小的EDSR(包含4個(gè)殘差模塊、64個(gè)通道數(shù)),在Set5和Set14這兩個(gè)標(biāo)準(zhǔn)數(shù)據(jù)集上進(jìn)行了4倍超分辨率測(cè)試,使用峰值信噪比(point signal to noise ratio, PSNR)衡量重建圖像的質(zhì)量。PSNR是圖像復(fù)原任務(wù)中使用最普遍的圖像質(zhì)量評(píng)估方法,其值越大,說(shuō)明生成圖像與真實(shí)圖像越相近,圖像質(zhì)量越好。Soft-target[6]是基于softmax輸出的知識(shí)蒸餾方法,其超分辨率任務(wù)輸出的是高分辨率圖像而并無(wú)softmax輸出,實(shí)際實(shí)驗(yàn)中則修改為用大網(wǎng)絡(luò)生成的圖像監(jiān)督小網(wǎng)絡(luò)。AT[8]和FitNet[9]都是典型的基于特征的知識(shí)蒸餾方法,F(xiàn)SP[10]、CCKD[12]和SPKD[13]則是基于特征關(guān)系知識(shí)蒸餾的代表。
觀察表1可得,這些知識(shí)蒸餾方法運(yùn)用到超分辨率任務(wù)上,大多數(shù)都難以起到提升小網(wǎng)絡(luò)性能的積極作用,只有基于輸出圖像的soft-target方法能帶來(lái)一些改進(jìn)。不同于分類等高級(jí)視覺(jué)任務(wù),在超分辨率等低級(jí)視覺(jué)任務(wù)中,像素間的局部和全局關(guān)系尤為重要??紤]到這一特點(diǎn),諸如FSP和AT等基于特征關(guān)系的方法是從特征圖中抽象出更高級(jí)的特征表達(dá)知識(shí),可能改變了圖像本身的空間信息,從而影響了模型恢復(fù)圖像的能力。而對(duì)于soft-target提供的加成,這可能是因?yàn)閷W(xué)習(xí)大網(wǎng)絡(luò)的輸出圖像比直接學(xué)習(xí)真實(shí)圖像更簡(jiǎn)單,一定程度上降低了模型的訓(xùn)練難度。圖1展現(xiàn)了FSP與基準(zhǔn)網(wǎng)絡(luò)的訓(xùn)練過(guò)程對(duì)比情況,橫軸是訓(xùn)練批量的迭代次數(shù),豎軸是在Set5數(shù)據(jù)集上測(cè)試的PSNR結(jié)果,baseline是未經(jīng)過(guò)知識(shí)蒸餾訓(xùn)練的小網(wǎng)絡(luò),F(xiàn)SP是采用FSP蒸餾訓(xùn)練的小網(wǎng)絡(luò)。觀察圖1可得,知識(shí)蒸餾在訓(xùn)練前期能為小網(wǎng)絡(luò)提供一個(gè)良好的助力,加快其收斂速度,但收斂后和基準(zhǔn)網(wǎng)絡(luò)逐漸趨于一致。
圖1 FSP與基準(zhǔn)網(wǎng)絡(luò)的訓(xùn)練過(guò)程對(duì)比Fig.1 Comparison of the training process between FSP and baseline
綜上可得,目前的知識(shí)蒸餾方法直接運(yùn)用在超分辨率任務(wù)上效果甚微,僅能在訓(xùn)練前期起到加速效果。為了分析知識(shí)蒸餾在超分辨率等圖像回歸任務(wù)上效果不佳的原因,本文將超分辨率網(wǎng)絡(luò)EDSR的中間層特征圖進(jìn)行了可視化,見(jiàn)圖2。圖2中,左圖是輸入圖像,右圖是EDSR首個(gè)殘差塊輸出的特征圖。觀察EDSR的特征圖可以發(fā)現(xiàn),該特征圖與原圖輪廓保持著高度一致,這是由超分辨率任務(wù)本身特性決定的。分類網(wǎng)絡(luò)的特征圖往往趨向于關(guān)注目標(biāo)的局部特征,其稀疏性較高,而超分辨率網(wǎng)絡(luò)提取的特征仍然與輸入圖像趨于一致,每一部分細(xì)節(jié)都影響著成像質(zhì)量,這無(wú)疑增大了通過(guò)蒸餾傳遞特征知識(shí)的難度。由此可知,要想知識(shí)蒸餾在超分辨率上發(fā)揮作用,關(guān)鍵在于大網(wǎng)絡(luò)能提供更加有益于小網(wǎng)絡(luò)訓(xùn)練的特征知識(shí)。
圖2 超分辨率網(wǎng)絡(luò)EDSR的特征圖Fig.2 Feature map of EDSR
受文獻(xiàn)[10]的啟發(fā),本文針對(duì)超分辨率任務(wù)提出了一種新的知識(shí)蒸餾訓(xùn)練框架,借助編解碼器的結(jié)構(gòu),保留真實(shí)圖像中的高頻信息,為特征知識(shí)提供更精準(zhǔn)、有用的特征,助力小網(wǎng)絡(luò)的訓(xùn)練。
本文的知識(shí)蒸餾訓(xùn)練框架如圖3所示。圖3中,HR表示訓(xùn)練集中的高分辨率圖像,LR表示高分辨率圖像對(duì)應(yīng)的低分辨率圖像。
圖3 基于編解碼器的特征知識(shí)蒸餾Fig.3 Feature knowledge distillation based on codec
不同于傳統(tǒng)知識(shí)蒸餾中直接從已訓(xùn)練好的大網(wǎng)絡(luò)中提取知識(shí),本文知識(shí)蒸餾訓(xùn)練框架對(duì)大網(wǎng)絡(luò)增加了一些結(jié)構(gòu)上的訓(xùn)練約束,并使用高分辨圖像作為大網(wǎng)絡(luò)的輸入對(duì)其進(jìn)行訓(xùn)練。大網(wǎng)絡(luò)分為編碼器和解碼器兩部分,編碼器對(duì)輸入的高分辨率圖像進(jìn)行壓縮編碼,它將輸入圖像投影到一個(gè)低維特征空間中,生成更緊湊的特征,然后再輸入解碼器中重構(gòu)出高分辨率圖像,使大網(wǎng)絡(luò)能為超分辨率任務(wù)提取更好的特征表示。
(1)
解碼器的損失函數(shù)LD具體定義為
(2)
(2)式中,λ為退化損失的平衡系數(shù);H、W分別表示高分辨率圖像的高和寬。(2)式由2部分組成,第1部分是生成圖像與真實(shí)圖像的MAE損失,第2部分是生成圖像退化后與解碼器輸入的MAE損失。
編碼器、解碼器共同進(jìn)行訓(xùn)練,大網(wǎng)絡(luò)能從高分辨率圖像中提取高頻信息,從而在網(wǎng)絡(luò)中生成精確的特征知識(shí),整個(gè)大網(wǎng)絡(luò)的總的損失函數(shù)LT具體定義為
LT=βTLE+LD
(3)
(3)式中,βT是編碼器損失函數(shù)的平衡系數(shù)。
小網(wǎng)絡(luò)采用與大網(wǎng)絡(luò)中解碼器相同的結(jié)構(gòu),只是采用低分辨率圖像作為輸入,并使用解碼器的網(wǎng)絡(luò)權(quán)重對(duì)其進(jìn)行初始化,為小網(wǎng)絡(luò)提供一個(gè)良好的訓(xùn)練起點(diǎn)。雖然小網(wǎng)絡(luò)和解碼器具有相同的初始參數(shù),但由于其輸入不同,兩者提取到的特征也大相徑庭。
不同于解碼器,小網(wǎng)絡(luò)的損失函數(shù)不僅包含生成圖像的重建損失以及退化損失,還包含蒸餾損失。這里的知識(shí)蒸餾則與傳統(tǒng)知識(shí)蒸餾類似,是為了將解碼器的特征知識(shí)遷移至小網(wǎng)絡(luò)。
在蒸餾過(guò)程中計(jì)算中間層特征的分布,將特征圖的分布信息定義為知識(shí),通過(guò)最大均值差異(max mean discrepancy,MMD)[12]衡量大網(wǎng)絡(luò)(解碼器)和小網(wǎng)絡(luò)之間的特征分布差異。以MMD作為蒸餾損失函數(shù),引導(dǎo)小網(wǎng)絡(luò)中間層的激活分布模擬大網(wǎng)絡(luò)的激活分布。
(4)
(5)
(5)式中:‖·‖2是L2正則化;G是Gram矩陣,矩陣中每一項(xiàng)為gij=(fi)Tfj。Gram矩陣是特征圖向量化后內(nèi)積的結(jié)果,能反應(yīng)特征之間的相關(guān)程度。
小網(wǎng)絡(luò)總的損失函數(shù)LS可表示為
LS=Lsr+λLF+βSLdistill
(6)
(6)式中:Lsr為超分辨率任務(wù)常用的重建損失,由生成圖像與真實(shí)高分辨率圖像計(jì)算MAE損失而得;λ是退化損失函數(shù)的平衡系數(shù),與大網(wǎng)絡(luò)中的參數(shù)設(shè)置相同,λLF為退化損失,與解碼器相同,將生成圖像經(jīng)退化后與輸入圖像計(jì)算MAE損失;βSLdistill為特征蒸餾損失,其中βS為蒸餾損失的平衡系數(shù)。
參照EDSR的訓(xùn)練設(shè)置[14-15],本文使用DIV2K數(shù)據(jù)集進(jìn)行訓(xùn)練,其中包含800張高分辨率圖像,低分辨率圖像通過(guò)對(duì)高分辨圖像進(jìn)行雙三次下采樣生成而得。每個(gè)訓(xùn)練批次大小為32張圖片。
本文在Set5、Set14、B100和Urban100等標(biāo)準(zhǔn)數(shù)據(jù)集上評(píng)估提出的方法,并使用亮度通道上計(jì)算的PSNR作為評(píng)估指標(biāo)。
首先訓(xùn)練大網(wǎng)絡(luò),經(jīng)過(guò)反復(fù)實(shí)驗(yàn),設(shè)置超參數(shù)βT=10-4,βS=10-3,λ=0.1。表2所示為在不同βT值情況下訓(xùn)練小網(wǎng)絡(luò),并在Set5數(shù)據(jù)集上驗(yàn)證兩倍超分的結(jié)果。在這部分實(shí)驗(yàn)中,退化支路沒(méi)有引入其中,以免對(duì)編解碼器的調(diào)參實(shí)驗(yàn)造成影響。當(dāng)βT為0時(shí),編解碼器失去了壓縮HR中高頻特征的功能,成為了簡(jiǎn)單的線性映射。若參數(shù)βT設(shè)置太大,編碼器損失函數(shù)將促使編碼器生成的低維特征與低分辨率圖像趨于同質(zhì)。在這種情況下,本文提出的知識(shí)蒸餾框架將不能從高分辨率圖像的先驗(yàn)知識(shí)中獲益,并且小網(wǎng)絡(luò)的特征蒸餾與傳統(tǒng)蒸餾方法毫無(wú)差別,性能提升微乎其微。
表2 超參數(shù)βT的實(shí)驗(yàn)結(jié)果Tab.2 Results of balance parameters βT
經(jīng)過(guò)大量實(shí)驗(yàn)得知,參數(shù)βT設(shè)置為10-4時(shí),能在保證解碼器學(xué)習(xí)超分辨率映射和編碼器提取高分辨率的先驗(yàn)知識(shí)之間取得折衷,能通過(guò)特征蒸餾使小網(wǎng)絡(luò)達(dá)到最佳效果。
小網(wǎng)絡(luò)在不同數(shù)據(jù)集上進(jìn)行2倍放大的蒸餾結(jié)果如表3所示,進(jìn)行4倍放大結(jié)果如表4所示。經(jīng)過(guò)蒸餾訓(xùn)練后,小網(wǎng)絡(luò)在各個(gè)基準(zhǔn)數(shù)據(jù)集上的性能都有所提升,在網(wǎng)絡(luò)參數(shù)量和計(jì)算復(fù)雜度未增加的情況下,依靠本文提出的知識(shí)蒸餾方法,小網(wǎng)絡(luò)PSNR在進(jìn)行2倍超分辨率時(shí)能提升0.17~0.28 dB;在4倍超分辨率時(shí)能提升0.11~0.18 dB。與傳統(tǒng)知識(shí)蒸餾方法相比,本文方法大大提升了小網(wǎng)絡(luò)性能,能將知識(shí)蒸餾高效地運(yùn)用在超分辨率任務(wù)上。這主要有以下兩個(gè)原因:①大網(wǎng)絡(luò)中編解碼器結(jié)構(gòu)能有效捕捉高分辨率圖像中的高頻信息,并通過(guò)特征蒸餾傳遞給小網(wǎng)絡(luò);②退化支路的約束縮小了超分辨率任務(wù)的解空間,加速了小網(wǎng)絡(luò)的收斂。傳統(tǒng)知識(shí)蒸餾方法未對(duì)大網(wǎng)絡(luò)進(jìn)行專門(mén)化訓(xùn)練,使得其蘊(yùn)含的高頻信息包含太多噪聲,無(wú)法指導(dǎo)小網(wǎng)絡(luò)訓(xùn)練。
表3—表4中的SRKD[18]和PISR[19]也是兩種結(jié)合了知識(shí)蒸餾的超分辨率方法。其中,SRKD是基于傳統(tǒng)知識(shí)蒸餾AT[9]的方式,通過(guò)學(xué)習(xí)大網(wǎng)絡(luò)的特征圖進(jìn)行蒸餾,只是修改了AT中的知識(shí)定義,因此,它并不能脫離傳統(tǒng)知識(shí)蒸餾在超分任務(wù)上的桎梏,PISR類似于本文方法,通過(guò)對(duì)小網(wǎng)絡(luò)的預(yù)訓(xùn)練提取高分辨率圖像中原有的先驗(yàn)信息來(lái)提高蒸餾效果。本文按照文獻(xiàn)[18-19]的參數(shù)設(shè)置,在DIV2K數(shù)據(jù)集上完成了復(fù)現(xiàn)。結(jié)果表明,無(wú)論是在2倍還是在4倍超分辨率上,本文方法都取得了更高的PSNR指標(biāo)。
表3 小網(wǎng)絡(luò)在不同數(shù)據(jù)集上進(jìn)行2倍放大的蒸餾結(jié)果Tab.3 Small network distillation results with 2x magnification on different data sets
表4 小網(wǎng)絡(luò)在不同數(shù)據(jù)集上進(jìn)行4倍放大的蒸餾結(jié)果Tab.4 Small network distillation results with 4x magnification on different data sets
將本文方法訓(xùn)練而得的小網(wǎng)絡(luò)和文獻(xiàn)[20-21]的超分辨率網(wǎng)絡(luò)進(jìn)行對(duì)比,從表3—表4可得,DRCN和MemNet的參數(shù)量都大于小網(wǎng)絡(luò),相應(yīng)地計(jì)算復(fù)雜度更高,因而在各個(gè)數(shù)據(jù)集上它們都取得了比未蒸餾小網(wǎng)絡(luò)更高的PSNR結(jié)果。但經(jīng)過(guò)本文方法的訓(xùn)練,在不需對(duì)網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行特殊設(shè)計(jì)情況下,蒸餾后的小網(wǎng)絡(luò)并未增加運(yùn)算復(fù)雜度或提高網(wǎng)絡(luò)參數(shù)量,它的性能表現(xiàn)已經(jīng)超過(guò)了DRCN,同時(shí)在耗時(shí)上也縮減至與MemNet相近的水準(zhǔn),這說(shuō)明本文的知識(shí)蒸餾方法能針對(duì)超分辨率網(wǎng)絡(luò)實(shí)現(xiàn)良好的輕量化效果。
為了充分證明有效性,本文還對(duì)方法中的編解碼器結(jié)構(gòu)、退化支路、蒸餾損失函數(shù)進(jìn)行了消融實(shí)驗(yàn)。在Set5上進(jìn)行2倍放大,對(duì)PSNR進(jìn)行定量計(jì)算分析,比較其超分辨率性能。
表5所示為在不同模塊組合下訓(xùn)練而得小網(wǎng)絡(luò)的消融實(shí)驗(yàn)結(jié)果。
表5 消融實(shí)驗(yàn)的結(jié)果Tab.5 Results of ablation studies
從表5可以看出,相比于傳統(tǒng)蒸餾方法,編解碼器結(jié)構(gòu)使得特征蒸餾在超分辨率任務(wù)上更加有效,小網(wǎng)絡(luò)受益于大網(wǎng)絡(luò)解碼器的網(wǎng)絡(luò)權(quán)重,這為小網(wǎng)絡(luò)提供了一個(gè)良好的訓(xùn)練起點(diǎn),并且遷移了大網(wǎng)絡(luò)的重構(gòu)能力;MMD損失函數(shù)提供了比MAE更好的結(jié)果,基于MMD的蒸餾損失促使小網(wǎng)絡(luò)和大網(wǎng)絡(luò)的特征圖保持一致分布,一定程度上避免了大網(wǎng)絡(luò)特征圖中的噪聲影響;退化支路提供了更好的性能表現(xiàn),能帶來(lái)0.1 dB的PSNR提升,這表明縮小解空間有助于超分辨率任務(wù)。
本文主要研究了如何使用知識(shí)蒸餾對(duì)超分辨率網(wǎng)絡(luò)進(jìn)行輕量化,全面分析了現(xiàn)有知識(shí)蒸餾方法直接運(yùn)用到超分辨率網(wǎng)絡(luò)的局限性,并提出了一種基于編解碼器的知識(shí)蒸餾訓(xùn)練框架,能有效提取高分辨率圖像中的先驗(yàn)知識(shí),再通過(guò)特征蒸餾將其傳遞給小網(wǎng)絡(luò)。本文方法顯著地提高了超分辨率網(wǎng)絡(luò)的性能,實(shí)現(xiàn)了超分辨率網(wǎng)絡(luò)的輕量化。