莫建文, 賈 鵬
(桂林電子科技大學 信息與通信學院,廣西 桂林 541004)
半監(jiān)督分類是當前機器學習領域的一個重要研究方向。區(qū)別于傳統(tǒng)的有監(jiān)督分類算法,它能夠基于標記樣本,不依賴外界交互,自動地利用大量廉價的未標記樣本提升學習性能。由于現(xiàn)實應用中利用未標記樣本提升算法性能的巨大需求,半監(jiān)督算法迅速成為了研究熱點。
早期半監(jiān)督學習算法中有較大影響力的是半監(jiān)督自訓練法和協(xié)同訓練法。自訓練法可看作是早期利用無標記樣本的一種原始框架。先用少量標記樣本初始化,然后分類大量無標記樣本,選擇可靠的偽標記樣本擴充訓練集,直到收斂。Li等[1]提出一種基于最優(yōu)路徑森林的自訓練方法,其中所有樣本作為最優(yōu)路徑森林的頂點相互連接,利用特征空間的結構和分布,幫助自訓練法給未標記數(shù)據(jù)貼標簽。Hyams等[2]基于自訓練法,利用基于MC-dropout的置信區(qū)間作為置信度測量方法,得到改進的偽標記方法。Chen等[3]采用協(xié)同訓練策略,并結合SVM算法進行分類器設計。該算法利用2個分類器對未標記樣本進行分類,從而擴展標記樣本,提高分類器性能。
隨著深度學習在圖像處理方面取得突破,半監(jiān)督深度學習算法研究成為了自然的需求。許多學者通過優(yōu)化網(wǎng)絡結構構建的半監(jiān)督分類學習框架都能取得不錯的分類效果。如Springenberg[4]從損失函數(shù)的角度對網(wǎng)絡優(yōu)化,提出CAT-GAN,它改變了網(wǎng)絡的訓練誤差,通過正則化信息最大化的框架提高了模型的魯棒性。Salimans等[5]從判別器角度對網(wǎng)絡優(yōu)化,提出Improved-GAN,它在判別器后接一個分類器,訓練網(wǎng)絡得到一個分類判別器,可以直接將數(shù)據(jù)分為原始圖像類別和生成圖像假的類別,提高了生成樣本質(zhì)量和模型的穩(wěn)定性。有學者將無監(jiān)督深度學習網(wǎng)絡與不同的有監(jiān)督分類算法結合,也能得到較好的半監(jiān)督分類模型。Rasmus等[6]利用去噪自編碼構造梯形網(wǎng)絡用于半監(jiān)督分類任務,提高網(wǎng)絡對更深層次圖像特征的學習能力。付曉等[7]將生成對抗網(wǎng)絡與編碼器相結合,更好地提取特征。Larsen等[8]將變分自編碼器和生成對抗網(wǎng)絡結合,將生成對抗網(wǎng)絡判別器學習到的特征表示用于變分自編碼重構。Chen等[9]提出半監(jiān)督記憶網(wǎng)絡,第一次將記憶機制和半監(jiān)督深度學習相結合,利用模型學習過程中產(chǎn)生的記憶信息,使預測結果更可靠。但是,半監(jiān)督深層生成模型隨著網(wǎng)絡層數(shù)和參數(shù)的增多,會導致模型出現(xiàn)過擬合問題,并且半監(jiān)督的方法未能充分利用大量未標記數(shù)據(jù)來輔助少量的有標記數(shù)據(jù)進行學習。
為了充分利用未標記樣本,提高半監(jiān)督深層生成模型的分類性能與泛化性能,在梯形網(wǎng)絡框架的基礎上,結合mix_up線性插值數(shù)據(jù)增強方法對訓練數(shù)據(jù)進行預處理,并進一步討論了虛擬對抗訓練對模型魯棒性的影響,提出一種基于改進梯形網(wǎng)絡的半監(jiān)督虛擬對抗訓練模型(ILN-SS VAT)。本模型在梯形網(wǎng)絡框架的基礎上,引入虛擬對抗訓練的正則化方法,結合mix_up線性插值法,完成數(shù)據(jù)增強操作,提高圖像分類性能。首先,對訓練數(shù)據(jù)進行凸融合,使用mix_up線性插值的方法得到新的擴充訓練集;在梯形網(wǎng)絡的輸入層引入虛擬對抗噪聲,并且保持編碼器輸出一致,對網(wǎng)絡進行正則化;最后,模型以分類損失、重構損失和虛擬對抗損失相結合的方式調(diào)整參數(shù),訓練分類器。
ILN-SS VAT模型主要包括數(shù)據(jù)增強處理、虛擬對抗訓練和訓練分類器3個部分。梯形網(wǎng)絡框架共有3個網(wǎng)絡支路:有噪編碼器、解碼器和無噪編碼器,其中編、解碼器用VGG網(wǎng)絡[10]或Π模型[11]。有噪編碼器各層得到的特征變量通過跳躍連接(skip connection)映射到對應的解碼器上,而無噪編碼器則輔助解碼器進行無監(jiān)督訓練,以達到對有噪數(shù)據(jù)的最佳映射效果。模型在訓練時,先對訓練數(shù)據(jù)做mix_up增強處理,使用線性插值的方法得到新的擴展數(shù)據(jù)輸入網(wǎng)絡。然后在梯形網(wǎng)絡框架上計算對抗擾動,引入虛擬對抗損失構建平滑性正則化約束。最后,模型以分類損失、重構損失和虛擬對抗損失相結合的方式共同調(diào)整網(wǎng)絡參數(shù),訓練得到分類器。
為了解決半監(jiān)督分類模型中有標記樣本數(shù)不足的問題,在訓練之前,模型采用mix_up數(shù)據(jù)增強方法進行數(shù)據(jù)預處理。mix_up[12]是一種運用在計算機視覺中的對圖像進行混類增強的算法,它可以將不同類之間的圖像進行混合,從而擴充訓練數(shù)據(jù)集。
λ=Beta(α,β);
(1)
(2)
(3)
(4)
圖1 ILN-SS VAT模型結構
mix_up是鄰域風險最小化的一種形式,它令模型在處理樣本之間的區(qū)域時表現(xiàn)為線性,相比于其他數(shù)據(jù)增強方法,這種線性建模減少了預測訓練樣本以外數(shù)據(jù)的不適應性,提高了鄰域內(nèi)的平滑性。此外,擴充訓練數(shù)據(jù)集有助于消除對錯誤標簽的記憶、網(wǎng)絡的敏感性以及對抗訓練的不穩(wěn)定性,提高模型泛化能力。
ILN-SS VAT模型利用mix_up數(shù)據(jù)增強和虛擬對抗訓練的優(yōu)勢,在梯形網(wǎng)絡框架的基礎上,以有監(jiān)督分類損失、無監(jiān)督重構損失和虛擬對抗損失相結合的方式共同調(diào)整網(wǎng)絡參數(shù),訓練得到分類器,綜合增強模型泛化能力,并且提高圖像分類精度。
1.2.1 虛擬對抗損失
虛擬對抗訓練是一種有效的正則化技術[13]。通過在實際數(shù)據(jù)點上應用小的隨機擾動來生成人工數(shù)據(jù)點,鼓勵模型為真實和擾動的數(shù)據(jù)點提供類似的輸出。從分布散度的意義上,在虛擬對抗方向上的擾動能夠極大地改變輸出分布,虛擬對抗方向定義在未標記訓練數(shù)據(jù)點上,使得當前的模型輸出分布極大地偏離當前狀態(tài)。
(5)
(6)
(7)
其中:p(y|x,θ)為模型輸出分布;q(y|x)為輸出標簽的真實分布;D[p,q]為KL散度,用于評估p、q之間的距離。式(6)表示在r的L1范數(shù)小于某個值的情況下,找到使式(5)最大的radv,即為擾動方向,最小化兩輸出之間的KL散度即得到對抗損失Vloss,1。
虛擬對抗損失定義為模型的后驗分布與每個輸入點周圍局部擾動的魯棒性,相對于對抗訓練的優(yōu)點是,不需要標簽信息就可以定義對抗方向,因此適用于半監(jiān)督學習。
由于未標記數(shù)據(jù)的輸出標簽真實分布是未知的,當標記樣本個數(shù)很大時,可用當前的模型輸出近似代替未知的真實標簽,并基于虛擬標簽計算對抗方向。
(8)
(9)
(10)
將式(8)、(10)結合,得到總的損失:
(11)
虛擬對抗損失反映了當前模型在每個輸入數(shù)據(jù)點的局部平滑度,當其減小時,會使得模型在每個數(shù)據(jù)點處更加平滑。相比于其它正則化約束,ILN-SS VAT模型針對對抗方向的擾動進行輸出平滑,可提高模型對噪聲的魯棒性,防止過擬合。
1.2.2 重構損失
梯形網(wǎng)絡框架中編碼器的每層都有一個跳躍連接(skip connection)到解碼層,有利于恢復編碼器丟棄的信息,減輕編碼器最高層特征表示的壓力,也可以避免梯度消失的問題,這使得梯形網(wǎng)絡框架能夠與有監(jiān)督算法兼容。
向有噪編碼器的每層施加隨機高斯噪聲nl(l=0,1,2,…),得到有噪聲輸入為(x′l為未添加噪聲數(shù)據(jù))
(12)
(13)
(14)
有噪編碼各層的特征變量通過跳躍連接映射到對應的解碼層,將無噪編碼器每層的特征變量作為目標值,通過無監(jiān)督訓練盡可能多地恢復出未添加噪聲數(shù)據(jù)的信息特征。按照有噪編碼器的思路,對解碼器的每層都進行批歸一化處理,不同的是在此之前需要進行降噪。根據(jù)Pezeshki等[14]的研究成果,已知有噪數(shù)據(jù)的特征變量μ和無噪數(shù)據(jù)的先驗分布ε,可得到最優(yōu)降噪函數(shù)
g(l)(x)=εx+(1-ε)μ=(x-μ)ε+μ。
(15)
解碼器每層的特征變量表示為
(16)
同樣的,對無噪編碼的每層進行批歸一化處理:
(17)
(18)
其中,λl超參數(shù)是第l層占的比重。通過最小化重構損失函數(shù),使輸出的重構樣本盡可能多的恢復原有數(shù)據(jù)信息。
1.2.3 分類損失
ILN-SS VAT模型中的分類器采用Softmax來構建有監(jiān)督分類損失。將有標記樣本輸入有噪編碼器得到標簽預測值,計算預測值與真實標簽之間的交叉熵,得到有監(jiān)督分類損失:
(19)
根據(jù)式(11)、(18)和(19),可得訓練分類器總的損失函數(shù):
L=Vloss+Rloss+Closs。
(20)
虛擬對抗損失、有監(jiān)督損失和無監(jiān)督損失都可以通過梯度下降法達到最小化,因此,采用將其結合的方式共同調(diào)整網(wǎng)絡參數(shù),通過虛擬對抗訓練的方式進一步提高模型的泛化性。
通過在MNIST數(shù)據(jù)集、SVHN上對ILN-SS VAT模型的學習特征能力和分類性能進行評估。
MNSIT手寫字符數(shù)據(jù)集:MNIST有10個類別,包括60 000個訓練樣本,10 000個測試樣本。圖像均為單通道黑白圖像,大小為28×28的手寫字符。參考文獻[4-6,13],從訓練樣本中分別選取N1=100、N2=1 000個有標記數(shù)據(jù)作有監(jiān)督訓練,其余為無標記數(shù)據(jù)。
SVHN街牌號數(shù)據(jù)集:SVHN包括73 257個訓練樣本和26 032個測試樣本。圖像為彩色圖像,大小為32×32,每張圖片上有一個或多個數(shù)字,且圖像類別以識別的正中數(shù)據(jù)為準。參考文獻[4-6,13],從訓練樣本中分別選取N1=100、N2=1 000個有標記數(shù)據(jù)作有監(jiān)督訓練,其余為無標記數(shù)據(jù)。
實驗硬件采用Intel Xeon E5-2687 W CPU、32 GiB內(nèi)存和GTX 1080 GPU平臺;軟件采用Windows系統(tǒng)、Python語言、Tensorflow深度學習框架。
ILN-SS VAT模型,對較簡單的MNIST數(shù)據(jù)集,網(wǎng)絡結構采用Π模型,主要由9個卷積層、3個池化層和1個全連接層組成。對較復雜的SVHN,網(wǎng)絡結構采用VGG-19,由16個卷積層、5個池化層和3個全連接層組成。模型為減少梯度對參數(shù)大小的依懶性,對每層編、解碼結構都采用了批歸一化處理,且設置批次大小為64。另外,實驗中將遍歷次數(shù)設置為總樣本數(shù)除以批次大小,以增加多樣性。為保證模型的收斂速度,采用指數(shù)衰減的方式更新學習率,并且定義初始學習率為0.02(MNIST)、0.003(SVHN)。實驗設置迭代150個epoch(MNIST),180個epoch(SVHN)。
為了進一步分析ILN-SS VAT模型中2個模塊的有效性,以梯形網(wǎng)絡框架為基準,提出2種混合方案。方案A不使用mix_up數(shù)據(jù)增強;方案B不使用虛擬對抗訓練。實驗結果如表1所示。
表1顯示,ILN-SS VAT模型顯著提高了數(shù)據(jù)集的分類精度。由此可發(fā)現(xiàn),mix_up數(shù)據(jù)增強和虛擬對抗訓練都是提升模型性能的重要因素。對MNIST數(shù)據(jù)集分別采樣100和1 000個有標記樣本,mix_up的改進率分別為0.32%和0.02%,虛擬對抗訓練的改進率分別為0.63%和0.07%,通過結合這2個模塊,改進率分別為0.67%和0.15%。結果顯示,利用mix_up數(shù)據(jù)增強和虛擬對抗損失相結合的模式,使得梯形網(wǎng)絡具備更強的學習能力,同時,也表明模型更具有效性。
表1 ILN-SS VAT模型及混合方案在MNIST數(shù)據(jù)集上的分類精度 %
圖2 ILN-SS VAT模型生成圖片與MNIST原始數(shù)據(jù)圖片對比
梯形網(wǎng)絡架構中,由于解碼器是編碼器逆運算的一個過程,也可通過觀察解碼后的圖片質(zhì)量評估模型的收斂程度。圖2和圖3分別顯示了ILN-SS VAT模型趨于收斂時,生成圖片與原始輸入圖片的對比情況。圖2對MNIST數(shù)據(jù)集選取1 000個有標記數(shù)據(jù),迭代訓練150次趨于收斂,可以看出,已基本恢復出原始輸入圖片的信息,生成圖片質(zhì)量較高。圖3對SVHN數(shù)據(jù)集選取1 000個有標記數(shù)據(jù),迭代訓練180次趨于收斂,生成的圖片能夠識別出正中的數(shù)字,基本能夠與部分原始輸入圖片相匹配。觀察圖2、3可以發(fā)現(xiàn),ILN-SS VAT模型對數(shù)據(jù)學習能力較強,并且在處理不同復雜程度的數(shù)據(jù)集時,都有很好的魯棒性。
圖3 ILN-SS VAT模型生成圖片與SVHN原始數(shù)據(jù)圖片對比
為了驗證提出的ILN-SS VAT模型的優(yōu)勢,按照提出的實驗配置和采樣不同有標記樣本數(shù)目,以分類精度為評價標準,與當前主要半監(jiān)督深層生成模型進行對比,實驗結果如表2、3所示。
表2 MNIST數(shù)據(jù)集上的分類精度 %
表3 SVHN數(shù)據(jù)集上的分類精度 %
實驗結果顯示,ILN-SS VAT模型具有更高的分類精度,證明了在半監(jiān)督學習中處理對抗擾動的重要性,同時也證明了數(shù)據(jù)增強對模型分類性能提升的有效性。通過訓練不同復雜程度的數(shù)據(jù)集,ILN-SS VAT模型仍具有較強的學習能力,表明其有很強的泛化能力。
為驗證ILN-SS VAT模型的泛化性,對SVHN數(shù)據(jù)集選取1 000個有標記數(shù)據(jù)進行訓練,并以基礎梯形網(wǎng)絡框架(VGG+Softmax)為基準實驗,對模型趨于收斂時的測試集的損失作對比試驗分析,如圖4所示。
圖4 模型趨于收斂時測試損失對比
從圖4可看出,基礎梯形網(wǎng)絡框架在訓練時,測試集損失開始會隨著迭代進行慢慢下降,隨著訓練次數(shù)的增多,損失漸漸增大,相比而言,提出的ILN-SS VAT模型,隨著訓練的進行,測試集損失趨于穩(wěn)定。這表明,ILN-SS VAT模型結合mix_up數(shù)據(jù)增強和虛擬對抗訓練的模式能有效改善過擬合的問題。
為了進一步提高半監(jiān)督深層生成模型的分類精度,減少過擬合,在梯形網(wǎng)絡框架基礎上,結合mix_up數(shù)據(jù)增強和虛擬對抗訓練,提出了一種基于改進梯形網(wǎng)絡的半監(jiān)督虛擬對抗訓練模型(ILN-SS VAT)。ILN-SS VAT模型相對其他方法有以下幾點優(yōu)勢:1)用mix_up對訓練數(shù)據(jù)做增強處理得到新的擴展數(shù)據(jù),解決了半監(jiān)督分類模型有標記樣本較少的問題;2)對梯形網(wǎng)絡框架施加虛擬對抗噪聲,通過構建平滑性正則化約束,可有效增強模型的泛化能力;3)利用梯形網(wǎng)絡的優(yōu)勢,通過對有監(jiān)督分類損失、無監(jiān)督重構函數(shù)和虛擬對抗損失總和的梯度下降來達到最小化,優(yōu)化網(wǎng)絡參數(shù),得到分類性能更好的分類器。實驗結果表明,針對不同復雜程度的圖像數(shù)據(jù)集,ILN-SS VAT模型可利用少量的有標記數(shù)據(jù)訓練得到更好的分類精度。同時,該模型也有一定的不足,訓練時存在參數(shù)過多的問題,并且采用數(shù)據(jù)增強擴充訓練數(shù)據(jù)導致訓練較為耗時,在之后的工作中將繼續(xù)研究如何能夠保證分類精度,同時又能有效減少模型訓練時間。