林振元,林紹輝+,姚益武,何高奇,王長波,馬利莊
1.華東師范大學(xué) 計算機(jī)科學(xué)與技術(shù)學(xué)院,上海 200062
2.北京大學(xué) 信息科學(xué)技術(shù)學(xué)院,北京 100871
知識蒸餾(knowledge distillation,KD)[1-5]是一種常見的模型壓縮方法,在大多數(shù)現(xiàn)有的KD 方法中,使用基于logits[1]或來自教師的特征信息[2]的方法可以將知識從教師網(wǎng)絡(luò)轉(zhuǎn)移到學(xué)生模型,但在這其中需要訪問整個訓(xùn)練數(shù)據(jù)。本文將這些KD 方法稱為數(shù)據(jù)驅(qū)動的KD 方法。然而在現(xiàn)實中,由于隱私、保密或傳輸限制,在蒸餾過程中原始訓(xùn)練樣本通常不可用。例如,患者的醫(yī)療數(shù)據(jù)是保密的,不會公開共享以泄露患者的隱私。如果沒有數(shù)據(jù)的幫助,這些方法在獲取不到原始數(shù)據(jù)的情況下將無法使用。
許多工作[6-8]使用生成對抗網(wǎng)絡(luò)研究無數(shù)據(jù)模型壓縮。然而,這些研究都關(guān)注于提高從特定的單一模型反演數(shù)據(jù)的性能,導(dǎo)致生成的數(shù)據(jù)缺乏多樣性和泛化性。
一方面,從某一特定模型反演知識會使合成圖像有偏差。由于生成的樣本是從單一的教師模型反演學(xué)習(xí)得到的,只含有教師網(wǎng)絡(luò)所包含的結(jié)構(gòu)先驗知識,導(dǎo)致這些合成的數(shù)據(jù)不能用于蒸餾到其他的模型。如圖1 所示,在相同的設(shè)定下分別將DAFL(data-free learning)[6]、DFQ(data-free quantization)[9]、DeepInversion[10]、CMI(contrastive model inversion)[7]方法合成的數(shù)據(jù)直接用于訓(xùn)練不同架構(gòu)的網(wǎng)絡(luò),實驗結(jié)果表明同一個方法得到的訓(xùn)練數(shù)據(jù)用于訓(xùn)練不同網(wǎng)絡(luò)時效果差異很大,而且與CIFAR-10 原始數(shù)據(jù)相比性能上仍存在較大的差距。以Inception-V3 為例,現(xiàn)有的方法CMI[7]所合成的數(shù)據(jù)與原始數(shù)據(jù)得到的性能仍然相差了10 個百分點,而且使用合成的數(shù)據(jù)來訓(xùn)練不同的網(wǎng)絡(luò)結(jié)構(gòu)很不穩(wěn)定,不同網(wǎng)絡(luò)的準(zhǔn)確率有較大的方差,說明先前的方法合成的數(shù)據(jù)可能包含了某一種網(wǎng)絡(luò)結(jié)構(gòu)的先驗知識以至于無法很好推廣適用于其他的模型的訓(xùn)練。因此,這種方法顯然無法拓展至多種網(wǎng)絡(luò)進(jìn)行壓縮。而使用不同的教師網(wǎng)絡(luò)進(jìn)行多次多個模型的壓縮將顯著增加多個模型的訓(xùn)練時間和數(shù)據(jù)內(nèi)存存儲。另外,Chen等[6]使用特定的教師模型(ResNet-34[11])合成數(shù)據(jù)去訓(xùn)練其他模型,例如ResNet-18、WRN-16-1,WRN-16-1的最終性能明顯低于ResNet-18 的性能。因此本文的目的在于所合成的數(shù)據(jù)可以直接用于訓(xùn)練其他結(jié)構(gòu)的網(wǎng)絡(luò)。
圖1 跨模型無數(shù)據(jù)蒸餾的結(jié)果概述Fig.1 Overview of results of cross-model data-free distillation
另一方面,目前的工作在判別器中使用信息熵[6]或?qū)W生-教師分歧[9]來生成多樣化的圖像,由于缺乏與歷史生成的圖像的比較,生成圖片的多樣性仍然有所欠缺。在這種情況下,該類算法在生成的圖像中會遇到重復(fù)模式,生成器極有可能生成與歷史實例高度相似的實例。
為了解決這些問題,本文提出了一種多教師對比知識反演的無數(shù)據(jù)蒸餾方法(multi-teacher contrastive knowledge inversion,MTCKI),圖2 描述了所提出方法的工作流程。MTCKI 算法在實際應(yīng)用中,也有著巨大的需求。例如,模型的供應(yīng)端(公司和企業(yè))是會有很多不同網(wǎng)絡(luò)架構(gòu)的預(yù)訓(xùn)練模型,而客戶端需要部署一個小模型在自己的終端設(shè)備上。本文提出了一種供應(yīng)端-客戶端合作的模式,供應(yīng)端將已經(jīng)訓(xùn)練好的多個教師模型提供給客戶,而不提供原始的訓(xùn)練數(shù)據(jù),而客戶端只通過這些訓(xùn)練好的教師網(wǎng)絡(luò)去得到一個學(xué)生網(wǎng)絡(luò)用于部署。單個學(xué)生可以訪問多個教師從而得到多個教師網(wǎng)絡(luò)提供的全面指導(dǎo),由此訓(xùn)練出的學(xué)生模型對模型偏差具有較強(qiáng)的魯棒性。本文首先提出了基于多教師集成的模型反演,充分反演來自教師的更豐富的信息以生成可泛化的數(shù)據(jù)。同時,本文進(jìn)一步提出了多教師和學(xué)生之間的對比交互正則化,其中包含教師內(nèi)對比和師生對比,以提高合成數(shù)據(jù)的多樣性。具體來說,教師內(nèi)部對比用于逐步合成具有與歷史樣本不同模式的新樣本。本文還提出了師生對比,師生對比旨在使得生成器合成的圖片能讓學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)映射到相同的表示空間中,采用對比學(xué)習(xí)的方法拉近同一物體的多視角表示,并區(qū)分開不同物體的特征。學(xué)生網(wǎng)絡(luò)學(xué)到的不僅是學(xué)生網(wǎng)絡(luò)所擅長提取的特征,比如鳥的嘴,還能從與教師網(wǎng)絡(luò)的表示的拉近過程中明白鳥的嘴、翅膀、眼睛、羽毛都可以被看作同一物體的不同視角,從而學(xué)習(xí)到更好的特征表示?;谝陨显?,生成器所合成的圖片融合了多視角的特征信息使得合成的圖片具有泛化性和多樣性,一次生成的圖片數(shù)據(jù)集能夠用于蒸餾或從頭訓(xùn)練多個不同的學(xué)生網(wǎng)絡(luò)。本文方法以對抗的方式訓(xùn)練圖像生成和知識轉(zhuǎn)移的過程,最終可以獲得高精度的學(xué)生模型和高質(zhì)量的生成數(shù)據(jù)。
圖2 多教師對比知識反演的無數(shù)據(jù)模型壓縮方法整體架構(gòu)Fig.2 Overall framework of multi-teacher contrastive knowledge inversion for data-free distillation
本文的主要貢獻(xiàn)總結(jié)如下:
(1)提出了一個新的無數(shù)據(jù)知識蒸餾框架,從多個可用的教師模型中提取“多視角”知識,同時提高學(xué)生模型精度和合成高質(zhì)量數(shù)據(jù)。
(2)設(shè)計了一種對比交互方式,充分利用來自多位師生的知識,生成具有高泛化性和多樣性的合成數(shù)據(jù)。一次生成的圖片數(shù)據(jù)集能夠用于蒸餾或從頭訓(xùn)練多個不同的學(xué)生網(wǎng)絡(luò)。
(3)實驗表明本文方法優(yōu)于現(xiàn)有的方法。本文方法不僅合成了更接近原始數(shù)據(jù)集分布的高保真數(shù)據(jù),而且還達(dá)到了與在原始數(shù)據(jù)集上訓(xùn)練的預(yù)訓(xùn)練模型相媲美的結(jié)果。
知識蒸餾[1]旨在通過從大型教師網(wǎng)絡(luò)轉(zhuǎn)移知識來提高小型學(xué)生網(wǎng)絡(luò)的性能。產(chǎn)生的知識可來自類后驗概率[1]或中間特征[1-2,7,12-13]。目前已有利用多個教師構(gòu)建更豐富和有啟發(fā)性的信息來訓(xùn)練學(xué)生模型的研究,其中知識來自集成logits[14-15]或特征[16-18]。例如,Lan 等[14]構(gòu)造了一個多分支結(jié)構(gòu),每個分支表示學(xué)生,并對每個分支進(jìn)行融合得到教師網(wǎng)絡(luò),將最終的集成logits作為蒸餾知識。You等[17]使用多個教師網(wǎng)絡(luò)的結(jié)合來提取不同實例中間層中的三元組排序關(guān)系,鼓勵與學(xué)生保持一致。本文方法在以下兩方面與之前的方法完全不同:(1)本文的框架以無數(shù)據(jù)的方式構(gòu)建,這相比之前的數(shù)據(jù)驅(qū)動的知識蒸餾更加靈活;(2)本文考慮了多位教師之間的內(nèi)部和相互關(guān)系,與基于多教師的知識蒸餾相比,它可以提取更豐富的知識進(jìn)行蒸餾。
無數(shù)據(jù)知識蒸餾的關(guān)鍵是在無需真實圖像的情況下進(jìn)行圖像合成。一般可以大致分為兩類:(1)在先驗知識上使用梯度下降直接學(xué)習(xí)圖像,如激活統(tǒng)計[19]和批量正則化(batch normalization,BN)統(tǒng)計[10];(2)對抗性訓(xùn)練以在噪聲輸入上學(xué)習(xí)生成器。DAFL[6]和DFQ[9]在第一階段使用生成對抗網(wǎng)絡(luò)(generative adversarial networks,GAN)生成圖像,可進(jìn)一步用于學(xué)習(xí)學(xué)生模型。最近,ZAQ(zero-shot adversarial quantization)[20]提出了一個兩級差異建模框架,用對抗的方式對學(xué)生和老師之間的中間特征進(jìn)行差異估計,并通過知識轉(zhuǎn)移來學(xué)習(xí)學(xué)生。訓(xùn)練后,無需重新訓(xùn)練即可同時獲得合成圖像和學(xué)生模型。ZeroQ[21]、Knowledge Within[22]以及MixMix[23]使用合成的數(shù)據(jù)集來執(zhí)行無數(shù)據(jù)量化。然而,這些方法是模型定制的,生成的圖像不能推廣到其他模型進(jìn)行蒸餾。與這些方法不同,本文方法提出了多教師和學(xué)生之間的對比交互,以生成高泛化和高多樣性的圖像。雖然MixMix[23]也利用多教師使用合成的數(shù)據(jù)集來執(zhí)行無數(shù)據(jù)量化,但本文方法利用最終特征信息和師生交互來更好地提高合成圖像的泛化性和多樣性。此外,學(xué)生和圖像生成的學(xué)習(xí)是以端到端的方式訓(xùn)練的,這與MixMix中的兩步訓(xùn)練完全不同。
對比學(xué)習(xí)[24-28]已廣泛應(yīng)用于無監(jiān)督學(xué)習(xí),能夠?qū)W習(xí)有效的特征表示以提高下游任務(wù)的性能。實例級對比是一種簡單而有效的策略,旨在將正樣本和錨點拉近,同時將其推離表示空間中的負(fù)樣本。例如,He 等[26]使用記憶庫來存儲來自動量編碼器的負(fù)樣本,并使用InfoNCE損失[27]從查詢編碼器和動量編碼器之間的表示中構(gòu)建對比。Chen等[24]用大批量數(shù)據(jù)替換記憶庫,讓兩個網(wǎng)絡(luò)在不同輸入增強(qiáng)上進(jìn)行對比。對比學(xué)習(xí)的思想同樣也有被應(yīng)用于知識蒸餾[29-31]。例如,Tian等[29]通過最大化教師和學(xué)生表示之間的互信息,將對比學(xué)習(xí)與知識蒸餾相結(jié)合。然而,這些方法中的對比知識是由真實數(shù)據(jù)和一個教師網(wǎng)絡(luò)形成的,然而,本文方法不需要任何真實數(shù)據(jù),只需構(gòu)建多教師和學(xué)生之間的對比。
為了更好地說明所提出的方法,本文首先使用一個預(yù)訓(xùn)練的教師網(wǎng)絡(luò)介紹了三個廣泛使用的模型反演損失。令fT(x,θt)和fS(x,θs)分別表示來自輸入圖像x的教師和學(xué)生編碼器的輸出,其中參數(shù)分別為θt和θs。由于預(yù)訓(xùn)練教師中給定固定參數(shù),本文通過省略θt將fT(x,θt)表示為fT(x)。=G(z,θg)是參數(shù)為θg的生成器G從噪聲輸入z合成的圖像。本文的目標(biāo)是通過減小教師網(wǎng)絡(luò)帶來的偏差來生成具有多樣性的高保真數(shù)據(jù)集,以替代原始圖像X。
(1)One-hot 預(yù)測損失。它用于生成器合成與教師網(wǎng)絡(luò)訓(xùn)練數(shù)據(jù)相兼容的圖像,使教師能夠?qū)Α首龀鰋ne-hot 的預(yù)測[5]。因此,給定一個預(yù)定義的類c,本文將one-hot預(yù)測損失表示為:
這里的CE是指交叉熵?fù)p失。
(2)BN 層中的特征正則化損失。BN 層已廣泛用于CNN,它通過在訓(xùn)練期間用平均統(tǒng)計量對特征圖進(jìn)行歸一化來緩解協(xié)變量偏移。訓(xùn)練后,這些統(tǒng)計數(shù)據(jù)存儲了有關(guān)X的豐富信息(例如:運行均值μ(x)和運行方差σ2(x))。因此,Yin 等[10]通過最小化所有層的和x的統(tǒng)計數(shù)據(jù)之間的距離來提出特征正則化:
(3)對抗蒸餾損失。通過對抗性蒸餾損失以鼓勵合成圖像使學(xué)生-教師產(chǎn)生較大的分歧[10,32-33],可以表示為:
其中,KL是KL散度,τ是溫度參數(shù)。
如上所述,本文整合了無數(shù)據(jù)蒸餾的基本框架,無數(shù)據(jù)蒸餾的整體模型反演損失可以通過組合公式(1)~(3)來表示:
其中,λi,i=1,2,3 是平衡參數(shù)。
文獻(xiàn)[34]提出了多視圖假設(shè),即“多視圖”結(jié)構(gòu)非常普遍存在于許多現(xiàn)實世界的數(shù)據(jù)集中。這些數(shù)據(jù)中存在多個特征,可用于正確分類圖像。例如,通過觀察翅膀、身體大小或嘴巴的形狀,可以將鳥類圖像分類為鳥類。模型往往只需要獲取一部分的特征,由于大部分的圖像可以被正確分類,模型便不再學(xué)習(xí)額外的特征。在現(xiàn)有的無數(shù)據(jù)蒸餾方法中,即使學(xué)生可以提取單一老師學(xué)習(xí)的所有特征,他們?nèi)匀粺o法“看到”該特定教師未發(fā)現(xiàn)的特征,從而限制了學(xué)生的表現(xiàn)。除此之外,由于圖像的合成受限于教師網(wǎng)絡(luò),生成器合成的圖像缺乏多視圖結(jié)構(gòu),以至于學(xué)生網(wǎng)絡(luò)難以看到物體的全部特征,這也就限制了合成數(shù)據(jù)的泛化性能。即使某些模型缺少單個學(xué)生可以學(xué)習(xí)多視圖知識的視圖,基于集成的方法也可以收集到大部分這些視圖。受文獻(xiàn)[14,34]的啟發(fā),本文首先考慮多個集成教師來構(gòu)建一個可靠的多分支模型。整體的框架如圖2所示,本文的框架包含多個教師網(wǎng)絡(luò)、一個學(xué)生網(wǎng)絡(luò)以及一個生成器。本文選擇所有教師的平均最終輸出作為模型預(yù)測,而不是按文獻(xiàn)[14]使用門控組件。此外,本文使用不同的教師來獲取各種統(tǒng)計知識,以提高合成圖像的多視圖結(jié)構(gòu),從而提升數(shù)據(jù)的泛化性能。因此,方程式中的模型反演損失式(4)可以重新表述為:
對比學(xué)習(xí)[23,25-26]以自監(jiān)督方式在特征表示上取得了巨大成功,可以有效地轉(zhuǎn)移到下游任務(wù),例如分割和目標(biāo)檢測。實例級對比是一種簡單而有效的策略,目的在于將錨點拉近正實例,同時將其推離表示空間中的負(fù)實例。MOCO(momentum contrast)[26]算法使用記憶庫(比如存儲來自歷史數(shù)據(jù)的特征)通過將當(dāng)前的實例與歷史存儲的實例的匹配來進(jìn)行對比,從而學(xué)習(xí)圖像特征表示。它啟發(fā)了本文使用記憶庫進(jìn)行對比學(xué)習(xí)來生成具有高度多樣性的數(shù)據(jù)。
受此啟發(fā),任意選取生成器合成的同一批圖像中的一張圖像為待測圖像,將待測圖像的表示和數(shù)據(jù)增強(qiáng)后的待測圖像的表示作為正樣本對,生成器合成的同一批圖像中待測圖像以外的圖像的表示作為負(fù)樣本,并將生成器合成的歷史圖像的表示作為負(fù)樣本。本文首先引入一個頭部投影網(wǎng)絡(luò)h將輸入投影到一個新的特征空間中。因此,本文可以獲得每個帶有參數(shù)的教師的輸出。本文遵循MOCO的流程,并通過InfoNCE[27]為每個教師編碼器獨立地構(gòu)造教師內(nèi)對比損失(intra-teacher contrastive loss),可以表示為:
教師內(nèi)對比損失可以幫助生成器逐步合成一些與歷史樣本不同的新樣本。然而,它只獨立考慮了教師的實例級對比,本文希望通過不同網(wǎng)絡(luò)對物體不同視圖下的特征關(guān)系進(jìn)行對比學(xué)習(xí),從而使得學(xué)生網(wǎng)絡(luò)以及生成器對于數(shù)據(jù)中的多視圖知識的分布學(xué)習(xí)到更好的表征。換句話說,同一個物體在不同視圖下的表征應(yīng)當(dāng)是相似的,不同物體的表征則遠(yuǎn)離?;谏鲜鏊枷?,學(xué)生網(wǎng)絡(luò)學(xué)到的不僅是學(xué)生網(wǎng)絡(luò)所擅長提取的特征,比如鳥的嘴,還能從與教師網(wǎng)絡(luò)的表示的拉近過程中明白鳥的嘴、翅膀、眼睛、羽毛都可以被看作同一物體的不同視角,從而學(xué)習(xí)到更好的特征表示。故本文進(jìn)一步提出了師生對比,旨在使生成器合成的圖片能讓學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)映射到相同的表示空間中,采用對比學(xué)習(xí)的方法拉近同一物體的多視角表示,并將不同物體的特征區(qū)分開來。首先,從當(dāng)前批次中的第i個圖像構(gòu)造學(xué)生的特征,表示為=h(fS(,θs),θh)。然后,本文將學(xué)生的特征和相同的第i圖像中教師的特征進(jìn)行拉近,并將和負(fù)實例的表示推遠(yuǎn),包括記憶庫和其他不包括當(dāng)前批次中的第i個圖像實例。因此,師生對比損失可以表述為:
其中,Neg是負(fù)樣本的集合,可以定義為:
這里,D(s)是教師網(wǎng)絡(luò)索引集,為學(xué)生模型輸出的歷史圖像記憶庫中的第j個負(fù)樣本的特征表示。通過結(jié)合式(9)和式(10),本文可以將多教師和學(xué)生之間的對比交互損失表示為:
本文通過最小化式(11)來反演出來自多個教師的更豐富的知識。它有效地生成具有多樣性和更真實的圖像。需要注意的是,與MOCO不同,本文的框架是以對抗的方式進(jìn)行訓(xùn)練,不需要動量編碼器。
本文方法包含兩個階段:通過生成器G生成圖像以及從教師蒸餾知識到學(xué)生網(wǎng)絡(luò)。對于圖像生成,本文結(jié)合了模型反演損失和對比交互損失Lci,可以表示為:
其中,λ是和Lci之間的平衡參數(shù)。對于知識蒸餾,本文的目標(biāo)是將知識從多教師集成的預(yù)測結(jié)果蒸餾到學(xué)生網(wǎng)絡(luò),則式(8)改為:
本文的框架在兩階段過程中進(jìn)行訓(xùn)練,如算法1所示,其中生成器和學(xué)生交替更新。在每次迭代中,首先訓(xùn)練生成器使得其輸出的圖片通入教師網(wǎng)絡(luò)后的統(tǒng)計量信息逼近存儲在教師BN層中的統(tǒng)計數(shù)據(jù),使得特征圖處于一個合理的范圍內(nèi)。隨后使用對比學(xué)習(xí)與歷史樣本進(jìn)行對比,融合教師網(wǎng)絡(luò)多視角的信息,并消除存儲在圖像中的模型結(jié)構(gòu)所帶來的偏差信息。然后訓(xùn)練學(xué)生網(wǎng)絡(luò)使其輸出與教師集合預(yù)測的輸出之間的距離最小化。通過交替更新學(xué)生和生成器,算法收斂到最優(yōu)點。
算法1多教師對比知識反演的算法
(1)數(shù)據(jù)集和模型。本文在不同的網(wǎng)絡(luò)架構(gòu)上評估提出的方法,包括ResNets[11]、帶BN 層的VGG[35]、WRN[36]、Inception-V3[37]和MobileNet-v2[38]。在3 個廣泛使用的數(shù)據(jù)集CIFAR-10、CIFAR-100 和Caltech-101[39]上進(jìn)行了實驗用于測試合成圖像的質(zhì)量,并訓(xùn)練教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)。本文選擇ResNet-34、VGG-11、WRN-40-2 和Inception-V3 作為教師模型。選擇WRN-16-1、ResNet-18、WRN-16-2、WRN-40-1 和VGG-8作為學(xué)生模型,并對其進(jìn)行評估。本文在表1中總結(jié)了這些在原始CIFAR-10/100和Caltech-101數(shù)據(jù)集上訓(xùn)練的教師的準(zhǔn)確率,其中“Ensemble”表示ResNet-34、VGG-11 和WRN-40-2 集成后的準(zhǔn)確率。將本文方法與現(xiàn)有的最先進(jìn)的方法DAFL[6]、DFQ[9]、Deepinv(deep inversion)[10]、CMI[7]進(jìn)行了比較。
表1 在不同數(shù)據(jù)集上預(yù)訓(xùn)練教師網(wǎng)絡(luò)的準(zhǔn)確率Table 1 Accuracy of pre-trained teachers on different datasets 單位:%
(2)實驗設(shè)置細(xì)節(jié)。本文使用PyTorch來實現(xiàn)提出的多教師對比知識反演,算法1中的優(yōu)化問題在具有24 GB顯存的NVIDIA GTX 3090 GPU上運行來進(jìn)行實驗。權(quán)重衰減設(shè)置為0.000 1,動量設(shè)置為0.9。對于數(shù)據(jù)集CIFAR-10 和CIFAR-100,本文將小批量(minibatch)大小、總訓(xùn)練回合(epoch)數(shù)和初始學(xué)習(xí)率分別設(shè)置為256、200和0.1。學(xué)習(xí)率在120、150、175和190 個epoch 上按0.1 的比例衰減。對于數(shù)據(jù)集Caltech-101,本文首先從原始數(shù)據(jù)集中隨機(jī)抽取20%的圖像作為測試集,并將所有圖像的大小調(diào)整為128×128。本文使用更大的生成器來合成圖像,教師數(shù)設(shè)置為3(在3.3 節(jié)中有對集成教師網(wǎng)絡(luò)個數(shù)的影響的分析)。
對于在數(shù)據(jù)集Caltech-101 上的實驗,將批量大小設(shè)定為32,合成圖像大小尺寸為128×128,epoch為400,學(xué)習(xí)率在250、300、350 和375 個epoch 上按0.1衰減,同時遵循了CMI中對于超參數(shù)的設(shè)定,λ1、λ2、λ3分別設(shè)置為0.5、1.0 和0.5,其余訓(xùn)練參數(shù)設(shè)置為與CIFAR-10/100 相同。對于超參數(shù)λ,本文使用[0.1,1.0]范圍內(nèi)的交叉驗證來確定多教師模型反演損失和對比交互損失之間的最佳權(quán)衡。
(3)生成器和頭部映射層的結(jié)構(gòu)。生成器G的內(nèi)部結(jié)構(gòu)由一個全連接層(fully connected layers,F(xiàn)C)、三個卷積層組成,其中一個卷積層是由一個卷積、批量歸一化和LeakyReLU 組成。輸入噪聲的維度設(shè)置為256。對于頭部投影架構(gòu),本文使用兩個全連接層將網(wǎng)絡(luò)的輸出表示映射到同樣的256維。
(4)評價指標(biāo)。本文選擇學(xué)生的準(zhǔn)確率和生成的圖像與原始數(shù)據(jù)之間的FID(Frechet inception distance score)作為評估標(biāo)準(zhǔn)。FID 是生成對抗網(wǎng)絡(luò)GAN 中常見的衡量指標(biāo),用于衡量兩個數(shù)據(jù)集的相似程度,分?jǐn)?shù)越低兩者的分布越接近。
本文在數(shù)據(jù)集CIFAR-10、CIFAR-100 和Caltech-101 上進(jìn)行實驗。CIFAR-10 是一個常用的分類數(shù)據(jù)集,圖像均勻分布在10個類別中。它總共有50 000張訓(xùn)練圖像和10 000張測試圖像,所有這些圖像的大小都是32×32 像素。CIFAR-100中的圖像與CIFAR-10相同,只是它們分為100個類別。Caltech-101是一個包含101個類別的圖像分類數(shù)據(jù)集。每個類別的樣本數(shù)量從40到800不等,每張圖像的大小約為300×200。
本文選擇ResNet-34、VGG-11和WRN-40-2作為本文的多個教師。在數(shù)據(jù)集CIFAR-10、CIFAR-100和Caltech-101中集成的預(yù)訓(xùn)練教師達(dá)到95.83%、80.08%和67.08%的準(zhǔn)確率。本文以定量和定性的方式將本文方法與最優(yōu)方法(state-of-the-art,SOTA)進(jìn)行比較。
(1)客觀指標(biāo)分析。表2記錄了本文方法和先前的方法在不同數(shù)據(jù)集CIFAR-10、CIFAR-100和Caltech-101上的比較結(jié)果。本文可以觀察到:①本文方法在所有3 個數(shù)據(jù)集上都優(yōu)于現(xiàn)有方法。例如,當(dāng)在CIFAR-10 數(shù)據(jù)集上蒸餾到相同的WRN-16-1 時,本文方法達(dá)到了91.59%的準(zhǔn)確率,比最佳的CMI 基線提高了2.49個百分點。對于CIFAR-100,在蒸餾到相同的WRN-16-2時,本文比CMI高出了2.08個百分點的準(zhǔn)確率。對于更復(fù)雜的場景Caltech-101,本文方法在蒸餾到MobileNet-V2 時與Deepinv 相比增加了3.89 個百分點的準(zhǔn)確率。②在本文所采用的多教師集成的準(zhǔn)確率和CIFAR-10 上的一個特定ResNet-34的準(zhǔn)確率(≈95.7%)幾乎一致時,本文方法在提取同一個學(xué)生時相比其他基線實現(xiàn)了顯著的性能提升。這也就表明模型性能的提升來自于多教師結(jié)構(gòu)和提出的對比交互損失,而不是簡單來自于強(qiáng)教師。③教師和學(xué)生之間的同構(gòu)結(jié)構(gòu)有助于提高學(xué)生在所有基線中的表現(xiàn)。例如,在CIFAR-10上,本文使用相同的WRN-16-1 作為學(xué)生,相比于ResNet-34 作為教師,WRN-40-2作為教師時顯著提高了學(xué)生WRN-16-1的準(zhǔn)確率。④值得注意的是,本文的預(yù)訓(xùn)練教師沒有使用MobileNet-V2,然而本文的合成圖像仍然可以有效地訓(xùn)練模型。而且本文方法已經(jīng)和使用原始數(shù)據(jù)訓(xùn)練的MobileNet-V2的準(zhǔn)確度非常接近。這意味著使用本文提出的多教師對比知識反演方法的合成圖像對于各種模型的訓(xùn)練具有很高的泛化性。⑤與其他方法相比,本文用不同的學(xué)生模型生成的數(shù)據(jù)集的FID值都是最低的,并且方差較小。這意味著本文的合成圖像與原始數(shù)據(jù)集最一致。本文方法在CIFAR-10 數(shù)據(jù)集上的FID 值(即≈52.20)甚至可以與一些使用原始數(shù)據(jù)的GAN方法[8]相媲美。
表2 在不同數(shù)據(jù)集上無數(shù)據(jù)蒸餾方法的結(jié)果Table 2 Results of data-free distillation on different datasets
(2)主觀視覺分析。本文進(jìn)一步將提出的方法與現(xiàn)有方法的合成質(zhì)量進(jìn)行比較,如圖3 所示。與DAFL[6]、DFQ[9]、Deepinv[10]、CMI[7]相比,可以明顯看出本文的多教師對比知識反演所生成的圖像質(zhì)量最高。例如,DAFL 使用CIFAR-10 數(shù)據(jù)集上的預(yù)訓(xùn)練教師生成的圖像類似噪聲圖像。Deepinv 能夠生成具有視覺特征的圖像,但物體顏色與背景顏色接近,風(fēng)格單一。因此,它與原始的CIFAR-10 數(shù)據(jù)集相距甚遠(yuǎn)。DFQ 和本文的合成圖像之間的比較表明,本文提出的方法可以生成更多樣化的圖像,而DFQ 則遇到了明顯的模式崩潰問題。盡管CMI合成的圖像在顏色和風(fēng)格上似乎有一些改進(jìn),但它們?nèi)匀贿^于模糊而無法區(qū)分。本文方法在對象輪廓的清晰度、顏色匹配的合理性方面提高了圖像質(zhì)量。對于CIFAR-10數(shù)據(jù)集,本文方法生成更多樣化的語義圖像,例如不同姿勢的馬的特寫和各種類型的卡車。即使是像船后面的天際線這樣的微小細(xì)節(jié)也能夠清晰生成。對于CIFAR-100數(shù)據(jù)集,合成圖像提供了豐富的語義信息,肉眼可以很輕松識別圖3中顯示的對象,如熊貓、自行車、鮮花。
圖3 不同方法反演生成的圖片展示Fig.3 Images inverted from pre-trained model by different methods
為了評估本文方法的有效性,包括多教師的引入,對比交互損失、泛化性和多樣性。本文選擇CIFAR-10數(shù)據(jù)集中的預(yù)訓(xùn)練模型進(jìn)行消融實驗。
(1)超參數(shù)λ的敏感性。本文首先評估λ的敏感性。如表3所示,本文對不同學(xué)生網(wǎng)絡(luò)設(shè)定下超參數(shù)敏感性做了實驗,發(fā)現(xiàn)當(dāng)λ設(shè)置為0.2 時蒸餾到不同的學(xué)生網(wǎng)絡(luò)能夠達(dá)到相對最佳的精度。為了方便討論,本文將所有實驗的λ設(shè)置為0.2。
表3 蒸餾到不同網(wǎng)絡(luò)結(jié)構(gòu)時的超參數(shù)λ對結(jié)果的影響Table 3 Effect of hyper-parameter λ for distilling student networks
(2)集成教師網(wǎng)絡(luò)個數(shù)的影響。本文進(jìn)行了多教師集成的幾種組合,其中教師的數(shù)量從1 到4。為了幫助學(xué)生學(xué)習(xí)更多樣化的知識,本文選擇了異構(gòu)教師網(wǎng)絡(luò),即不同網(wǎng)絡(luò)結(jié)構(gòu)的模型作為教師。如表4所示,更多的教師相對來說可以達(dá)到更高的準(zhǔn)確率。隨著教師數(shù)量的增加,學(xué)生和教師集成的測試準(zhǔn)確率的增長速度放緩,終于接近一個上限。當(dāng)教師數(shù)量設(shè)置為3,達(dá)到了相對飽和的性能??紤]到計算開銷,本文將實驗中多教師的網(wǎng)絡(luò)個數(shù)設(shè)定為3。
(3)對比交互損失的作用。本文研究了所提出的不同模塊的貢獻(xiàn),包括多教師、教師內(nèi)對比學(xué)習(xí)和師生對比學(xué)習(xí)。本文將每個模塊單獨關(guān)閉做cutoff來檢測其有效性。如表5所示,本文使用mt(multi-teacher)、itcl(intra-teacher contrastive learning)、tscl(teacher-student contrastive learning)分別代表多教師、教師內(nèi)對比學(xué)習(xí)和師生對比學(xué)習(xí)。實驗數(shù)據(jù)表明使用多教師進(jìn)行無數(shù)據(jù)蒸餾時直接將性能提高了5.7 個百分點。使用教師內(nèi)對比損失函數(shù)可提升性能4.43 個百分點。當(dāng)在多教師的基礎(chǔ)上加入教師內(nèi)對比損失時,WRN-16-1 的準(zhǔn)確率相比于原始方法達(dá)到了大約8 個百分點的增益。在此基礎(chǔ)上,本文進(jìn)一步添加了學(xué)生-教師對比損失,對性能實現(xiàn)了進(jìn)一步提升,使得本文的模型最終達(dá)到91.59%的準(zhǔn)確率。這是由于教師模型中提取“多視角”知識并將其很好地融合到學(xué)生模型中,同時使用了對比交互方式,充分利用來自多位師生的知識,生成具有高泛化性和多樣性的合成數(shù)據(jù)。
表5 不同組件在蒸餾過程中對算法的影響Table 5 Effect of different component combinations on algorithm during distillation
(4)合成數(shù)據(jù)的泛化性能分析。本文使用WRN-16-1作為學(xué)生,使用多教師對比知識反演方法得到的數(shù)據(jù)和CMI方法反演的數(shù)據(jù)從頭開始訓(xùn)練不同結(jié)構(gòu)的網(wǎng)絡(luò),由此來評估數(shù)據(jù)是否可以用于訓(xùn)練多種不同的網(wǎng)絡(luò)。為了公平比較,在這兩個方法合成數(shù)據(jù)時采用的訓(xùn)練參數(shù)和策略是相同的。
結(jié)果如表6 所示,與CMI 相比,本文方法實現(xiàn)了大幅提升(可高達(dá)8個百分點的提升)。此外,與原始CIFAR-10 數(shù)據(jù)集相比,使用本文方法的合成數(shù)據(jù)在從零開始訓(xùn)練教師方面達(dá)到了非常接近的準(zhǔn)確性。注意到本文并沒有使用Inception-V3 作為教師網(wǎng)絡(luò)之一,而本文的合成圖像仍可以有效地訓(xùn)練該模型。這意味著使用MTCKI的合成圖像對于各種模型的訓(xùn)練具有很高的泛化性。
表6 將合成數(shù)據(jù)直接用于從頭訓(xùn)練網(wǎng)絡(luò)效果對比Table 6 Comparison of training model from scratch using inverted data 單位:%
(5)數(shù)據(jù)多樣性分析。為了進(jìn)一步評估本文方法在數(shù)據(jù)多樣性方面的有效性,本文使用T-SNE[40]工具可視化MTCKI 和CMI 合成圖像的數(shù)據(jù)分布情況。如圖4 所示,對于本文方法,數(shù)據(jù)整體的分布較為分散,圖片的特征分布較廣,有效分開不同類別的數(shù)據(jù)分布,而具有相同類別的數(shù)據(jù)被很好地聚合。此分布與原始CIFAR-10 數(shù)據(jù)集十分接近。而CMI的數(shù)據(jù)點較為密集,圖片的特征較為相似,表明不同類別的合成圖像沒有被解開。與CMI 相比,本文方法表現(xiàn)出更好的數(shù)據(jù)多樣性。
圖4 CIFAR-10原始數(shù)據(jù)、CMI合成數(shù)據(jù)、MTCKI合成數(shù)據(jù)分布對比Fig.4 Distribution comparison among original CIFAR-10 data and data inverted by MTCKI and CMI
由于生成對抗的方法在收斂時可能會出現(xiàn)不穩(wěn)定的情況,本文進(jìn)一步分析了本文方法的收斂性和不同epoch 下圖像變化的情況。如圖5 所示,本文方法可以穩(wěn)定地收斂。與其他基線進(jìn)行了可視化比較,本文方法需要更少的訓(xùn)練epoch 來收斂,且收斂到的損失最低。值得注意的是,在訓(xùn)練過程中,由于豐富的多教師信息和對比交互的有效性,如圖6 所示,第10個epoch合成的圖像已經(jīng)具有多樣化的語義信息和組織良好的物體輪廓。除此之外,本文還客觀分析了對比交互損失對運算復(fù)雜度的影響,本文將其分成測試時間和訓(xùn)練時間兩部分。在測試時間上,加入對比交互損失不會對最終的測試時間有影響,因為該損失相當(dāng)于模型訓(xùn)練中的正則化項,測試過程中模型將不參與該部分計算。在訓(xùn)練時間上,對比交互損失確實會增加模型訓(xùn)練內(nèi)存和時間開銷。當(dāng)使用對比交互損失在單卡NVIDIA 3090GPU上訓(xùn)練200 個epoch,需花費16.6 h,而不使用對比交互損失在單卡NVIDIA 3090GPU上訓(xùn)練200個epoch需要11.9 h。雖然對比交互損失在訓(xùn)練上會增大開銷,但是在一次訓(xùn)練過程中合成的圖片可以用于多次從頭訓(xùn)練一個新的網(wǎng)絡(luò)或用于有數(shù)據(jù)的知識蒸餾且準(zhǔn)確率相比先前的方法都有較大提升,一定程度上節(jié)省了后續(xù)的開銷,并提高了模型精度。
圖5 不同方法在訓(xùn)練過程中的損失曲線對比Fig.5 Training loss curves of different methods during training
圖6 不同回合階段的合成圖片的質(zhì)量Fig.6 Quality of generated images in different epochs
本文提出了一種基于多教師對比知識反演的無數(shù)據(jù)知識蒸餾框架(MTCKI),該框架在提高學(xué)生網(wǎng)絡(luò)表現(xiàn)的同時,以對抗的方式生成高保真度的訓(xùn)練數(shù)據(jù)。首先,本文提出了一種供應(yīng)端-客戶端合作的模式,用于數(shù)據(jù)保護(hù)下的模型壓縮,然后構(gòu)建了一個新的無數(shù)據(jù)知識蒸餾框架,從多個教師模型中提取“多視角”知識并將其很好地融合到學(xué)生模型中。此外,本文建立了多教師和學(xué)生之間的對比交互以提高合成圖像的多樣性。本文提出的MTCKI能將一次生成的圖片數(shù)據(jù)用于蒸餾或從頭訓(xùn)練多個不同的學(xué)生網(wǎng)絡(luò)。本文綜合評估了MTCKI 在各種CNN 架構(gòu)上的性能,實驗結(jié)果表明,MTCKI 不僅生成視覺上效果不錯的圖像,而且在性能上優(yōu)于現(xiàn)有的無數(shù)據(jù)蒸餾方法。