王光博,陳 亮
(沈陽理工大學 自動化與電氣工程學院,沈陽 110159)
隨著科學技術的持續(xù)發(fā)展,深度學習[1]作為機器學習的分支已經(jīng)在計算機視覺領域中取得了巨大的進步,且基于神經(jīng)網(wǎng)絡的圖像分類技術取得了快速發(fā)展[2]。由于深度學習模型的參數(shù)多,訓練一般需要依靠大量的監(jiān)督數(shù)據(jù),且有些真實樣本因涉及隱私、安全問題無法收集,例如在金融、軍事[3]和醫(yī)學領域,由于無法獲得足夠多的訓練數(shù)據(jù),模型訓練往往達不到預期效果。因此,需要研究小樣本圖像分類任務[4],讓模型通過每類少數(shù)標簽樣本獲得分類的能力。
小樣本學習降低了獲取數(shù)據(jù)集的難度。在小樣本學習領域常用的四種方法有數(shù)據(jù)增強、遷移學習[5]、正則化和元學習[6]。利用數(shù)據(jù)增強、遷移學習或正則化的方法并不能從根本上解決小樣本中過擬合的現(xiàn)象。元學習的主要目的是讓機器學會學習,利用任務之間的共性,使模型從少量標簽樣本中進行算法學習,確保元學習器能夠快速習得解決新學習任務的能力。Santoro等[7]提出采用長短時記憶(LSTM)網(wǎng)絡解決單樣本學習問題,Shyam等[8]提出采用循環(huán)神經(jīng)網(wǎng)絡進行樣本間的動態(tài)比較。度量學習[9]可以看作元學習的一種形式,基于度量的元學習采用不同的度量方法對樣本進行相似性度量,其中比較有代表性的研究成果有Sung等[10]提出的具有高效度量方式的關系網(wǎng)絡等。張碧陶等[11]提出了融合強化學習和關系網(wǎng)絡的小樣本算法,使改進的網(wǎng)絡結(jié)構(gòu)獲得了良好的分類效果。魏勝楠等[12]在關系網(wǎng)絡中引入自注意力機制,使模型能夠提取每個類別特定的信息,提高了分類準確率。王年等[13]在關系網(wǎng)絡中加入了感受野塊,增強了網(wǎng)絡的度量能力。
本文在關系網(wǎng)絡的基礎上提出一種新的小樣本圖像分類方法。該方法改進原關系網(wǎng)絡的結(jié)構(gòu),將嵌入模型的第3個卷積塊替換成inception塊[14],增強嵌入模塊的特征提取能力;替換原關系網(wǎng)絡中用于相關性計算的激活函數(shù),實現(xiàn)更好的信息流動,有利于模型的訓練;為便于模型訓練,改進原關系網(wǎng)絡的損失函數(shù),使模型具有更好的泛化能力,有效提高小樣本圖像分類的準確度。
元學習通過學習已知的類別,完成新測試任務的分類。數(shù)據(jù)集分成訓練集、驗證集以及測試集,且屬于不同的標簽域。在訓練階段,將數(shù)據(jù)集分成多個元任務,隨機抽取C種類別,每種類別中有K個標記樣本,由C×K個樣本構(gòu)成支持集,C種類別剩余的樣本構(gòu)成查詢集,目的是使模型在C×K個樣本里學會區(qū)分C種類別,稱為C-wayK-shot問題。如果K等于1,稱為單樣本學習,如果K大于1,則為小樣本學習。本文將選擇5-way 1-shot和5-way 5-shot兩種方式在三個小樣本檢測任務常用的數(shù)據(jù)集Omniglot、MiniImagenet以及TieredImageNet上完成模型的訓練和測試。
關系網(wǎng)絡包括嵌入模塊和關系模塊,其結(jié)構(gòu)如圖1所示。
圖1 關系網(wǎng)絡結(jié)構(gòu)框圖
嵌入模塊的功能是對輸入圖片的特征信息進行提取,該模塊由4個卷積層和2個最大池化層構(gòu)成;關系模塊的主要作用是計算支持樣本與查詢樣本的相似度,由兩個卷積層和兩個全連接層構(gòu)成,最后利用Sigmoid函數(shù)計算查詢樣本和支持樣本的相似程度,從而判斷兩幅圖像是否屬于同一類別,計算表達式為
式中:ri,j為支持集樣本與查詢集樣本的關系得分;m表示查詢集樣本的數(shù)量;φ、φ分別為嵌入?yún)?shù)和關系參數(shù);fφ是嵌入模塊的映射函數(shù);gφ是關系模塊的映射函數(shù);C(·)為提取到的查詢樣本和支持樣本的特征;yi和yj分別表示查詢樣本和支持樣本的標簽;I(·)是指數(shù)函數(shù),如果支持集樣本和查詢集樣本是不同類別,值為0,如果支持集樣本和查詢集樣本是相同類別,值為1;argmin表示通過最小化均方誤差損失優(yōu)化嵌入模型和關系模型的參數(shù)。
本文基于改進關系網(wǎng)絡的小樣本學習模型結(jié)構(gòu)框圖如圖2所示,改進了原關系網(wǎng)絡的嵌入模塊和關系模塊。
圖2 改進關系網(wǎng)絡結(jié)構(gòu)框圖
本文將原關系網(wǎng)絡嵌入模塊的第3個卷積塊替換成inception塊,增強嵌入模塊的特征提取能力。關系模塊采用3個卷積塊和2個全連接層,通過縮放函數(shù)輸出關系得分。在原結(jié)構(gòu)的基礎上增加一個卷積塊的目的是為了能夠?qū)⑻卣餍畔⑦M一步卷積,并將第3個卷積層采用全局平均池化處理,避免過擬合現(xiàn)象的發(fā)生,通過Mish全連接層激活,最后一個全連接層使用Sigmoid和縮放函數(shù)計算查詢樣本和支持樣本的相似程度。因為縮放函數(shù)能夠使輸出的特征向量維持在一個特定的范圍之內(nèi),進而降低特征向量的影響程度。采用縮放函數(shù)加快了梯度的收斂速度,更換原關系網(wǎng)絡中用于相關性計算的激活函數(shù)和損失函數(shù),以有利于網(wǎng)絡模型訓練。通過對查詢集圖像和支持集圖像的特征向量進行相關性計算得到相似度分數(shù),分數(shù)最高的類即為預測的分類。
1.3.1 inception塊
inception塊的設計思想是用并聯(lián)的方法使不同的卷積層進行組合,經(jīng)過卷積層處理的結(jié)果矩陣進行拼接,形成一個更深的矩陣。inception塊是讓網(wǎng)絡決定需要什么樣卷積層以及是否需要池化操作,能夠?qū)σ恍┐蟪叽绲木仃囘M行降維操作,進而降低計算量,能夠在不同尺寸上聚合圖像信息,從不同尺度上提取特征。本文將原關系網(wǎng)絡嵌入模塊中的第三個卷積塊替換成inception塊,能夠有效地改進網(wǎng)絡的深度和寬度,提升模型的準確率,避免過擬合現(xiàn)象。
本文采用的inception塊共有三部分,結(jié)構(gòu)如圖3所示。為降低通道的個數(shù),第一部分采用1×1的卷積,并進一步減少計算量,再通過2個3×3的卷積將感受野進行放大,在減少計算量的同時不會降低網(wǎng)絡的性能。為獲得更多的表層信息,第二部分采用了常規(guī)3×3大小的卷積,同時可以保留更多的紋理信息,第三部分采用3×3最大池化層提取不同尺度的特征。最后將這三部分得到的不同特征圖拼接在一起,得到多尺度特征。結(jié)合三個部分得到的不同特征映射,能夠增強嵌入模塊的特征表達能力。
圖3 inception塊結(jié)構(gòu)圖
1.3.2 激活函數(shù)
關系網(wǎng)絡中的激活函數(shù)采用ReLU函數(shù)。ReLU函數(shù)為分段性函數(shù),函數(shù)正值的收斂比較快,負值的梯度為0,會出現(xiàn)對應的參數(shù)不更新的情況。Mish激活函數(shù)的形式為
兩種激活函數(shù)的比較如圖4所示。
由圖4可見,Mish激活函數(shù)是連續(xù)的光滑函數(shù),避免了ReLU激活函數(shù)的奇異點,有更好的泛化能力和優(yōu)化能力;Mish函數(shù)曲線是上無邊界和有下邊界,上無邊界防止了梯度飽和的情況發(fā)生,有下邊界與ReLU激活函數(shù)的硬零邊界不同,會保留較小的負值,能夠穩(wěn)定網(wǎng)絡梯度流,從而實現(xiàn)更好的信息流動。因此,本文采用Mish取代Re-LU作為激活函數(shù)。
圖4 激活函數(shù)對比圖
1.3.3 損失函數(shù)
關系網(wǎng)絡采用的損失函數(shù)是均方誤差MSE。MSE損失函數(shù)的公式為
式中:yi是輸入樣本的真實值;f(xi)是預測值。如果輸入樣本真實值和預測值之差比1小,誤差會變得更小,如果輸入樣本真實值和預測值之差比1大,誤差會變得更大,因此MSE損失函數(shù)的缺點是對離群點敏感。
本文改進網(wǎng)絡的損失函數(shù)采用修正后的平均絕對誤差SmoothL1,其公式為
式中f'(xi)是預測值。
圖5為兩個損失函數(shù)的對比圖。
圖5 損失函數(shù)對比圖
從圖5可以看出,均方誤差損失函數(shù)的特點是連續(xù)光滑,同時函數(shù)上的每一個點都可導,便于網(wǎng)絡模型更好地收斂;本文網(wǎng)絡模型采用的損失函數(shù)與均方誤差損失函數(shù)比較,存在對離群點不敏感的優(yōu)勢,無論差值多大,其懲罰都不變,有著穩(wěn)定的梯度,不會出現(xiàn)梯度爆炸的現(xiàn)象,便于模型更好地訓練。
數(shù)據(jù)集劃分為訓練集Dtrain、支持集Dsupport以及測試集Dtest,Dtrain對網(wǎng)絡進行元訓練,Dtest對網(wǎng)絡的泛化性能進行測試。在元訓練期間,從訓練集中隨機選取一些樣本作為支持集,將余下的樣本組成查詢集,并且Dsupport和Dtest是一樣的標簽空間,Dtrain和Dtest無交集。
本文的網(wǎng)絡整體模型如圖6所示。首先通過損失函數(shù)訓練嵌入模塊,然后查詢集樣本和支持集樣本分別進入嵌入模塊進行特征提取得到特征向量,確定嵌入網(wǎng)絡參數(shù)。將提取到的特征向量進行組合后輸入關系模塊進行計算,使用損失函數(shù)進行訓練,得到相似度分數(shù)。當查詢集樣本和支持集樣本是不同類別時,相似度分數(shù)接近0,當查詢集樣本和支持集樣本是同類別時,相似度分數(shù)接近1。網(wǎng)絡參數(shù)的訓練過程為
圖6 網(wǎng)絡整體模型
網(wǎng)絡模型訓練的過程就是模擬小樣本分類的場景。對比標簽yi和網(wǎng)絡模型的相似度分數(shù)ri,j,通過累加求和過程獲得最終的損失值。
本文改進的關系網(wǎng)絡模型在Omniglot、Mini-Imagenet和TieredImageNet三個小樣本任務常用的數(shù)據(jù)集上完成實驗。Omniglot數(shù)據(jù)集[15]由不同人繪制的字符組成,總計1 623個類別的字符;MiniImageNet數(shù)據(jù)集包含100種類別的60 000張圖片,由600張圖像構(gòu)成一個類;TieredImageNet數(shù)據(jù)集是數(shù)量較大的小樣本學習數(shù)據(jù)集,一共包含608個類,總計有779 165張圖像,比MiniImageNet數(shù)據(jù)集中的類別有更大的語義差距,從而提供了更嚴格的泛化測試。
實驗條件:Intel(R)Core(TM)i5-10300H,2.50 GHz、16 GB內(nèi)存,NVIDIA GeForceRTX2060顯卡,Windows10操作系統(tǒng),基于Pytorch深度學習框架。
各數(shù)據(jù)集的設置說明如表1所示。
表1 各數(shù)據(jù)集的設置說明
將Omniglot數(shù)據(jù)集劃分成三個部分,其中訓練集由1 200類圖像組成,驗證集由123類圖像組成,測試集由300類圖像組成,實驗結(jié)果由測試集中隨機生成的1 000個批次的分類精確度平均值表示;在MiniImageNet數(shù)據(jù)集中,訓練集由64類圖像組成,驗證集由16類圖像組成,測試集由20類圖像組成,實驗結(jié)果由測試集中隨機生成的1 000個批次的分類精確度平均值表示;TieredImageNet數(shù)據(jù)集中的訓練集由351類圖像組成,驗證集由97類圖像組成,測試集由160類圖像組成,最終實驗結(jié)果由測試集中隨機生成的600個批次的分類精確度平均值表示。為使分類的效果更好,將數(shù)據(jù)集圖像采用翻轉(zhuǎn)方式對其進行擴充,以達到數(shù)據(jù)增強的效果。
將改進后的關系網(wǎng)絡在Omniglot數(shù)據(jù)集上的運行結(jié)果與原關系網(wǎng)絡進行比較,改進前后模型的準確度如表2所示。
表2 改進前后模型在Omniglot數(shù)據(jù)集上的準確度
本文的網(wǎng)絡模型在Omniglot數(shù)據(jù)集的5-way 5-shot上的準確度為99.8%±0.32%,對比原關系網(wǎng)絡提升的效果并不明顯,但是在5-way 1-shot上的準確度為99.7%±0.32%,比原關系網(wǎng)絡大約提高了0.1%。
將改進后的關系網(wǎng)絡在MiniImageNet數(shù)據(jù)集上的運行結(jié)果與原關系網(wǎng)絡進行比較,如表3所示。
表3 改進前后模型在MiniImageNet數(shù)據(jù)集上的準確度
本文的網(wǎng)絡模型在5-way 1-shot上的準確度為54.24%±0.79%,比原關系網(wǎng)絡模型的準確度提高了3.8%,在5-way 5-shot上的準確度為69.05%±0.71%,比關系網(wǎng)絡提高了3.73%。
本文改進的網(wǎng)絡模型和原關系網(wǎng)絡在Mini-ImageNet數(shù)據(jù)集中的5-way 1-shot和5-way 5-shot任務上的迭代次數(shù)和分類準確度如圖7所示。
圖7 兩種情況下的分類準確度
由圖9可見,在20 000次之前的訓練中,圖像分類的準確度在不斷增加,在40 000次之后的訓練中,分類準確度的變化不大且基本保持平穩(wěn)。本文的改進方法在在5-way 1-shot和5-way 5-shot兩種情況下都表現(xiàn)出更好的性能。
將本文改進的網(wǎng)絡在TieredImageNet數(shù)據(jù)集上的運行結(jié)果與原關系網(wǎng)絡比較,如表4所示。
本文網(wǎng)絡模型在5-way 1-shot任務上的準確度為58.69%,比關系網(wǎng)絡提高了4.21%,在5-way 5-shot上的準確度為75.36%,比關系網(wǎng)絡提高了4.05%,兩種情況下都表現(xiàn)出優(yōu)異的性能。
2.4.1 網(wǎng)絡結(jié)構(gòu)分析
為進一步驗證引入inception塊對模型圖像分類準確度的影響,在MiniImageNet數(shù)據(jù)集上分別在模型沒有引入inception塊和引入inception塊的兩種情況下進行實驗,實驗結(jié)果如表5所示。
表5 網(wǎng)絡結(jié)構(gòu)分析
由表5可知,在5-way 1-shot情況下,引入inception塊的分類準確度提高1.17%,在5-way 5-shot情況下提高1.03%。證實了引入inception塊的網(wǎng)絡能夠增強模型的分類準確度。
2.4.2 魯棒性分析
為進一步驗證本文改進網(wǎng)絡具有的魯棒性,在保證驗證集和測試集不變的情況下,在MinImagenet數(shù)據(jù)集上的5-way 1-shot和5-way 5-shot實驗中,對模型分類準確度隨測試集類別數(shù)變化的情況進行比較。在MinImagenet數(shù)據(jù)集中以10類為間隔,從100類到10類依次改變測試集的類別數(shù),兩種情況下的分類準確度如圖8所示。從圖8可以看出,在MinImagenet數(shù)據(jù)集上,隨著測試集中類別數(shù)量的減少,模型的分類準確度逐漸降低,但模型依然可以保持75%以上的分類準確度,表明本文改進模型的魯棒性明顯優(yōu)于原關系網(wǎng)絡模型。
圖8 兩種情況下的分類準確度
2.4.3 評估實驗
為驗證數(shù)據(jù)集多樣性對于本文改進網(wǎng)絡的影響,在Omniglot和MinImagenet數(shù)據(jù)集上進行實驗,結(jié)果如圖9所示。圖9給出了模型分類準確率隨數(shù)據(jù)集多樣性的變化,經(jīng)過對比分析可以發(fā)現(xiàn):數(shù)據(jù)集多樣性越高,分類準確率越低;在同樣類別的情況下,樣本數(shù)越多,準確率越高;在樣本數(shù)相同的情況下,類別數(shù)越多,準確率越低。
圖9 C-way K-shot任務
本文在原關系網(wǎng)絡的基礎結(jié)構(gòu)上,引入了inception塊增加網(wǎng)絡的寬度,更換了原關系網(wǎng)絡模型中的損失函數(shù)和激活函數(shù),有利于網(wǎng)絡訓練,同時保留了網(wǎng)絡結(jié)構(gòu)的簡單性以及快速的訓練和測試過程。在Omniglot數(shù)據(jù)集、MiniImageNet數(shù)據(jù)集和TieredImageNet數(shù)據(jù)集上的實驗結(jié)果表明,改進的關系網(wǎng)絡比原關系網(wǎng)絡的準確率有所提高,可以提升小樣本學習的泛化能力。