摘要: 針對目前小樣本圖像分類推斷置信度有待提高的問題, 提出一個新的結(jié)合元置信轉(zhuǎn)導(dǎo)推理、 數(shù)據(jù)混淆方法和按特征線性調(diào)制方法的模型. 首先, 利用轉(zhuǎn)導(dǎo)推理在訓(xùn)練過程中能學(xué)習(xí)到推斷數(shù)據(jù)的性質(zhì), 可以有針對性地學(xué)習(xí); 其次, 在網(wǎng)絡(luò)結(jié)構(gòu)中結(jié)合數(shù)據(jù)混淆方法, 加強(qiáng)對關(guān)鍵特征的提取, 提升模型的特征發(fā)現(xiàn)能力; 最后, 在轉(zhuǎn)導(dǎo)推理框架中加入按特征線性調(diào)制變換以改進(jìn)模型的小樣本查詢能力. 在標(biāo)準(zhǔn)數(shù)據(jù)集Mini-ImageNet和Tiered-ImageNet上進(jìn)行實(shí)驗(yàn)的結(jié)果表明, 該模型在這兩個數(shù)據(jù)集上執(zhí)行5-way 1-shot任務(wù)時準(zhǔn)確率分別提升了3.21,3.36個百分點(diǎn), 在5-way 5-shot任務(wù)上準(zhǔn)確率分別提升了2.89,1.89個百分點(diǎn). 實(shí)驗(yàn)結(jié)果驗(yàn)證了該方法的有效性.
關(guān)鍵詞: 小樣本學(xué)習(xí); 轉(zhuǎn)導(dǎo)推理; 數(shù)據(jù)擾動; 按特征線性調(diào)制變換
中圖分類號: TP39""文獻(xiàn)標(biāo)志碼: A""文章編號: 1671-5489(2024)06-1439-08
Transductive Inference Based Improvement for Few-Shot Learning
FU Haitao, JIN Chenlei, YANG Yajie, FENG Yuxuan
(College of Information Technology, Jilin Agricultural University, Changchun 130118, China)
Abstract: Aiming at the problem of the need to improve "confidence level in few-shot image classification inference at present, we proposed
a new model that combined meta-confidence transductive inference, data obfuscation method, and feature-wise linear modulation method. Firstly, by using transductive inference, the model could
learn properties of inference data during training process, and achieve targeted learning. Secondly, combining "data obfuscation methods "in the network architecture
to enhance the extraction of key features, and "improve the feature discovery ability of the "model. "Finally, feature-wise linear modulation transformation was added to "the transductive
inference framework to improve the model’s few-shot query capabilities. The results of experiments conducted on "standard datasets Mini-ImageNet and Tiered-ImageNet show
that the model improves "accuracy by "3.21 and 3.36 percentage points respectively when performing "5-way 1-shot tasks on these two datasets, and by 2.89 and 1.89 percentage points "respectively
on 5-way 5-shot tasks. The experimental results validate the effectiveness of the proposed method.
Keywords: few-shot learning; transductive inference; data perturbation; feature-wise linear modulation transformation
當(dāng)前深度神經(jīng)網(wǎng)絡(luò)的發(fā)展主要依靠大量標(biāo)注樣本的監(jiān)督學(xué)習(xí)[1]. 在實(shí)際應(yīng)用中, 數(shù)據(jù)標(biāo)注是一項(xiàng)勞動密集型工作, 難以保證數(shù)據(jù)質(zhì)量. 當(dāng)某類別的樣本數(shù)量不足或類別分布不均衡時會影響深度神經(jīng)網(wǎng)絡(luò)在該類別上的分類性能[2]. 受人類強(qiáng)大學(xué)習(xí)能力的啟發(fā), 在樣本不足或分布不均衡情況下有效學(xué)習(xí)的深度學(xué)習(xí)模型, 即為小樣本學(xué)習(xí)[3].
在小樣本學(xué)習(xí)中, 常采用的數(shù)據(jù)增強(qiáng)[4]、 遷移學(xué)習(xí)[5]和元學(xué)習(xí)等方法, 或是這些方法的組合. 數(shù)據(jù)增強(qiáng)是一種直觀的解決方案, 通過上采樣擴(kuò)展數(shù)據(jù)集, 可增加數(shù)據(jù)量并平衡類別. 例如, 采用隨機(jī)過采樣和參數(shù)化過采樣進(jìn)行數(shù)據(jù)增強(qiáng), 再結(jié)合元學(xué)習(xí)進(jìn)行求解設(shè)計(jì), 能在一定程度上改善小樣本學(xué)習(xí)[6]. 基于元學(xué)習(xí)的原型網(wǎng)絡(luò)和孿生網(wǎng)絡(luò)聯(lián)合框架也取得了良好的結(jié)果[7]. 在結(jié)合數(shù)據(jù)增強(qiáng)的小樣本學(xué)習(xí)中, "除對訓(xùn)練數(shù)據(jù)進(jìn)行增強(qiáng)外, 還有基于特征的數(shù)據(jù)增強(qiáng)方案. 假設(shè)同一類型的數(shù)據(jù)特征服從同一高斯分布, 可先基于該假設(shè)對支持集特征進(jìn)行聚類, 再通過元學(xué)習(xí)對查詢集特征進(jìn)行分類, 最后在同分布特征中采樣, 實(shí)現(xiàn)按特征的數(shù)據(jù)增強(qiáng)也實(shí)現(xiàn)了較好的性能[8].
遷移學(xué)習(xí)通過將在源領(lǐng)域數(shù)據(jù)中學(xué)習(xí)到的模型遷移到目標(biāo)領(lǐng)域, 從而提高目標(biāo)領(lǐng)域的學(xué)習(xí)性能. 其典型流程分為預(yù)訓(xùn)練和微調(diào)兩個階段. 預(yù)訓(xùn)練使用相對充足的源領(lǐng)域數(shù)據(jù)進(jìn)行訓(xùn)練, 微調(diào)則利用目標(biāo)領(lǐng)域數(shù)據(jù)對網(wǎng)絡(luò)進(jìn)行調(diào)整以適應(yīng)特定任務(wù). 如殘差遷移網(wǎng)絡(luò)(residual transfer networks, RTN)就是通過遷移源領(lǐng)域的特征學(xué)習(xí)提升模型在目標(biāo)領(lǐng)域的泛化能力
, 表現(xiàn)了良好的小樣本學(xué)習(xí)效果[9].
元學(xué)習(xí), 尤其是度量學(xué)習(xí), 則是一類通過自動學(xué)習(xí)知識解決相似性分類問題的方法. 度量學(xué)習(xí)通過計(jì)算待分類樣本與已知樣本之間的相似度, 在度量空間內(nèi)完成分類推斷[10]. 典型方法有Koch等[11]提出的基于孿生網(wǎng)絡(luò)的單樣本學(xué)習(xí)方法, 以及Vinyals等[12]提出的匹配網(wǎng)絡(luò), 都是利用相似度度量實(shí)現(xiàn)小樣本分類. 原型網(wǎng)絡(luò)通過比較類別原型之間的距離進(jìn)行分類[13], 而關(guān)系網(wǎng)絡(luò)則通過神經(jīng)網(wǎng)絡(luò)輸出相似度得分[14]進(jìn)行分類.
基于以上小樣本學(xué)習(xí)技術(shù), 為擴(kuò)展深度神經(jīng)網(wǎng)絡(luò)的知識學(xué)習(xí)能力, 本文采用基于元學(xué)習(xí)的轉(zhuǎn)導(dǎo)推理方法, 對小樣本學(xué)習(xí)進(jìn)行改進(jìn). 首先, 引入?yún)^(qū)域混淆, 通過將輸入圖像劃分為多個區(qū)域, 按區(qū)域進(jìn)行隨機(jī)排列, 增強(qiáng)模型在少量數(shù)據(jù)中進(jìn)行特征學(xué)習(xí)的能力, 改進(jìn)小樣本學(xué)習(xí)效果; 其次, 設(shè)計(jì)骨干網(wǎng)絡(luò)中的按特征線性調(diào)制變換(feature-wise linear modulation, FiLM), 通過對神經(jīng)網(wǎng)絡(luò)中間特征進(jìn)行仿射變換, 訓(xùn)練提高轉(zhuǎn)導(dǎo)推理結(jié)果的置信度, 從而提升小樣本學(xué)習(xí)性能.
1"預(yù)備知識
1.1"小樣本學(xué)習(xí)問題
在小樣本學(xué)習(xí)任務(wù)中, 為解決在樣本數(shù)據(jù)量小的情況下對新類別進(jìn)行準(zhǔn)確分類的問題, 通常使用基類數(shù)據(jù)集Dbase訓(xùn)練模型, 并使模型能泛化到新類Dnovel圖像數(shù)據(jù)上完成分類, 基類數(shù)據(jù)集與新類數(shù)據(jù)集相互獨(dú)立, Dbase∩Dnovel=. 在小樣本學(xué)習(xí)的圖像分類問題中, 每個元任務(wù)由支持集和查詢集組成, 其中: 支持集為S={{xi,s}Ks=1,yi}Ni=1, 支持集中類別數(shù)為N, 每種類別包含K個標(biāo)記樣本; 查詢集為Q={{xi,q}Qq=1,yi}Ni=1, 查詢集和支持集來自同分布的N個類別, 每個類別中包含Q個測試樣本. 小樣本學(xué)習(xí)也稱為N-way K-shot分類問題.
1.2"轉(zhuǎn)導(dǎo)推理
通過對已有的訓(xùn)練用例進(jìn)行表征學(xué)習(xí), 并據(jù)此進(jìn)行推斷是常見的監(jiān)督學(xué)習(xí)范式, 屬于歸納推理[15]. 轉(zhuǎn)導(dǎo)推理與歸納推理同屬監(jiān)督學(xué)習(xí), 其核心概念在于使用已知的訓(xùn)練樣本學(xué)習(xí)對測試樣本的推斷[16]. 在模型的訓(xùn)練階段引入訓(xùn)練集和測試集, 以實(shí)現(xiàn)對這兩個數(shù)據(jù)集的充分利用. 相對于歸納推理有一定的推斷和遷移的性能提升, 當(dāng)數(shù)據(jù)改變時訓(xùn)練要重新進(jìn)行. 轉(zhuǎn)導(dǎo)推理常見的方法主要包括轉(zhuǎn)導(dǎo)支持向量機(jī)(trans-support vector machine, TSVM)和標(biāo)簽傳播算法(label propagation algorithm, LPA). 轉(zhuǎn)導(dǎo)支持向量機(jī)假設(shè)數(shù)據(jù)集合與其對應(yīng)的標(biāo)簽集合之間存在特定的幾何關(guān)系, 與支持向量機(jī)相同也致力于尋找具有分類能力的超平面. 標(biāo)簽傳播算法是在已知部分?jǐn)?shù)據(jù)的情況下, 根據(jù)數(shù)據(jù)之間的特征相似性推斷數(shù)據(jù)集中缺失標(biāo)簽的機(jī)器學(xué)習(xí)方法, 算法包含了特征傳播和標(biāo)簽傳播兩部分. 它們的目的都是通過有標(biāo)記與無標(biāo)記的數(shù)據(jù)進(jìn)行綜合考慮, 從而提升識別效果. 在該過程中, 利用有標(biāo)簽的數(shù)據(jù)確定支持向量, 而未標(biāo)記的數(shù)據(jù)被用來描述數(shù)據(jù)的分布情況, 以提高分類效果.
轉(zhuǎn)導(dǎo)推理通過學(xué)習(xí)的方式使模型對輸入數(shù)據(jù)進(jìn)行轉(zhuǎn)換、 分類、 標(biāo)注、 生成等處理, 同時利用神經(jīng)網(wǎng)絡(luò)對多個任務(wù)進(jìn)行權(quán)值調(diào)節(jié), 使目標(biāo)函數(shù)達(dá)到最優(yōu). 在計(jì)算機(jī)視覺、 圖像分類、 自然語言處理和生物信息學(xué)等領(lǐng)域應(yīng)用廣泛.
1.3"元置信轉(zhuǎn)導(dǎo)網(wǎng)絡(luò)模型
元置信轉(zhuǎn)導(dǎo)(meta-confidence transduction, MCT)[17]網(wǎng)絡(luò)在解決小樣本問題上, 先利用轉(zhuǎn)導(dǎo)推理得到查詢樣本的置信度加權(quán)平均值, 再通過得到的值對查詢樣本的類別原型進(jìn)行更新. 但在現(xiàn)實(shí)問題中, 查詢樣本可能來自一個未知的分布領(lǐng)域, 這樣通過轉(zhuǎn)導(dǎo)推理得到的值可能不可靠, 導(dǎo)致預(yù)測錯誤. 為解決該問題, 引入了元學(xué)習(xí)方法, 先利用元學(xué)習(xí)中的距離度量給未標(biāo)記的每個類的查詢樣本分配置信度得分, 再對獲得置信度的原網(wǎng)絡(luò)進(jìn)行更新, 從而提高模型對未知分布查詢樣本的轉(zhuǎn)導(dǎo)推理能力.
在小樣本學(xué)習(xí)任務(wù)中, 即使對模型進(jìn)行了元學(xué)習(xí), 但由于數(shù)據(jù)的稀缺問題, 模型置信度本質(zhì)上還是不可靠的. 為輸出更可靠的置信度, 可以強(qiáng)制模型輸出一致性的預(yù)測, 同時擾動模型或數(shù)據(jù). 數(shù)據(jù)擾動共有兩個置信度, 一個置信度來自原始圖像, 另一個是對原始圖像進(jìn)行水平翻轉(zhuǎn)后的圖像. 采用這種方法可以在不損失信息的前提下, 對數(shù)據(jù)施加一定的干擾, 從而得到訓(xùn)練與測試中的擾動置信度. 模型中的擾動則是通過偽隨機(jī)過程決定是否刪除某些網(wǎng)絡(luò)模塊帶來的.
2"基于MCT的小樣本方法
2.1"區(qū)域混淆
人類有小樣本學(xué)習(xí)和細(xì)粒度學(xué)習(xí)的能力, 如在僅獲取少量局部區(qū)域樣本的條件下, 就能區(qū)分更有代表性的樣本特征, 從而實(shí)現(xiàn)分類和識別任務(wù). 例如旅游愛好者能通過少量圖像的區(qū)域特征對景點(diǎn)進(jìn)行分類. 相比于自然語言處理, 在監(jiān)督學(xué)習(xí)條件下對比亂序前后的句子可以使模型注意到區(qū)分度高的詞, 在視覺領(lǐng)域, 將圖的局部區(qū)域擾亂前后的數(shù)據(jù)采用同一個標(biāo)簽, 端到端地訓(xùn)練網(wǎng)絡(luò)也可以使模型在分類任務(wù)中更好地聚焦在關(guān)鍵分類區(qū)域上. 即先把輸入圖像劃分為多個局部區(qū)域部分, 再利用區(qū)域融合算法(region confusion mechanism, RCM)[18]對局部區(qū)域進(jìn)行打亂. 通過加強(qiáng)對細(xì)節(jié)特征的學(xué)習(xí)充分利用小樣本的數(shù)據(jù).
RCM擾亂圖像局部區(qū)域的空間分布如圖1所示. 首先將圖像I切分為N×N個子區(qū)域, 用Ri,j標(biāo)記每個區(qū)域, 其中i,j分別表示圖像塊在原圖中的行列位置, i和j的取值范圍為[0,N). 對局部區(qū)域的空間分布進(jìn)行打亂, 其實(shí)就是為R的第j行生成一個隨機(jī)向量qj, 而第i個元素的值為qj,i=i+r, r的取值范圍為[-b,b], 且符合均勻分布U(-b,b), 0≤blt;N, 其中b是一個可調(diào)整的參數(shù). 在第j行通過對qj進(jìn)行排序獲得新區(qū)域的排列σrowj, 數(shù)量關(guān)系要滿足:
σrowj(i)-ilt;2b,""i∈{1,2,…,M},(1)
也可以用上述關(guān)系進(jìn)行驗(yàn)證.
同理, 在行向量和列向量上做完該操作后即可得到一個新的“破壞”后的圖像. 對區(qū)域逐列地應(yīng)用變換σcoli:
σcol(j)-jlt;2b,""j∈{1,2,…,M}.(2)
原圖中(i,j)區(qū)域被重構(gòu)后新的坐標(biāo)位置為
σ(i,j)=(σrowj(i),σcoli(j)).(3)
該方法對原始圖像的區(qū)域分布進(jìn)行重構(gòu)后, 對不同類別之間提取細(xì)微差別的能力得到了提高, 解決了數(shù)據(jù)來自未知分布時, 置信度預(yù)測不可靠的問題.
2.2"按特征的線性調(diào)制變換
按特征的線性調(diào)制變換(FiLM)方法[19]來自于視覺問答, 使用該結(jié)構(gòu)在神經(jīng)網(wǎng)絡(luò)的中間表示數(shù)據(jù)上進(jìn)行仿射變換, 該模式可在視覺問答實(shí)踐中按條件進(jìn)行選擇特征.
即通過條件信息的特征仿射變換影響神經(jīng)網(wǎng)絡(luò)的計(jì)算, 小樣本學(xué)習(xí)中支持集/查詢集的學(xué)習(xí)模式與視覺問答模式相仿, 所以在小樣本學(xué)習(xí)模型上可引入這類變換方法. 該方法通過學(xué)習(xí)兩個任意的函數(shù)f和h改進(jìn)特征學(xué)習(xí), 這兩個函數(shù)也可以通過神經(jīng)網(wǎng)絡(luò)分支實(shí)現(xiàn), 它們共享參數(shù)便于更好地學(xué)習(xí), 本文將這兩個網(wǎng)絡(luò)作為一個網(wǎng)絡(luò), 類似于孿生網(wǎng)絡(luò)的設(shè)計(jì). 最終通過可學(xué)習(xí)參數(shù)γi,c和βi,c影響網(wǎng)絡(luò):
γi,c=fc(xi),(4)βi,c=hc(xi),(5)
其中下標(biāo)表示第i個輸入的第c個特征. γi,c和βi,c用來調(diào)節(jié)神經(jīng)網(wǎng)絡(luò)中的激活Fi,c, 然而調(diào)節(jié)特征圖僅需兩個參數(shù), 計(jì)算效率較高, 在圖像風(fēng)格化、 視覺回答和語言識別領(lǐng)域表現(xiàn)優(yōu)異. 本文將其用在查詢集與支持集的相似性比較上:
FiLM(Fi,cγi,c,βi,c)=γi,cFi,c+βi,c.(6)
一個卷積神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)上的單一FiLM層如圖2所示, 在MCT主干網(wǎng)絡(luò)中進(jìn)行FiLM變換, 可以自適應(yīng)地調(diào)整特征圖, γ和β的不同組合能實(shí)現(xiàn)多種方式對特征的調(diào)整, 其中⊙表示Hadamard乘積, 是一種常用的矩陣計(jì)算方法. 將兩個行、 列數(shù)相同的矩陣A和B對應(yīng)位置元素相乘后的結(jié)果放入矩陣C中的對應(yīng)位置, 可得到Hadamard乘積C: C(i,j)=A(i,j)⊙B(i,j), 其中i和j表示行列.
2.3"整體模型設(shè)計(jì)
本文對原始的MCT網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行優(yōu)化, 以提高模型的泛華能力和分類性能, 改進(jìn)的MCT網(wǎng)絡(luò)結(jié)構(gòu)如圖3所示. 在支持集輸入階段, 采用區(qū)域混淆算法對圖像的區(qū)域進(jìn)行打亂, 以增強(qiáng)數(shù)據(jù)的多樣性和模型的泛化能力.
在數(shù)據(jù)擾動后, 引入FiLM變換, 根據(jù)輸入數(shù)據(jù)的特征動態(tài)調(diào)整網(wǎng)絡(luò)的權(quán)重. 為更好地使網(wǎng)絡(luò)模型理解并區(qū)分不同的特征, 提高分類的準(zhǔn)確性, 利用距離度量進(jìn)行學(xué)習(xí). 在查詢集
輸入階段, 采用與支持集相同的網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行特征提取, 以確保特征的一致性. 為提高模型在處理新數(shù)據(jù)時能快速準(zhǔn)確地進(jìn)行分類, 提取到的特征進(jìn)行特征嵌入, 最終與支持集數(shù)據(jù)的特征進(jìn)行k-means聚類度量.
2.4"算法流程
基于MCT方法的小樣本學(xué)習(xí)網(wǎng)絡(luò)算法流程如下.
為捕捉數(shù)據(jù)的不確定性, 數(shù)據(jù)經(jīng)過區(qū)域擾動后把原有圖像劃分為多個區(qū)域后進(jìn)行重構(gòu), 重構(gòu)后的數(shù)據(jù)經(jīng)過模型擾動后進(jìn)行FiLM變換, 通過對中間特征應(yīng)用條件信息的仿射變換影響轉(zhuǎn)導(dǎo)推理的置信度.
對查詢樣本進(jìn)行元學(xué)習(xí), 能產(chǎn)生一個可以提高性能的距離度量值dθ, 其中θ表示距離度量關(guān)系由可學(xué)習(xí)的參數(shù)θ決定. 設(shè)a1,a2∈, 使用歐氏距離定義有歸一化的按實(shí)例度量縮放gIθ和成對度量縮放gPθ, 分別表示為
dIθ(a1,a2)=a1‖a1‖2gIθ(a1)-a2‖a2‖2gIθ(a2)22,(7)
dPθ(a1,a2)=a1‖a1‖2gPθ(a1,a2)-a2‖a2‖2gPθ(a1,a2)22.(8)
為得到最優(yōu)距離函數(shù)gθ∈{gIθ,gPθ}, 計(jì)算查詢樣本概率后, 優(yōu)化距離函數(shù)gθ中θ的值, 使dθ∈{dIθ,dPθ}的損失盡量小, 其中dIθ為按實(shí)例度量縮放的距離度量值, dPθ為按成對度量縮放的距離度量值.
LτI(θ,φ)是一個期望損失函數(shù), 在元學(xué)習(xí)過程中用于優(yōu)化模型參數(shù), 并在所有任務(wù)上取平均值, 其中τ表示任務(wù)索引
或任務(wù)分布, θ為定義自適應(yīng)距離度量參數(shù), φ表示模型參數(shù). 在元學(xué)習(xí)中, 通常需要在多個任務(wù)上訓(xùn)練模型, 每個任務(wù)τ可能包含不同的數(shù)據(jù)分布或不同
的類別. I表示損失函數(shù)的類型或?qū)嵗?LI表示在特定任務(wù)τ上對每個實(shí)例進(jìn)行預(yù)測時的損失, 表示集合的大小, 用來對所有樣本的損失進(jìn)行平均. LτI(θ,φ)定義為
LτI(θ,φ)= "1∑(,)∈-log p(yx,s;θ,?)
= "1∑(,)∈d?(fθ(x),P(T)c)+∑
Cc′=1exp{-d?(fθ(x),P(T)c′)},(9)
其中: -log p(yx,s;θ,?)表示Softmax歸一化項(xiàng); (x,y)表
示輸入樣本及其對應(yīng)標(biāo)簽; (yx,s;θ,?)表示給定輸入x和支持集s時, 模型預(yù)測樣本x屬于類別y的概率, 此概率通過Softmax函數(shù)計(jì)算得到; 距離項(xiàng)dφ(fθ(x),P(T)c)是模型預(yù)測的嵌入fθ(x)與類別原型P(T)c之間的距離, 該距離度量是通過參數(shù)θ學(xué)習(xí)的, 可以是歐氏距離; exp{-d?(fθ(x),P(T)c′)}是模型對類別c′預(yù)測概率的指數(shù)部分, d?是通過參數(shù)?學(xué)習(xí)的距離度量函數(shù), fθ(x)是模型的嵌入函數(shù), 將輸入x映射到嵌入空間中, P(T)c是類別c的原型, 是在第T步更新后的類別中心.
在此基礎(chǔ)上, 使用Soft k-means算法描述轉(zhuǎn)導(dǎo)推理得到置信度加權(quán)平均值后對原始網(wǎng)絡(luò)進(jìn)行更新的主要過程. Soft k-means和Hard k-means這兩種算
法, 前者得到一個置信度, 類似于概率; 后者得到0和1, 0表示不屬于該類別, 1表示屬于該類別.
定義一個包含支持集S和查詢集Q的集, 并將Sc定義為類c的支持集, Qx為所有查詢樣本的集合, 其中Qx={X1,X2,…,Xc×m}. 計(jì)算每個類c={1,2,…,C}的初始原型
P(0)c=1Sc∑x∈Scfθ(x); 以t值為基礎(chǔ), 計(jì)算樣本x的置信度, 即屬于類c的概率p(t-1)c(x),q(t-1)c(x)=exp{-d(fθ(x),P(t-1)c)}∑Cc′=1(-d(fθ(x),P(t-1)c′)),(10)
其中t={1,2,…,T}, x∈Qx, d(·,·)表示歐氏距離, P(t-1)表示第(t-1)步需要更新的原型網(wǎng)絡(luò). 基于所有樣本x的置信度(概率)更新類c的原型. 網(wǎng)絡(luò)的具體訓(xùn)練過程如下.
算法1"利用區(qū)域混淆和FiLM變換改進(jìn)后的MCT網(wǎng)絡(luò)算法.
1) BEGIN
輸入: θ0,G0,J0,train,Nepisode(任務(wù)量);輸出: θ,G,J;
2) θ,G,J←θ0,G0,J0//模型參數(shù)初始化;
3) FOR x FROM 1 to Nepisode DO
4)"從train中隨機(jī)采樣一個任務(wù)δ(x)={S(x),Q(x)};
5) θ←σ(δ(x))//函數(shù)σ( )是將參數(shù)進(jìn)行區(qū)域混淆, 按式(3)描述過程得到區(qū)域混淆重構(gòu)后的新坐標(biāo);
6) θ′←FiLM(θγi,c,βi,c)//FiLM變換, 按式(6)完成;
7) G←式(7)//按式(7)得到按實(shí)例度量縮放的距離度量值
G←式(8)//按式(8)得到按成對度量縮放的距離度量值
8) G′←φ(G)//用函數(shù)φ( )對距離函數(shù)進(jìn)行優(yōu)化, 使dθ損失最??;
9) END FOR
10) J0←0//初始化網(wǎng)絡(luò)原型;
11) J←J0+式(10)//使用Soft k-means算法更新原型;
12) θ,G,J//更新網(wǎng)絡(luò)參數(shù);
13) END.
3"實(shí)"驗(yàn)
3.1"小樣本學(xué)習(xí)問題的評判標(biāo)準(zhǔn)
小樣本學(xué)習(xí)中5-way 1-shot分類任務(wù)表示在數(shù)據(jù)集中包含5個類別, 每個類別中的1個樣本作為訓(xùn)練樣本. 小樣本學(xué)習(xí)是指在實(shí)驗(yàn)中, 只給模型提供較少的樣本(1個或幾個),檢驗(yàn)?zāi)P湍芊裾_地推廣到新的類別上. 同樣在5-way 5-shot任務(wù)中選擇5類數(shù)據(jù)樣本, 并在每個類別中隨機(jī)抽取5個訓(xùn)練樣本, 模型需要從這些訓(xùn)練樣本中學(xué)習(xí)如何識別這5個類別, 當(dāng)進(jìn)行分類任務(wù)時, 模型能正確地完成分類.
分類準(zhǔn)確率(ACC)是衡量小樣本學(xué)習(xí)分類模型性能的重要指標(biāo), 主要是通過計(jì)算模型正確分類的圖像數(shù)量nR與總預(yù)測圖像數(shù)量n的比值得到, 公式為
ACC=nR/n.(11)
3.2"實(shí)驗(yàn)及參數(shù)配置
本文在數(shù)據(jù)集Mini-ImageNet和Tiered-ImageNet上對比并驗(yàn)證了本文方法. 激活函數(shù)使用Hardswish[20], 與ReLU,Sigmoid和Tanh激活函數(shù)不同, 該激活函數(shù)在本文問題中具有計(jì)算效率的優(yōu)勢. Mini-ImageNet: 該數(shù)據(jù)集由數(shù)據(jù)集ImageNet的一個子集組成, 常被用于小樣本圖像分類算法中. 數(shù)據(jù)集Mini-ImageNet包含100個類, 每個類有600張圖像, 共有60 000張數(shù)據(jù)集ImageNet的圖像, 并且每張圖像的處理長寬為84. 本文從100個類別中選取20個類別進(jìn)行測試, 16個用于檢驗(yàn), 剩余的64個用于訓(xùn)練. Tiered-ImageNet: 該數(shù)據(jù)集同樣也是基于數(shù)據(jù)集ImageNet, 但它采用了一種不同的分層采用方法, 可以更好地模擬實(shí)際應(yīng)用場景中的數(shù)據(jù)分布和類別之間的關(guān)系, 與數(shù)據(jù)集Mini-ImageNet相比包含了更廣泛的類別. 本文從20個上層類別中選擇351個類別樣本進(jìn)行訓(xùn)練; 6個不同類別的上層類別中選擇97個類別樣本進(jìn)行驗(yàn)證; 8個不同類別的上層類別中選擇160個類別樣本進(jìn)行測試.
3.3"實(shí)驗(yàn)結(jié)果和分析
對數(shù)據(jù)集圖像結(jié)構(gòu)進(jìn)行全局破壞重構(gòu), 并將FiLM變換結(jié)合主干網(wǎng)絡(luò)ResNet-12改進(jìn)的MCT網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練和測試. 本文在數(shù)據(jù)集上分別進(jìn)行了5-way 1-shot和5-way 5-shot的兩種實(shí)驗(yàn), 目的是與原有的MCT算法和其他同類型的小樣本圖像分類算法公平對比. 在實(shí)驗(yàn)中以數(shù)據(jù)集Mini-ImageNet的500次測試的平均精度作為最終結(jié)果, 以數(shù)據(jù)集Tiered-ImageNet的800次測試的平均精度作為最終結(jié)果. 除圖像處理部分和FiLM變換外, 其余MCT的原有參數(shù)保持不變.
基于數(shù)據(jù)集Mini-ImageNet和Tiered-ImageNet, 各小樣本分類模型在5-way 1-shot和5-way 5-shot任務(wù)上進(jìn)行實(shí)驗(yàn). 實(shí)驗(yàn)結(jié)果列于表1, 包括TPN[21],DPGN[22],MetaOptNet[23],MCT[17](*表示原MCT網(wǎng)絡(luò)在同一設(shè)置下的復(fù)現(xiàn))及本文的改進(jìn)結(jié)果.
由表1可見, 改進(jìn)后的本文(instance)模型在數(shù)據(jù)集Mini-ImageNet上的學(xué)習(xí)精度分別達(dá)到了(79.74±0.85)%,(88.45±0.12)%; 在數(shù)據(jù)集Tiered-ImageNet上的學(xué)習(xí)精度分別達(dá)
到了(83.11±0.65)%,(89.01±0.21)%. 實(shí)驗(yàn)結(jié)果表明, 打亂圖像區(qū)域分布和FiLM變換的方法對解決MCT網(wǎng)絡(luò)在小樣本圖像分類問題上性能較好.
3.4"消融實(shí)驗(yàn)
為進(jìn)一步驗(yàn)證改進(jìn)后的MCT網(wǎng)絡(luò)模型的性能, 在數(shù)據(jù)集Mini-ImageNet上進(jìn)行5-way 1-shot和5-way 5-shot的實(shí)驗(yàn). 將原MCT網(wǎng)絡(luò)同一設(shè)置下的復(fù)現(xiàn)結(jié)果作為基線, 分別測試進(jìn)行
數(shù)據(jù)擾動(Sr)和FiLM變化后的數(shù)據(jù), 實(shí)驗(yàn)結(jié)果列于表2. 表2的實(shí)驗(yàn)結(jié)果進(jìn)一步驗(yàn)證了改進(jìn)后的MCT網(wǎng)絡(luò)模型的有效性, 優(yōu)于原網(wǎng)絡(luò)模型, 取得了良好的效果.
綜上所述, 為提高M(jìn)CT網(wǎng)絡(luò)在小樣本圖像分類問題中的準(zhǔn)確性, 基于該網(wǎng)絡(luò)模型, 通過區(qū)域混淆機(jī)制, 使神經(jīng)網(wǎng)絡(luò)能根據(jù)區(qū)域的圖像細(xì)節(jié)信息實(shí)現(xiàn)分類, 降低對圖像全局結(jié)構(gòu)的依賴性. 利
用FiLM變換使網(wǎng)絡(luò)更適用于支持/查詢的學(xué)習(xí)模式, 有效改進(jìn)了當(dāng)數(shù)據(jù)來自未知分布時置信度不可靠的問題. 通過在經(jīng)典數(shù)據(jù)集上的實(shí)試驗(yàn)結(jié)果表明, 改進(jìn)后的本文(instance-FiLM)
網(wǎng)絡(luò)模型相比原模型在5-way 1-shot任務(wù)中的分類準(zhǔn)確率提高了3.21,3.36個百分點(diǎn), 在5-way 5-shot任務(wù)中相比原模型提高了2.89,1.89個百分點(diǎn).
參考文獻(xiàn)
[1]"BENGIO Y, LECUN Y, HINTON G. Deep Learning for AI [J]. Communications of the ACM, 2021, 64(7): 58-65.
[2]"LI Y D, HAO Z B, LEI H. Survey of Convolutional Neural Network [J]. Journal of Computer Applications, 2016, 36(9): 2508-2515.
[3]"HU X, CHEN S. A Survey of Few-Shot Learning Based on Machine Learning [J].
Intelligent Computer and Applications, 2021(7): 191-195.
[4]"朱曉慧, 錢麗萍, 傅偉. 圖像數(shù)據(jù)增強(qiáng)技術(shù)研究綜述 [J]. 軟件導(dǎo)報, 2021, 20(5): 230-236. (ZHU X H, QIAN L P, FU W. Overview of Research
on Image Data Enhancement Technology [J]. Software Guide, 2021, 20(5): 230-236.)
[5]"PAN S J, YANG Q. A Survey on Transfer Learning [J]. IEEE Transactions on Knowledge and Data Engineering, 2009, 22(10): 1345-1359.
[6]"OCHAL M, PATACCHIOLA M, STORKEY A J, et al. Few-Sho
t Learning with Class Imbalance [J]. IEEE Transactions on Artificial Intelligence, 2021, 4: 1348-1358.
[7]"YAN J, FENG K Y, ZHAO H Y, et al. Siamese-Prototypical Network with Data Augme
ntation Pre-training for Few-Shot Medical Image Classification [C]//2022 2nd Inte
rnational Conference on Frontiers of Electronics, Information and Computation Technologies (ICFEICT). Piscataway, NJ: IEEE, 2022: 387-391.
[8]"ZHANG X, HUANG W G, WANG R, et al. Multi-stage Distribution Correction: A Pr
omising Data Augmentation Method for Few-Shot Fault Diagnosis [J]. Engineering Application of Artificial Intelligence, 2023, 123: 106477-1-106477-16.
[9]"LONG M S, ZHU H, WANG J M, et al. Unsupervised Domain Adaptation with Residual T
ransfer Networks [J]. Advances in Neural Information Processing Systems, 2016, 29: 22-24.
[10]"SHEN Y Y, YAN Y, WANG H Z. Recent Advances on Supervised Distance Metric Learn
ing Algorithms [J]. Acta Automatica Sinica, 2014, 40(12): 2673-2686.
[11]"KOCH G, ZEMEL R, SALAKHUTDINOV R. Siamese Neural Networ
ks for One-Shot Image Recognition [C]//Proceedings of the 32nd International Conference on Machine Learning. [S.l.]: JMLR, 2015: 1-8.
[12]"VINYALS O, BLUNDELL C, LILLICRAP T, et al. Matching Networks for One Shot Learning [C]//Advances in Neural Information Processing Systems. Cambridge: MIT Press, 2016: 3630-3638.
[13]"SNELL J, SWERSKY K, ZEMEL R S. Prototypical Network
s for Few-Shot Learning [C]//Proceeding of the 31st International Conference on Neural Information Processing Systems. New York: ACM, 2017: 4080-4090.
[14]"SUNG F, YANG Y, ZHANG L, et al. Learning to Compare: Relation Network for
Few-Shot Learning [C]//2018 IEEE Conference on Computer Vision and Pattern Recognition. Pidcataway, NJ: IEEE Computer Society, 2018: 1199-1208.
[15]"KITSON N K, CONSTANTINOU A C, GUO Z G, et al. A Sur
vey of Bayesian Network Structure Learning [J]. Artificial Intelligence Review, 2023, 56(8): 8721-8814.
[16]"BOUSQUET O. Transductive Learning: Motivation, Models, Algorithms [J]. Jo
urnal of Machine Learning Research, 2002, 2002(14): 135-168.
[17]"KYE S M, LEE H, KIM H, et al. Transductive Few-Shot Learning with Meta-L
earned Confidence [J]. Journal of Machine Learning Research, 2020, 1(2): 1-3.
[18]"CHEN Y, BAI Y L, ZHANG W, et al. Destruction and Construction Learning for Fine-Grained Image Recognition [C]//Conference on Computer Vision and Pattern Recognition. Piscataway, NJ: IEEE, 2019: 5157-5166.
[19]"PEREZ E, STRUB F, DE VRIES H, et al. FiLM: Visual Reasoning with a General
Conditioning Layer [C]//Proceedings of the Thirty-Second AAAI Conference on Artificial Intelligence. Palo Alto: AAAI Press, 2018: 3942-3951.
[20]"RAMACHANDRAN P, ZOPH B, LE Q V. Searching for Activation Functions [EB/OL]. (2017-10-16)[2023-02-15]. https://arxiv.org/abs/1710.05941.
[21]"LIU Y B, LEE J, PARK M, et al. Learning to Propagate Labels: Transductive Propagation Network for Few-Shot Learning [EB/OL]. (2018-03-25)[2023-03-10]. https://arxiv.org/abs/1805.10002.
[22]"YANG L, LI L L, ZHANG Z L, et al. DPGN: Distribution Propagation Graph Network
for Few-Shot Learning [C]//2020 IEEE/CVF Computer Vision and Pattern Recognition. Piscataway, NJ: IEEE, 2020: 13387-13396.
[23]"LEE K, MAJI S, RAVICHANDRAN A, et al. Meta-Learnin
g with Differentiable Convex Optimization [C]//IEEE Conference on Computer Vision and Pattern Recognition. Piscataway, NJ: IEEE, 2019: 10657-10665.
(責(zé)任編輯: 韓"嘯)