張 晶 鞠佳良 任永功
(遼寧師范大學(xué)計(jì)算機(jī)與人工智能學(xué)院 遼寧 大連 116081)
深度學(xué)習(xí)憑借對(duì)樣本高維特征的非線性表達(dá)及數(shù)據(jù)信息的抽象表示,極大地推進(jìn)了語音識(shí)別、計(jì)算機(jī)視覺等人工智能方法在工業(yè)中的應(yīng)用.1989 年LeCun等人[1]提出深度卷積網(wǎng)絡(luò)LeNet 模型,在手寫體圖像識(shí)別領(lǐng)域取得了突破性進(jìn)展,為深度學(xué)習(xí)的發(fā)展提供了前提和基礎(chǔ).為進(jìn)一步提升深度神經(jīng)網(wǎng)絡(luò)模式識(shí)別及圖像處理精度,推廣其在工業(yè)中的應(yīng)用,國內(nèi)外學(xué)者不斷優(yōu)化及改進(jìn)網(wǎng)絡(luò)結(jié)構(gòu).隨著模型層數(shù)逐步增加,模型參數(shù)和架構(gòu)愈加龐大,算法對(duì)存儲(chǔ)、計(jì)算等資源的需求不斷增長,導(dǎo)致大模型網(wǎng)絡(luò)失效等問題[2],例如Resnet50,VGG16 等大型神經(jīng)網(wǎng)絡(luò),盡管在圖像分類應(yīng)用上表現(xiàn)出卓越性能,但其冗余參數(shù)導(dǎo)致較高計(jì)算成本和內(nèi)存消耗.同時(shí),多媒體、5G 技術(shù)、移動(dòng)終端的快速發(fā)展,邊緣計(jì)算設(shè)備廣泛部署,使網(wǎng)絡(luò)應(yīng)用需求逐步增加.手機(jī)、平板電腦、移動(dòng)攝像機(jī)等便攜式近端設(shè)備相比于固定設(shè)備存在數(shù)十倍的計(jì)算、存儲(chǔ)等能力差距,為大規(guī)模網(wǎng)絡(luò)近端遷移與運(yùn)行帶來困難.如何提升邊緣設(shè)備計(jì)算、識(shí)別及分類能力,實(shí)現(xiàn)大規(guī)模深度學(xué)習(xí)網(wǎng)絡(luò)的近端部署成為有意義的工作.基于此,Buciluǎ等人[3]提出神經(jīng)網(wǎng)絡(luò)模型壓縮方法,將信息從大模型或模型集合傳輸?shù)叫枰?xùn)練的小型模型,而不降低模型精度.同時(shí),大規(guī)模神經(jīng)網(wǎng)絡(luò)模型中包含的大量參數(shù)存在一定功能稀疏性,使網(wǎng)絡(luò)結(jié)構(gòu)出現(xiàn)過參數(shù)化等問題,即使在網(wǎng)絡(luò)性能敏感的大規(guī)模場(chǎng)景中,仍包含產(chǎn)生重復(fù)信息的神經(jīng)元與鏈接.知識(shí)蒸餾 (knowledge distillation, KD)將高性能大規(guī)模網(wǎng)絡(luò)作為教師網(wǎng)絡(luò)指導(dǎo)小規(guī)模學(xué)生網(wǎng)絡(luò)[4],實(shí)現(xiàn)知識(shí)精煉與網(wǎng)絡(luò)結(jié)構(gòu)壓縮,成為模型壓縮、加速運(yùn)算、大規(guī)模網(wǎng)絡(luò)近端部署的重要方法.
然而,隨著人們對(duì)隱私保護(hù)意識(shí)的增強(qiáng)以及法律、傳輸?shù)葐栴}的加劇,針對(duì)特定任務(wù)的深度網(wǎng)絡(luò)訓(xùn)練數(shù)據(jù)往往難以獲取,使Data-Free 環(huán)境下的神經(jīng)網(wǎng)絡(luò)模型壓縮,即在避免用戶隱私數(shù)據(jù)泄露的同時(shí)得到一個(gè)與數(shù)據(jù)驅(qū)動(dòng)條件下壓縮后準(zhǔn)確率相似的模型,成為一個(gè)具有重要實(shí)際意義的研究方向.Chen 等人[5]提出Data-Free 環(huán)境知識(shí)蒸餾框架DAFL (data-free learning of student networks, DAFL),建立教師端生成器,生成偽樣本訓(xùn)練集,實(shí)現(xiàn)知識(shí)蒸餾并獲得與教師網(wǎng)絡(luò)性能近似的小規(guī)模學(xué)生網(wǎng)絡(luò).然而,該方法在復(fù)雜數(shù)據(jù)集上將降低學(xué)生網(wǎng)絡(luò)識(shí)別準(zhǔn)確率,其主要原因有3 個(gè)方面:
1)判別網(wǎng)絡(luò)優(yōu)化目標(biāo)不同.模型中教師網(wǎng)絡(luò)優(yōu)化生成器產(chǎn)生偽數(shù)據(jù),實(shí)現(xiàn)學(xué)生網(wǎng)絡(luò)知識(shí)蒸餾,使學(xué)生網(wǎng)絡(luò)難以獲得與教師網(wǎng)絡(luò)一致的優(yōu)化信息構(gòu)建網(wǎng)絡(luò)模型.
2)誤差信息優(yōu)化生成器.教師端生成器的構(gòu)建過度信任教師網(wǎng)絡(luò)對(duì)偽數(shù)據(jù)的判別結(jié)果,利用誤差信息優(yōu)化并生成質(zhì)量較差的偽訓(xùn)練樣本,知識(shí)蒸餾過程學(xué)生網(wǎng)絡(luò)難以有效利用教師網(wǎng)絡(luò)潛在先驗(yàn)分布信息.
3)學(xué)生網(wǎng)絡(luò)泛化性低.模型中生成數(shù)據(jù)僅依賴于教師網(wǎng)絡(luò)訓(xùn)練損失,導(dǎo)致生成數(shù)據(jù)特征多樣性缺失,降低學(xué)生網(wǎng)絡(luò)判別性.
如圖1 所示,MNIST 數(shù)據(jù)集中類別為1 和7 時(shí)圖像特征有較大差異,而圖1 右側(cè)中DAFL 方法的學(xué)生網(wǎng)絡(luò)得到的2 類數(shù)據(jù)統(tǒng)計(jì)特征直方圖相當(dāng)近似,該模型訓(xùn)練得到的小規(guī)模學(xué)生網(wǎng)絡(luò)針對(duì)特征相似圖像難以獲得更魯棒的判別結(jié)果.為提升DAFL 模型中學(xué)生的網(wǎng)絡(luò)準(zhǔn)確率及泛化性,提出新的雙生成器網(wǎng)絡(luò)架構(gòu)DG-DAFL(double generators-DAFL,DG-DAFL),圖1 右側(cè)中由DG-DAFL 框架訓(xùn)練得到學(xué)生網(wǎng)絡(luò)判別器特征統(tǒng)計(jì)直方圖對(duì)比,即1 類和7 類特征統(tǒng)計(jì)結(jié)果有一定差距,為后續(xù)分類提供了前提.
為解決Data-Free 環(huán)境知識(shí)蒸餾、保證網(wǎng)絡(luò)識(shí)別精度與泛化性,本文提出雙生成器網(wǎng)絡(luò)架構(gòu)DGDAFL,學(xué)生端生成器在教師端生成器的輔助下充分利用教師網(wǎng)絡(luò)潛在先驗(yàn)知識(shí),產(chǎn)生更適合學(xué)生網(wǎng)絡(luò)訓(xùn)練的偽訓(xùn)練樣本,利用生成器端樣本分布差異,避免DAFL 學(xué)生網(wǎng)絡(luò)對(duì)單一教師網(wǎng)絡(luò)端生成器樣本依賴,保證生成器樣本多樣性,提升學(xué)生網(wǎng)絡(luò)判別器識(shí)別泛化性.本文貢獻(xiàn)有3 方面:
Fig.1 Comparison of normalized statistical results for approximate sample characteristics圖1 近似樣本特征歸一化統(tǒng)計(jì)結(jié)果對(duì)比
1)針對(duì)Data-Free 知識(shí)蒸餾問題提出雙生成器網(wǎng)絡(luò)架構(gòu)DG-DAFL,建立教師生成器網(wǎng)絡(luò)與學(xué)生生成器網(wǎng)絡(luò),生成偽樣本.優(yōu)化教師生成器網(wǎng)絡(luò)的同時(shí),學(xué)生網(wǎng)絡(luò)判別器優(yōu)化學(xué)生生成器網(wǎng)絡(luò),實(shí)現(xiàn)生成器與判別器分離,避免誤差判別信息干擾生成器構(gòu)建.同時(shí),使網(wǎng)絡(luò)任務(wù)及優(yōu)化目標(biāo)一致,提升學(xué)生網(wǎng)絡(luò)性能.該結(jié)構(gòu)可被拓展于解決其他任務(wù)的Data-Free 知識(shí)蒸餾問題.
2)通過增加教師網(wǎng)絡(luò)及學(xué)生網(wǎng)絡(luò)生成器端樣本分布差異度量,避免單生成器網(wǎng)絡(luò)結(jié)構(gòu)中學(xué)生網(wǎng)絡(luò)訓(xùn)練過度依賴教師生成器網(wǎng)絡(luò)樣本,產(chǎn)生泛化性較低等問題.同時(shí),該差異度量可使得學(xué)生網(wǎng)絡(luò)生成數(shù)據(jù)在保證分布近似條件下的樣本多樣性,進(jìn)一步提升學(xué)生網(wǎng)絡(luò)識(shí)別魯棒性.
3)所提出框架在流行分類數(shù)據(jù)集Data-Free 環(huán)境下,學(xué)生網(wǎng)絡(luò)參數(shù)量僅為教師網(wǎng)絡(luò)的50%時(shí),仍取得了令人滿意的識(shí)別性能.同時(shí),進(jìn)一步驗(yàn)證并分析了近似樣本數(shù)據(jù)集的分類問題,取得了更魯棒的結(jié)果.
針對(duì)大規(guī)模神經(jīng)網(wǎng)絡(luò)的近端部署與應(yīng)用,網(wǎng)絡(luò)模型壓縮及加速成為人工智能領(lǐng)域的研究熱點(diǎn).目前的模型壓縮方法包括網(wǎng)絡(luò)剪枝[6]、參數(shù)共享[7]、量化[8]、網(wǎng)絡(luò)分解[9]、緊湊網(wǎng)絡(luò)設(shè)計(jì),其中知識(shí)蒸餾憑借靈活、直觀的知識(shí)抽取及模型壓縮性能受到學(xué)者廣泛關(guān)注.2015 年,Hinton 等人[4]提出知識(shí)蒸餾模型,構(gòu)建教師網(wǎng)絡(luò)、學(xué)生網(wǎng)絡(luò)及蒸餾算法3 部分框架,引入溫度(temperature,T)系數(shù),使卷積神經(jīng)網(wǎng)絡(luò)softmax層的預(yù)測(cè)標(biāo)簽由硬標(biāo)簽(hard-label)轉(zhuǎn)換為軟標(biāo)簽(soft-label),利用龐大、參數(shù)量多的教師網(wǎng)絡(luò)監(jiān)督訓(xùn)練得到體量、參數(shù)量更少且分類性能與教師網(wǎng)絡(luò)更近似的學(xué)生網(wǎng)絡(luò)[3-4,10-11].根據(jù)知識(shí)蒸餾操作的不同部分,分為目標(biāo)(logits)蒸餾[12-16]與特征圖蒸餾[17-22]兩類.logits 知識(shí)蒸餾模型主要目標(biāo)集中在構(gòu)建更為有效的正則化項(xiàng)及優(yōu)化方法,在硬標(biāo)簽(hard-label)監(jiān)督訓(xùn)練下得到泛化性能更好的學(xué)生網(wǎng)絡(luò).Zhang 等人[16]提出深度互學(xué)習(xí)(deep mutual learning,DML)模型,利用交替學(xué)習(xí)同時(shí)強(qiáng)化學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò).然而,教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)的性能差距使蒸餾過程難以收斂.基于此,Mirzadeh 等人[14]提出助教知識(shí)蒸餾(teacher assistant knowledge distillation,TAKD)模型,引入中等規(guī)模助教網(wǎng)絡(luò),縮小教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)之間過大的性能差距,達(dá)到逐步蒸餾的目的.特征圖知識(shí)蒸餾模型通過直接將樣本表征從教師網(wǎng)絡(luò)遷移至學(xué)生網(wǎng)絡(luò)[17-18,20],或?qū)⒂?xùn)練教師網(wǎng)絡(luò)模型樣本結(jié)構(gòu)遷移至學(xué)生網(wǎng)絡(luò)[19,21-22],實(shí)現(xiàn)知識(shí)抽取.該類方法充分利用大規(guī)模教師網(wǎng)絡(luò)對(duì)樣本的高維、非線性特征表達(dá)及樣本結(jié)構(gòu),獲得更高效的學(xué)生網(wǎng)絡(luò).
Data-Free 環(huán)境中用于訓(xùn)練模型的真實(shí)數(shù)據(jù)往往難以獲取,使知識(shí)蒸餾模型失效.對(duì)抗生成網(wǎng)絡(luò)(generative adversarial network,GAN)技術(shù)的發(fā)展,激發(fā)了該類環(huán)境下知識(shí)蒸餾領(lǐng)域方法的進(jìn)步.2014 年,Goodfellow 等人[23]提出GAN 模型,通過模型中生成器與鑒別器的極大極小博弈,二者相互競(jìng)爭(zhēng)提升各自生成和識(shí)別能力[24],可用于生成以假亂真的圖片[25]、影片[26]等的無監(jiān)督學(xué)習(xí)方法.GAN中的生成器可合成數(shù)據(jù)直接作為訓(xùn)練數(shù)據(jù)集,或用于訓(xùn)練數(shù)據(jù)集增廣及生成難樣本支持學(xué)生網(wǎng)絡(luò)訓(xùn)練.Nguyen 等人[27]利用預(yù)訓(xùn)練的GAN 生成器作為模型反演的先驗(yàn),構(gòu)建偽訓(xùn)練數(shù)據(jù)集.Bhardwaj 等人[28]利用10%的原始數(shù)據(jù)和預(yù)訓(xùn)練教師模型生成合成圖像數(shù)據(jù)集,并將合成圖像用于知識(shí)蒸餾.Liu 等人[29]與Zhang 等人[30]均利用無標(biāo)簽數(shù)據(jù)提升模型效果,分別提出無標(biāo)簽數(shù)據(jù)蒸餾的光流學(xué)習(xí)(learning optical flow with unlabeled data distillation, DDFlow)模型[29]與圖卷積網(wǎng)絡(luò)可靠數(shù)據(jù)蒸餾(reliable data distillation on graph convolution network, RDDGCN)模型[30].其中RDDGCN 模型利用教師網(wǎng)絡(luò)對(duì)所生成的未標(biāo)注數(shù)據(jù)給予新的訓(xùn)練注釋,構(gòu)建訓(xùn)練數(shù)據(jù)集訓(xùn)練學(xué)生網(wǎng)絡(luò).有研究借助大規(guī)模預(yù)訓(xùn)練數(shù)據(jù)集提升模型效果,Yin 等人[31]提出的Deep-Inversion 方法將圖像更新?lián)p失與教師、學(xué)生之間的對(duì)抗性損失結(jié)合,教師網(wǎng)絡(luò)通過對(duì)Batch Normalization 層中所包含通道的均值和方差進(jìn)行推導(dǎo),在大規(guī)模ImageNet 數(shù)據(jù)集上預(yù)訓(xùn)練深度網(wǎng)絡(luò)后合成圖像作為訓(xùn)練樣本集.Lopes 等人[32]進(jìn)一步利用教師網(wǎng)絡(luò)先驗(yàn)信息,通過教師網(wǎng)絡(luò)激活層重構(gòu)訓(xùn)練數(shù)據(jù)集以實(shí)現(xiàn)學(xué)生網(wǎng)絡(luò)知識(shí)蒸餾.文獻(xiàn)[28-32]所述方法均利用少量訓(xùn)練數(shù)據(jù)或常用的預(yù)訓(xùn)練數(shù)據(jù)集信息,在Data-Free 環(huán)境中仍難以解決無法直接獲取真實(shí)且可用于訓(xùn)練小規(guī)模學(xué)生網(wǎng)絡(luò)的先驗(yàn)信息等問題.
基于此,DAFL 框架借助GAN 學(xué)習(xí)模型,將預(yù)訓(xùn)練好的教師網(wǎng)絡(luò)作為判別器網(wǎng)絡(luò),構(gòu)建并優(yōu)化生成器網(wǎng)絡(luò)模型,生成更加接近真實(shí)樣本分布的偽數(shù)據(jù),為高精度、小規(guī)模學(xué)生網(wǎng)絡(luò)的知識(shí)蒸餾與網(wǎng)絡(luò)壓縮提供有效先驗(yàn)信息,框架如圖2 所示.首先,通過函數(shù)one_hot獲得偽標(biāo)簽,利用損失函數(shù)將GAN 中判別器的輸出結(jié)果從二分類轉(zhuǎn)換為多分類,以實(shí)現(xiàn)多分類任務(wù)的知識(shí)蒸餾;其次,采用信息熵?fù)p失函數(shù)、特征圖激活損失函數(shù)、類別分布損失函數(shù)優(yōu)化生成器,為學(xué)生網(wǎng)絡(luò)訓(xùn)練提供數(shù)據(jù);最終,實(shí)現(xiàn)在沒有原始數(shù)據(jù)驅(qū)動(dòng)條件下,通過知識(shí)蒸餾方法使學(xué)生網(wǎng)絡(luò)參數(shù)減少一半,且具有與教師網(wǎng)絡(luò)近似的分類準(zhǔn)確率.然而,DAFL 框架中生成器優(yōu)化過程完全信任判別器針對(duì)Data-Free 環(huán)境中初始生成偽樣本的先驗(yàn)判別,忽略了偽樣本所構(gòu)造偽標(biāo)簽帶來的誤差,干擾生成器優(yōu)化,直接影響學(xué)生網(wǎng)絡(luò)性能.同時(shí),教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)執(zhí)行不同任務(wù)時(shí)存在學(xué)生網(wǎng)絡(luò)過度依賴教師網(wǎng)絡(luò)生成器樣本,降低Data-Free 環(huán)境下模型學(xué)習(xí)泛化性.
為了提升生成樣本質(zhì)量,F(xiàn)ang 等人[33]提出無數(shù)據(jù)對(duì)抗蒸餾(data-free adversarial distillation,DFAD)模型,通過訓(xùn)練一個(gè)外部生成器網(wǎng)絡(luò)合成數(shù)據(jù),使學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)輸出差異最大化圖像.Han 等人[34]提出魯棒性和多樣性的Data-Free 知識(shí)蒸餾(robutness and diversity seeking data-free knowledge distillation,RDSKD)方法在生成器訓(xùn)練階段引入指數(shù)懲罰函數(shù),提升生成器生成圖像的多樣性.Nayak 等人[35]提出零樣本知識(shí)蒸餾模型,僅利用教師網(wǎng)絡(luò)參數(shù)對(duì)softmax層空間建模生成訓(xùn)練樣本.同時(shí),Micaelli 等人[36]提出零樣本對(duì)抗性信息匹配模型,利用教師網(wǎng)絡(luò)特征表示的信息生成訓(xùn)練樣本.為避免零樣本學(xué)習(xí)中先驗(yàn)信息缺失降低學(xué)生網(wǎng)絡(luò)學(xué)習(xí)準(zhǔn)確率等問題,Kimura等人[37]與 Shen 等人[38]分別提出偽樣本訓(xùn)練模型與網(wǎng)絡(luò)嫁接模型,二者均借助少量確定性監(jiān)督樣本,將知識(shí)從教師模型提取到學(xué)生神經(jīng)網(wǎng)絡(luò)中.為充分利用教師網(wǎng)絡(luò)先驗(yàn)信息,Storkey 等人[39]提出zero-shot知識(shí)蒸餾方法,將教師網(wǎng)絡(luò)同時(shí)定義為樣本鑒別器.同時(shí),Radosavovic 等人[40]提出全方位監(jiān)督學(xué)習(xí)模型.
文獻(xiàn)[5, 33-40]所述的Data-Free 環(huán)境中知識(shí)蒸餾模型所需的訓(xùn)練數(shù)據(jù)通常由已訓(xùn)練教師模型的特征表示生成,該類數(shù)據(jù)包含部分教師網(wǎng)絡(luò)先驗(yàn)信息,在無數(shù)據(jù)可用的情況下顯示出了很大的潛力.然而,Data-Free知識(shí)蒸餾仍是一項(xiàng)非常具有挑戰(zhàn)性的任務(wù),主要集中在如何生成高質(zhì)量、多樣化、具有針對(duì)性的訓(xùn)練數(shù)據(jù),進(jìn)而獲得更高精度、高泛化性的小規(guī)模學(xué)生網(wǎng)絡(luò).
針對(duì)提升Data-Free 環(huán)境中知識(shí)蒸餾方法有效性與泛化性,本文受DAFL 模型的啟發(fā),提出DG-DAFL網(wǎng)絡(luò)架構(gòu),如圖3 所示.包括4 部分網(wǎng)絡(luò)結(jié)構(gòu):教師端生成器網(wǎng)絡(luò)GT、學(xué)生端生成器網(wǎng)絡(luò)GS、教師端判別器網(wǎng)絡(luò)NT、學(xué)生端判別器網(wǎng)絡(luò)NS.DG-DAFL 利用教師端與學(xué)生端判別器網(wǎng)絡(luò)NT與NS,同時(shí)優(yōu)化生成器網(wǎng)絡(luò)GT與GS,保證學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)優(yōu)化目標(biāo)一致,避免真實(shí)樣本標(biāo)簽類別先驗(yàn)信息缺失時(shí)生成器過度信任教師網(wǎng)絡(luò)判別結(jié)果,產(chǎn)生質(zhì)量較低的偽樣本,降低學(xué)生網(wǎng)絡(luò)判別性能.同時(shí),通過增加生成器端偽樣本分布損失,保證學(xué)生端生成器網(wǎng)絡(luò)訓(xùn)練樣本多樣性,提升學(xué)生網(wǎng)絡(luò)學(xué)習(xí)泛化性.DG-DAFL 框架的訓(xùn)練過程可總結(jié)為3 個(gè)步驟:教師端輔助生成器GT構(gòu)建、最優(yōu)化學(xué)生端生成器GS構(gòu)建、學(xué)生網(wǎng)絡(luò)NS與教師網(wǎng)絡(luò)NT知識(shí)蒸餾.
本文構(gòu)建雙生成器網(wǎng)絡(luò)架構(gòu)GT與GS,通過教師網(wǎng)絡(luò)提取訓(xùn)練樣本先驗(yàn)信息,訓(xùn)練教師端生成器網(wǎng)絡(luò)GT,使生成的偽樣本分布更近似于真實(shí)樣本.由于真實(shí)樣本標(biāo)簽缺失,GT難以得到來自于NT準(zhǔn)確、充分的樣本分布先驗(yàn)信息,實(shí)現(xiàn)最優(yōu)化訓(xùn)練.因此,本文僅利用教師端生成器網(wǎng)絡(luò)GT作為訓(xùn)練學(xué)生端生成器網(wǎng)絡(luò)GS的輔助網(wǎng)絡(luò),強(qiáng)化生成偽樣本質(zhì)量,提升學(xué)生網(wǎng)絡(luò)判別準(zhǔn)確率.
隨機(jī)樣本Z(T)作為教師端生成器網(wǎng)絡(luò)GT(Z(T);θg)的初始輸入,經(jīng)網(wǎng)絡(luò)計(jì)算后得到偽樣本x(iT),i=1,2,…,N,其中 θg為GT網(wǎng)絡(luò)參數(shù).同時(shí),偽樣本集X(T)作為教師網(wǎng)絡(luò)判別器NT(X(T);θd)的輸入,可得到該網(wǎng)絡(luò)判別結(jié)果,結(jié)合先驗(yàn)信息構(gòu)造損失函數(shù)LGT,反饋訓(xùn)練生成器網(wǎng)絡(luò)GT,得到更真實(shí)樣本分布的偽訓(xùn)練樣本集,用于學(xué)生網(wǎng)絡(luò)知識(shí)蒸餾.為獲得優(yōu)化反饋信息,LGT由3 部分構(gòu)成:
Fig.3 Architecture and learning process of DG-DAFL圖3 DG-DAFL 架構(gòu)及學(xué)習(xí)過程
最小化預(yù)測(cè)標(biāo)簽與真實(shí)標(biāo)簽交叉熵值,學(xué)習(xí)教師網(wǎng)絡(luò)判別器先驗(yàn)信息,使GT生成與真實(shí)樣本分布更為接近的偽樣本集.
2)借助DAFL 中模型訓(xùn)練過程,NT網(wǎng)絡(luò)中多卷積層所提取的特征向量中更具判別性的神經(jīng)元將被激活,即偽樣本X(T)經(jīng)預(yù)訓(xùn)練網(wǎng)絡(luò)NT逐層非線性特征計(jì)算后得到特征向量,其中更大激活值可包含更多的真實(shí)樣本特征先驗(yàn)信息,特征圖激活損失函數(shù)可被表示為
該損失在生成器優(yōu)化過程中減小偽樣本經(jīng)卷積濾波器后激活值更大的特征,得到更接近真實(shí)樣本特征表達(dá).
3)為充分利用預(yù)訓(xùn)練教師網(wǎng)絡(luò)樣本分布及類別先驗(yàn)信息,構(gòu)建預(yù)訓(xùn)練集樣本類平衡分布損失Lie-T.定義p={p1,p2,…,pk}為k類樣本集中的每類樣本出現(xiàn)的概率,當(dāng)各類樣本為均勻分布時(shí),即pk=,所含信息量最大.為保證教師網(wǎng)絡(luò)判別結(jié)果的均衡性、多樣性,充分利用預(yù)訓(xùn)練樣本分布信息,以教師網(wǎng)絡(luò)優(yōu)化生成器在該類數(shù)據(jù)集下等概率生成各類樣本,構(gòu)建信息熵?fù)p失函數(shù):
結(jié)合式(1)~(3),可得到用于優(yōu)化輔助生成器GT的目標(biāo)函數(shù)為
其中 α 和 β 為平衡因子.利用式(4)保證GT優(yōu)化過程充分利用教師網(wǎng)絡(luò)保存的訓(xùn)練樣本分布等先驗(yàn)信息,即可獲得更近似于真實(shí)數(shù)據(jù)的高質(zhì)量偽樣本數(shù)據(jù)集.
根據(jù)2.1 節(jié)所述的教師端生成器GT的優(yōu)化過程,借助教師端判別器網(wǎng)絡(luò)NT包含的真實(shí)樣本先驗(yàn)信息.然而,由于函數(shù)one_hot所構(gòu)建的偽樣本標(biāo)簽將帶來大量噪音,當(dāng)GT對(duì)NT完全信任時(shí),其優(yōu)化過程將引入錯(cuò)誤信息,使學(xué)生端判別器網(wǎng)絡(luò)NS訓(xùn)練階段難以生成與真實(shí)樣本分布近似的偽樣本集,影響學(xué)生網(wǎng)絡(luò)判別準(zhǔn)確率.同時(shí),當(dāng)NS的訓(xùn)練將完全依賴于網(wǎng)絡(luò)GT生成偽樣本時(shí)將降低模型NS的泛化性.
為解決上述問題,本文在學(xué)生網(wǎng)絡(luò)端引入生成器GS,如圖2 所示.利用GT信息輔助GS優(yōu)化,生成更接近真實(shí)分布且更具多樣性的訓(xùn)練樣本.首先,雙生成器GT與GS通過隨機(jī)初始樣本同時(shí)生成偽樣本矩陣X(T)與X(S),其中,X(T)通過NT計(jì)算并由式(4)構(gòu)建損失反饋訓(xùn)練生成器GT,生成新的教師端偽樣本集X′(T);其次,X(S)同時(shí)經(jīng)NT與NS計(jì)算,為充分借助教師網(wǎng)絡(luò)先驗(yàn)數(shù)據(jù)分布信息度量分布差異,利用式(5)優(yōu)化NS:
此時(shí),利用初步訓(xùn)練得到的NS結(jié) 合當(dāng)前生成偽樣本集X(S)與式(4),構(gòu)建反饋損失函數(shù)=Loh-S+αLα-S+βLie-S,優(yōu)化當(dāng)前學(xué)生網(wǎng)絡(luò)生成器GS.該模型可保證教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)執(zhí)行相同任務(wù),提升學(xué)生網(wǎng)絡(luò)學(xué)習(xí)能力.同時(shí),通過對(duì)學(xué)生網(wǎng)絡(luò)優(yōu)化避免對(duì)缺失真實(shí)標(biāo)簽判別結(jié)果的過分信任,降低生成器優(yōu)化效果.最后,GS生成新的學(xué)生端偽樣本集X′(S).為使GS獲得更多樣本先驗(yàn)信息保證生成樣本與真實(shí)樣本分布一致性,同時(shí),保證生成偽樣本多樣性,提升學(xué)生網(wǎng)絡(luò)模型泛化性,本文采用KL 散度獲得2 個(gè)優(yōu)化得到的偽樣本集X′(T)與X′(S)隨分布差異,如式(6)所示:
本文僅期望學(xué)生網(wǎng)絡(luò)生成器GS所 得的樣本集X′(S)在分布上與先驗(yàn)樣本分布更為接近.此時(shí),構(gòu)建學(xué)生網(wǎng)絡(luò)生成器優(yōu)化損失表達(dá),如式(7)所示,實(shí)現(xiàn)最優(yōu)化生成器GS的構(gòu)建.
其中,γ為平衡因子.
本文利用優(yōu)化得到的學(xué)生端生成器GS,更新偽樣本集X′(S)作為訓(xùn)練數(shù)據(jù)輔助學(xué)生網(wǎng)絡(luò)構(gòu)建.
教師網(wǎng)絡(luò)NT與學(xué)生網(wǎng)絡(luò)NS同時(shí)接受學(xué)生端生成器獲得的優(yōu)化為樣本集X′(S),由于模型差異,網(wǎng)絡(luò)結(jié)構(gòu)相對(duì)復(fù)雜的教師網(wǎng)絡(luò)輸出結(jié)果優(yōu)于網(wǎng)絡(luò)結(jié)構(gòu)相對(duì)簡(jiǎn)單的學(xué)生網(wǎng)絡(luò).為提升模型壓縮效果,借助知識(shí)蒸餾技術(shù),將二者softmax 層上輸出結(jié)果進(jìn)行交叉熵函數(shù)計(jì)算,使學(xué)生網(wǎng)絡(luò)的輸出更近似教師網(wǎng)絡(luò)的輸出,提升學(xué)生網(wǎng)絡(luò)NS的性能.知識(shí)蒸餾損失函數(shù)為
結(jié)合偽樣本訓(xùn)練,在此損失函數(shù)約束下,實(shí)現(xiàn)在相同任務(wù)下較為稀疏的大規(guī)模網(wǎng)絡(luò)到緊湊小規(guī)模網(wǎng)絡(luò)的壓縮及知識(shí)蒸餾.
本文在3 個(gè)流行圖像數(shù)據(jù)集上驗(yàn)證了所提出方法的有效性,并與近年Data-free 環(huán)境下較為流行的知識(shí)蒸餾模型,包括DAFL, DFAD, RDSKD 模型在精度、魯棒性、泛化性上進(jìn)行對(duì)比與分析.同時(shí),通過對(duì)模型消融實(shí)驗(yàn)結(jié)果的統(tǒng)計(jì),討論模型框架結(jié)構(gòu)設(shè)計(jì)的合理性.本文進(jìn)一步設(shè)置實(shí)驗(yàn)數(shù)據(jù),驗(yàn)證DGDAFL 模型的泛化性.實(shí)驗(yàn)運(yùn)行在Intel Core i7-8700及NVIDIA Geforce RTX 2070 硬件環(huán)境,及Windows10操作系統(tǒng)、Python3 語言環(huán)境、Pytorch 深度學(xué)習(xí)框架上.
本文為了更全面地驗(yàn)證模型效果,采用4 種評(píng)價(jià)指標(biāo):準(zhǔn)確率(Accuracy)、精確率(Precision)、召回率(Recall)、特異度(Specificity).
準(zhǔn)確率(Accuracy)指分類模型中正確樣本量占總樣本量的比重,其計(jì)算公式為
精確率(Precision)指分類結(jié)果預(yù)測(cè)為陽性的正確比重,計(jì)算公式為
召回率(Recall)指真實(shí)值為陽性的正確比重,其計(jì)算公式為
特異度(Specificity)指真實(shí)值為陽性的正確比重,其計(jì)算公式為
式(9)~(12)中,TP為模型正確預(yù)測(cè)為正例樣本量,TN為模型正確預(yù)測(cè)為反例樣本量,F(xiàn)P為模型錯(cuò)誤預(yù)測(cè)為正例樣本量,F(xiàn)N為模型錯(cuò)誤預(yù)測(cè)為反例樣本量.
本文引入雙生成器端損失在充分利用教師網(wǎng)絡(luò)先驗(yàn)樣本分布信息條件下,保證生成樣本多樣性,如式(7)所示,其中 γ為平衡因子.為保證實(shí)驗(yàn)的公平性,γ值的選取采用確定范圍{0.01,0.1,1,10,100}內(nèi)值遍歷選取方法,如圖4 中所示,γ取值將對(duì)學(xué)生網(wǎng)絡(luò)模型識(shí)別結(jié)果產(chǎn)生較大影響.當(dāng)γ=10時(shí),MNIST 與USPS 數(shù)據(jù)集均達(dá)到Accuracy統(tǒng)計(jì)的最高值.因此,本文驗(yàn)證實(shí)驗(yàn)中的所有數(shù)據(jù)集,均設(shè)置γ=10.
Fig.4 Effect of γ on model performance圖4 參數(shù) γ值對(duì)模型性能的影響
1)MNIST 手寫體數(shù)據(jù)集
MNIST 數(shù)據(jù)集為10 分類手寫體數(shù)據(jù)集,由像素大小為28×28 的70 000 張圖像組成,本文中隨機(jī)選取60 000 張圖像為訓(xùn)練數(shù)據(jù)集,10 000 張圖像為測(cè)試數(shù)據(jù)集,部分樣本可視化結(jié)構(gòu)如圖5 所示.
Fig.5 Sample visualization of MNIST dataset圖5 MNIST 數(shù)據(jù)集中樣本可視化
本數(shù)據(jù)集實(shí)驗(yàn)中,利用LeNet-5 作為教師網(wǎng)絡(luò)實(shí)現(xiàn)該數(shù)據(jù)集分類模型訓(xùn)練.構(gòu)建學(xué)生網(wǎng)絡(luò)LeNet-5-half,其網(wǎng)絡(luò)結(jié)構(gòu)與教師網(wǎng)絡(luò)相同,每層通道數(shù)相比教師網(wǎng)絡(luò)少一半,計(jì)算成本相比教師網(wǎng)絡(luò)少50%,可實(shí)現(xiàn)網(wǎng)絡(luò)壓縮.表1 中統(tǒng)計(jì)并對(duì)比了所提算法在MNIST 數(shù)據(jù)集上的Accuracy值.
表1 中對(duì)10 次實(shí)驗(yàn)統(tǒng)計(jì)的均值可見,利用真實(shí)數(shù)據(jù)訓(xùn)練得到教師網(wǎng)絡(luò)的Accuracy=0.989 4.由噪聲數(shù)據(jù)隨機(jī)生成偽樣本作為訓(xùn)練集,在教師網(wǎng)絡(luò)指導(dǎo)下,利用知識(shí)蒸餾可得到Accuracy=0.867 8 的學(xué)生網(wǎng)絡(luò),該狀態(tài)下僅利用教師網(wǎng)絡(luò)前期訓(xùn)練得到的判別信息,不借助樣本分布信息,難以達(dá)到滿意的蒸餾效果.DAFL 方法中,通過教師網(wǎng)絡(luò)模型判別結(jié)果回傳損失,優(yōu)化生成器網(wǎng)絡(luò),生成與真實(shí)樣本分布更為接近的偽樣本數(shù)據(jù),訓(xùn)練學(xué)生網(wǎng)絡(luò),模型Accuracy值可達(dá)到0.968 7.本文提出的DG-DAFL 方法相比DAFL方法,避免了單一生成器網(wǎng)絡(luò)對(duì)教師網(wǎng)絡(luò)在無標(biāo)簽偽樣本集上判別結(jié)果過度信任所產(chǎn)生的無效先驗(yàn)優(yōu)化失敗問題,同時(shí),學(xué)生網(wǎng)絡(luò)端生成器在教師端生成器的輔助下產(chǎn)生更適合學(xué)生端生成器的訓(xùn)練樣本,保證生成樣本的多樣性,提升識(shí)別泛化性.同時(shí),RDSKD模型通過增加正則化項(xiàng)提升樣本多樣性,針對(duì)不同類樣本特征較為近似的MNIST 數(shù)據(jù)集取得了比DAFL與DFAD 模型更好的分類性能.DG-DAFL 模型中,學(xué)生網(wǎng)絡(luò)Accuracy值提升至0.980 9,其網(wǎng)絡(luò)性能十分接近教師網(wǎng)絡(luò),同時(shí),根據(jù)10 次實(shí)驗(yàn)運(yùn)行結(jié)果的均值與方差可知DG-DAFL 模型獲得了更好的魯棒性.
Table 1 Classification Results on MNIST Dataset表1 MNIST 數(shù)據(jù)集上的分類結(jié)果
2)AR 人臉數(shù)據(jù)集
AR 數(shù)據(jù)集為包含100 類的人臉數(shù)據(jù)集,由圖像尺寸為120×165 的2 600 張圖片組成,其中前50 類為男性樣本,后50 類為女性樣本,每類包含26 張人臉圖,包括不同的面部表情、照明條件、遮擋情況,是目前使用最為廣泛的標(biāo)準(zhǔn)數(shù)據(jù)集.在實(shí)驗(yàn)中,本文將每類的20 張圖片作為訓(xùn)練集,剩余的6 張作為測(cè)試集,通過此方式對(duì)網(wǎng)絡(luò)性能進(jìn)行評(píng)價(jià).AR 數(shù)據(jù)集可視化結(jié)果如圖6 所示.
本數(shù)據(jù)集實(shí)驗(yàn)中,利用ResNet34 作為教師網(wǎng)絡(luò),ResNet18 作為學(xué)生網(wǎng)絡(luò).ResNet34 與ResNet18 采用相同的5 層卷積結(jié)構(gòu),ResNet34 在每層卷積結(jié)構(gòu)中的層數(shù)更多,其所消耗的計(jì)算成本更高;ResNet34 的Flops 計(jì)算量為3.6×109,ResNet18 的Flops 計(jì)算量為1.8×109.表2中統(tǒng)計(jì)并對(duì)比了所提方法在AR 數(shù)據(jù)集上的Accuracy結(jié)果.
Fig.6 Sample visualization results of AR dataset圖6 AR 數(shù)據(jù)集的可視化結(jié)果
Table 2 Classification Results on AR Dataset表2 AR 數(shù)據(jù)集上的分類結(jié)果
實(shí)驗(yàn)統(tǒng)計(jì)結(jié)果如表2 所示.教師網(wǎng)絡(luò)經(jīng)包含真實(shí)標(biāo)簽數(shù)據(jù)集訓(xùn)練后Accuracy=0.865.Data-Free 環(huán)境下,DAFL 模型中經(jīng)知識(shí)蒸餾后學(xué)生網(wǎng)絡(luò)的Accuracy=0.676 7.AR 數(shù)據(jù)集相比MNIST 數(shù)據(jù)集,圖像類別數(shù)量提升,圖像復(fù)雜度及細(xì)節(jié)增加,不同類別間樣本特征分布更為近似,難以判別.DAFL 模型中生成器優(yōu)化過程完全依賴教師網(wǎng)絡(luò)判別結(jié)果,導(dǎo)致生成大量用于訓(xùn)練學(xué)生網(wǎng)絡(luò)的噪音樣本,使學(xué)生網(wǎng)絡(luò)判別準(zhǔn)確率與魯棒性下降.DFAD 模型忽略教師網(wǎng)絡(luò)對(duì)樣本生成所提供的先驗(yàn)信息,難以獲得與原訓(xùn)練樣本分布更為近似的生成樣本,極大影響學(xué)生網(wǎng)絡(luò)識(shí)別準(zhǔn)確率.RDSKD 模型面對(duì)的復(fù)雜特征樣本集同樣面臨未充分利用預(yù)訓(xùn)練教師網(wǎng)絡(luò)樣本先驗(yàn)信息,導(dǎo)致知識(shí)蒸餾效果下降,學(xué)生網(wǎng)絡(luò)的Accuracy僅為0.52.本文通過構(gòu)建雙生成器模型DG-DAFL,在充分利用教師網(wǎng)絡(luò)的潛在樣本先驗(yàn)知識(shí)的同時(shí),構(gòu)造生成器端損失,避免對(duì)誤差樣本信息過學(xué)習(xí),生成更有效且與真實(shí)樣本分布一致的偽樣本.在AR 較為復(fù)雜的數(shù)據(jù)集上,本文所提出的DG-DAFL 模型的Accuracy=0.718 3.
3)USPS 手寫體數(shù)據(jù)集
USPS 數(shù)據(jù)集為10 類別分類數(shù)據(jù)集,由像素大小為16×16 的9 298 張灰度圖像組成,該數(shù)據(jù)集相比于MNIST 數(shù)據(jù)集包含的樣本量更多,樣本尺寸更小,且樣本表達(dá)更為模糊、抽象,為識(shí)別帶來了困難,USPS數(shù)據(jù)集可視化結(jié)果如圖7 所示.本文實(shí)驗(yàn)中,隨機(jī)選取7 291 張與2007 張圖像分別構(gòu)建教師網(wǎng)絡(luò)的訓(xùn)練集與測(cè)試集.
Fig.7 Sample visualization results of USPS dataset圖7 USPS 數(shù)據(jù)集的可視化結(jié)果
教師網(wǎng)絡(luò)選擇與MNIST 數(shù)據(jù)集下相同的網(wǎng)絡(luò)結(jié)構(gòu)LeNet-5,學(xué)生網(wǎng)絡(luò)結(jié)構(gòu)為LeNet-5-half.表3 中統(tǒng)計(jì)并對(duì)比了所提出方法在USPS 數(shù)據(jù)集上的Accuracy結(jié)果.
Table 3 Classification Results on USPS Dataset表3 USPS 數(shù)據(jù)集上的分類結(jié)果
由表3 可知,教師網(wǎng)絡(luò)分類Accuracy=0.96,在此基礎(chǔ)上實(shí)現(xiàn)DAFL 模型.學(xué)生網(wǎng)絡(luò)的Accuracy=0.926 7.DFAD 模型在USPS 數(shù)據(jù)集上的Accuracy=0.889 9,由于教師網(wǎng)絡(luò)過度信任生成樣本集中包含的噪音等樣本,影響知識(shí)蒸餾效果及模型魯棒性.RDSKD 模型同樣存在忽略生成樣本質(zhì)量等問題,降低學(xué)生網(wǎng)絡(luò)準(zhǔn)確率.DG-DAFL 通過引入學(xué)生端生成器的雙生成器方法,解決單生成器網(wǎng)絡(luò)結(jié)構(gòu)中學(xué)生網(wǎng)絡(luò)訓(xùn)練過度依賴教師生成器網(wǎng)絡(luò)樣本產(chǎn)生的泛化性較低等問題.同時(shí),學(xué)生網(wǎng)絡(luò)生成器所生成的數(shù)據(jù)在保證分布近似條件下的樣本多樣性,進(jìn)一步提升學(xué)生網(wǎng)絡(luò)識(shí)別泛化性的基礎(chǔ)上,學(xué)生網(wǎng)絡(luò)在USPS 數(shù)據(jù)集下獲得了更高的準(zhǔn)確率及魯棒性.
1)DG-DAFL 消融分析
為進(jìn)一步討論所提DG-DAFL 模型中學(xué)生端生成器GS優(yōu)化過程的合理性及損失函數(shù)各部分的必要性,本節(jié)在MNIST 數(shù)據(jù)集上實(shí)現(xiàn)消融實(shí)驗(yàn)并分析實(shí)驗(yàn)結(jié)果.表4 統(tǒng)計(jì)并對(duì)比了不同損失函數(shù)部分對(duì)Data-Free 環(huán)境下模型準(zhǔn)確率的影響.
Table 4 Ablation Experiment Results on MNIST Dataset表4 MNIST 數(shù)據(jù)集上消融實(shí)驗(yàn)結(jié)果
在消融實(shí)驗(yàn)中,利用真實(shí)數(shù)據(jù)訓(xùn)練的教師網(wǎng)絡(luò)分類Accuracy=0.983 9;學(xué)生端生成器GS在沒有任何損失函數(shù)優(yōu)化的情況下,利用隨機(jī)生成樣本并結(jié)合教師網(wǎng)絡(luò)知識(shí)蒸餾,Accuracy達(dá)到0.868 7.若僅利用對(duì)隨機(jī)偽樣本判別結(jié)果所構(gòu)造的任一損失函數(shù),包括偽標(biāo)簽損失、信息熵?fù)p失、特征損失,優(yōu)化學(xué)生網(wǎng)絡(luò)生成器GS,均難以得到滿意的判別結(jié)果,其主要原因在于學(xué)生網(wǎng)絡(luò)判別器未經(jīng)過真實(shí)樣本訓(xùn)練不包含真實(shí)先驗(yàn)信息,難以指導(dǎo)生成器訓(xùn)練.若僅利用雙生成器端KL 散度作為優(yōu)化信息,教師端生成器GT經(jīng)教師網(wǎng)絡(luò)優(yōu)化包含部分真實(shí)樣本先驗(yàn)信息,可對(duì)GS生成樣本產(chǎn)生一定的先驗(yàn)監(jiān)督作用,輔助生成器GS生成相近的輸出分布,在KL 散度損失單獨(dú)優(yōu)化下,學(xué)生網(wǎng)絡(luò)性能有小幅度提升.當(dāng)3 種損失函數(shù)與生成器損失結(jié)合后,生成器GS獲得更多樣本先驗(yàn)信息,保證生成樣本與真實(shí)樣本的分布一致性,并保證生成偽樣本的多樣性,提升學(xué)生網(wǎng)絡(luò)模型的準(zhǔn)確率.
2)DG-DAFL 泛化性分析
為驗(yàn)證所提出的DG-DAFL 模型具有更好的泛化性,本文基于MNIST 數(shù)據(jù)集,構(gòu)建實(shí)驗(yàn)數(shù)據(jù)集MNIST-F(訓(xùn)練集Tra 與測(cè)試集Te).其中0~9 為類別編號(hào),由于樣本類別編號(hào)1 和7、0 和8、6 和9 等具有判別特征上的相似性,將混淆分類模型,為識(shí)別帶來難度.本文縮小易混淆類別訓(xùn)練樣本規(guī)模,具體將原始數(shù)據(jù)集中的訓(xùn)練樣本類別編號(hào)為1,6,8 的樣本量減半,測(cè)試數(shù)據(jù)量保持不變,其詳細(xì)描述如表5 所示,表5 中nTra 與nTe 分別為原始訓(xùn)練集與原始測(cè)試集.
Table 5 Description of Generalizability Test Dataset表5 泛化性測(cè)試數(shù)據(jù)集描述
數(shù)據(jù)集MNIST-F 實(shí)驗(yàn)中,教師網(wǎng)絡(luò)結(jié)構(gòu)為LeNet-5,學(xué)生網(wǎng)絡(luò)結(jié)構(gòu)為LeNet-5-half.本文分別統(tǒng)計(jì)及對(duì)比了DAFL 模型與所提出DG-DAFL 模型的分類Accuracy,結(jié)果如表6 所示.
Table 6 Classification Results on MNIST-F Dataset表6 MNIST-F 數(shù)據(jù)集上的分類結(jié)果
表6 所示的是不同算法在MNIST-F 數(shù)據(jù)集下的泛化性測(cè)試結(jié)果.DAFL 算法的Accuracy=0.942 5,DGDAFL 算法的Accuracy=0.969 5,相比在MNIST 數(shù)據(jù)集下的測(cè)試結(jié)果,DAFL 算法的Accuracy值下降0.026 2,DG-DAFL 的算法Accuracy值下降0.011 4,當(dāng)在易混淆類別訓(xùn)練不足的情況下,本文所提出的DG-DAFL模型相比DAFL 模型具有更好的泛化性和魯棒性.DG-DAFL 模型中的學(xué)生網(wǎng)絡(luò)NS的訓(xùn)練數(shù)據(jù)不完全依賴于教師端生成器GT,避免在DAFL 模型下由于函數(shù)one_hot構(gòu)建的偽樣本標(biāo)簽帶來的大量噪聲,解決學(xué)生網(wǎng)絡(luò)NS魯棒性的問題.為便于觀察與分析,本文統(tǒng)計(jì)并對(duì)比了DAFL 與DG-DAFL 模型在MNISTF 數(shù)據(jù)集上的其他評(píng)價(jià)標(biāo)準(zhǔn)結(jié)果,如表6 和表7 所示.
由表7 與表8 可知,泛化性測(cè)試下DG-DAFL模型總體上比DAFL 模型在精確率、召回率、特異度指標(biāo)上均有所提升.類別1,6,8 中訓(xùn)練樣本量減少為一半的情況下,本文所提出的模型DG-DAFL 在這3類上均獲得了更好的性能.原因在于DG-DAFL 模型下,訓(xùn)練數(shù)據(jù)由雙生成器生成,其更具多樣性,避免了單一生成器容易導(dǎo)致生成數(shù)據(jù)泛化性低的問題.
Table 7 Statistical Results of DAFL Model for Different Categories表7 DAFL 模型針對(duì)不同類別統(tǒng)計(jì)結(jié)果
Fig.8 Confusion matrix for teacher network generalization test圖8 教師網(wǎng)絡(luò)泛化性測(cè)試的混淆矩陣
Fig.9 Confusion matrix for DAFL generalization test圖9 DAFL 模型泛化性測(cè)試的混淆矩陣
Fig.10 Confusion matrix for DG-DAFL generalization test圖10 DG-DAFL 模型泛化性測(cè)試的混淆矩陣
圖8~10 通過MNIST-F 數(shù)據(jù)集下各類別的分類結(jié)果樣本量及誤分類樣本量的混淆矩陣,可更為清晰地觀察到DG-DAFL 模型的效果更加接近教師網(wǎng)絡(luò),分類效果較優(yōu).在真實(shí)標(biāo)簽為0,5,6,8,9 上的分類中,DAFL 模型比DG-DAFL 模型出現(xiàn)更多錯(cuò)誤分類,其原因?yàn)镈AFL 模型的訓(xùn)練數(shù)據(jù)僅依賴于教師網(wǎng)絡(luò),教師網(wǎng)絡(luò)生成的偽標(biāo)簽帶來大量噪聲影響生成器性能,降低學(xué)生網(wǎng)絡(luò)性能.DG-DAFL 模型中學(xué)生網(wǎng)絡(luò)的訓(xùn)練數(shù)據(jù)取決于教師端生成器和學(xué)生端生成器2 方面的影響,避免過度依賴教師網(wǎng)絡(luò)端生成器的情況,使得在DG-DAFL 模型的訓(xùn)練過程中,生成訓(xùn)練數(shù)據(jù)更加接近真實(shí)數(shù)據(jù),且保證生成圖像的多樣性.同時(shí),可觀察到DAFL 模型在易混淆的類別中將1 類樣本被誤分類為7 類樣本,0,6,8 類樣本由于模型泛化性較低而被互相混淆,產(chǎn)生錯(cuò)誤的分類.
本文針對(duì)Data-Free 環(huán)境中網(wǎng)絡(luò)壓縮及知識(shí)蒸餾問題,借助DAFL 模型通過構(gòu)建生成器獲得偽訓(xùn)練樣本的學(xué)習(xí)方式,提出DG-DAFL 網(wǎng)絡(luò)框架.該框架設(shè)計(jì)雙生成器網(wǎng)絡(luò)結(jié)構(gòu),保證教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)完成一致學(xué)習(xí)任務(wù),并實(shí)現(xiàn)樣本生成器與教師網(wǎng)絡(luò)分離,避免DAFL 模型中生成器完全信任教師網(wǎng)絡(luò)判別結(jié)果,產(chǎn)生失效優(yōu)化問題.同時(shí),在學(xué)生網(wǎng)絡(luò)生成器訓(xùn)練過程中,構(gòu)造雙生成器端偽樣本分布損失,在充分利用教師網(wǎng)絡(luò)潛在樣本分布先驗(yàn)信息的同時(shí)避免過度依賴,生成更具多樣性的偽樣本集.本文在3 個(gè)流行的數(shù)據(jù)集上驗(yàn)證了算法的有效性,并構(gòu)造數(shù)據(jù)集進(jìn)一步分析了算法的泛化性及魯棒性.然而,Data-Free 環(huán)境中生成的偽訓(xùn)練樣本的質(zhì)量將影響學(xué)生網(wǎng)絡(luò)性能,接下來本文工作將圍繞充分挖掘教師網(wǎng)絡(luò)預(yù)訓(xùn)練樣本結(jié)構(gòu)特征等先驗(yàn)知識(shí),構(gòu)建更高質(zhì)量的學(xué)生網(wǎng)絡(luò)訓(xùn)練樣本集.DG-DAFL 方法代碼及模型已開源:https://github.com/LNNU-computer-research-526/DG-DAFL.git.
作者貢獻(xiàn)聲明:張晶主要負(fù)責(zé)模型提出、算法設(shè)計(jì)及論文撰寫;鞠佳良負(fù)責(zé)算法實(shí)現(xiàn)、實(shí)驗(yàn)驗(yàn)證及論文撰寫;任永功負(fù)責(zé)模型思想設(shè)計(jì)及寫作指導(dǎo).