趙小強, 蔣紅梅
(1. 蘭州理工大學 電氣工程與信息工程學院, 甘肅 蘭州 730050; 2. 蘭州理工大學 甘肅省工業(yè)過程先進控制重點實驗室, 甘肅 蘭州 730050; 3. 蘭州理工大學 國家級電氣與控制工程實驗教學中心, 甘肅 蘭州 730050)
基于有監(jiān)督的深度卷積神經(jīng)網(wǎng)絡(luò)已經(jīng)在很多應用領(lǐng)域取得了顯著的成就[1],然而在有監(jiān)督學習中,一個泛化能力強的模型往往需要大量的標記數(shù)據(jù).數(shù)據(jù)的收集和標注過程費時費力、成本高,這使得基于有監(jiān)督的深度卷積神經(jīng)網(wǎng)絡(luò)算法難以訓練出一個泛化能力較強的模型.此外,在真實場景中,由于環(huán)境的變化,訓練數(shù)據(jù)集和真實數(shù)據(jù)集之間存在分布差異,導致有監(jiān)督的模型泛化能力較差[2].針對上述問題,領(lǐng)域適應(domain adaptation,DA)[3]應運而生,利用其他領(lǐng)域的標記樣本在領(lǐng)域間建立橋梁,同時對無標記的目標域進行適配,從而提高目標域的預測質(zhì)量.領(lǐng)域適應在計算機視覺和自然語言等領(lǐng)域有著較為廣泛的應用,如使用電商網(wǎng)站上已經(jīng)分類好的圖片數(shù)據(jù)對手機拍攝圖片進行分類[4],使用已有情感標記的電影評論數(shù)據(jù)對快餐評論數(shù)據(jù)進行情感預測[5],這些問題都屬于領(lǐng)域適應研究的范疇.領(lǐng)域適應消除了訓練數(shù)據(jù)和測試數(shù)據(jù)必須服從獨立同分布的限制,解決了標準監(jiān)督學習所面臨的訓練數(shù)據(jù)不足、訓練數(shù)據(jù)與測試數(shù)據(jù)存在分布偏移的問題.
近年來,DA算法取得了快速發(fā)展.早期的典型算法是基于最大均值差異(maximum mean difference,MMD),其核心思想是將不同域的樣本映射在同一特征空間,通過最小化不同域的特征表示的MMD,減小域間差異.基于MMD的DA算法主要包含:深度域混淆(deep domain confusion,DDC)[6]在兩個權(quán)值共享的CNN的特征層之間加入一個適應層,通過最小化源域及目標域特征之間的MMD減小域間差異,但MMD與核函數(shù)直接關(guān)聯(lián),如果選擇了表現(xiàn)較差的核函數(shù)則MMD表現(xiàn)也會較差,從而導致最終目標域的預測質(zhì)量較差;深度適應網(wǎng)絡(luò)(deep adaptation network,DAN)[7]在DDC的基礎(chǔ)上,為了避免單個MMD中核函數(shù)的選擇比較困難,以及獲取更多的有用信息,采用一個多核的MMD適配多個全連接層;聯(lián)合適配網(wǎng)絡(luò)(joint adaptation network,JAN)[8]則用一個聯(lián)合的MMD適配多個全連接層.
相比較早期的淺層領(lǐng)域適應,應用深層神經(jīng)網(wǎng)絡(luò)的領(lǐng)域適應算法在共享特征提取和分類精度上表現(xiàn)更為優(yōu)越.近年來,隨著生成對抗網(wǎng)絡(luò)(generative adversarial network,GAN)[9]不斷發(fā)展,基于對抗學習(adversarial learning,AL)[10]的算法得到廣泛研究.領(lǐng)域?qū)股窠?jīng)網(wǎng)絡(luò)(domain adversarial neural network,DANN)[11]是首次將對抗學習引入領(lǐng)域適應的方法,其通過特征提取器與判別器之間的對抗來獲得域不變特征,但未考慮特征的類別信息,導致目標域的預測精確度較低;受條件生成對抗網(wǎng)絡(luò)(conditional generative adversarial nets,CGAN)[12]的啟發(fā),對抗判別領(lǐng)域適應(adversarial discriminative domain adaptation,ADDA)[13]提出了一種基于對抗學習方法的統(tǒng)一框架,并根據(jù)是否使用生成器、使用何種損失函數(shù)和是否跨域共享權(quán)重等角度對現(xiàn)有方法進行了總結(jié),取得了較好的結(jié)果.
然而,在上述的領(lǐng)域適應算法中,對目標樣本的適應性較差,且通用的熵最小化函數(shù)使易轉(zhuǎn)移樣本的梯度大、難轉(zhuǎn)移樣本的梯度小,從而導致目標域的預測精確度較低.針對上述問題,本文提出基于可調(diào)節(jié)判別器的領(lǐng)域適應(A-DADA)算法.首先,設(shè)計了可調(diào)節(jié)判別器,在目標域中,利用兩個可調(diào)節(jié)判別器(2K維)分類概率的差值作為對抗訓練的衡量指標,旨在減少已對齊的目標樣本對抗訓練的次數(shù),增加未對齊目標樣本的對抗訓練次數(shù);同時,將平方熵損失作為最小熵損失函數(shù),降低了易轉(zhuǎn)移樣本的梯度幅度,提高了難轉(zhuǎn)移樣本的訓練效率.然后,使用隨機梯度下降法(stochastic gradient descent,SGD)以可調(diào)節(jié)的學習率策略對A-DADA網(wǎng)絡(luò)進行訓練,得到該網(wǎng)絡(luò)的具有良好遷移性的模型.最后,對目標域中的測試集樣本進行測試,經(jīng)多個對比實驗驗證,本文的算法具有更好的性能.
在DA中,通常包含兩個數(shù)據(jù)集,分別為源域數(shù)據(jù)集和目標域數(shù)據(jù)集.設(shè)Xs為含有標簽信息的源域數(shù)據(jù)集,對應的標簽信息數(shù)據(jù)集為Ys,ys為Xs中的一個樣本xs對應的標簽,Xt為無標簽信息的目標域數(shù)據(jù)集,xt為Xt中的一個樣本.源域和目標域存在分布差異,本文的目的是設(shè)計一種模型來盡可能減小源域和目標域之間的差異[14-15],使得使用源域數(shù)據(jù)訓練出的模型,能夠有效地適用于目標域.
多數(shù)基于特征對齊的領(lǐng)域適應算法的基本結(jié)構(gòu)通常由兩部分組成,分別為特征提取器G及分類器F.首先,將樣本xs或xt輸入到特征提取器網(wǎng)絡(luò)G,獲得對應的特征圖,將特征圖輸入到分類器網(wǎng)絡(luò)F,輸出K維的類概率分布向量p(y|x),其中K為類別數(shù)[16];然后,特征提取器與分類器之間進行對抗訓練直到達到最優(yōu)結(jié)果.還有些DA算法考慮了類邊界和目標樣本之間的關(guān)系,其基本結(jié)構(gòu)由三部分組成,即在上述結(jié)構(gòu)的基礎(chǔ)上,又加入一個分類器,分別為F1、F2.首先,訓練兩個分類器F1和F2,使目標樣本的特征差異最大化,從而有效地檢測源域支持外的目標樣本;其次,通過訓練特征提取器G“欺騙”判別器F1和F2,使目標樣本的差異最小化,鼓勵在源域的支持范圍內(nèi)生成目標樣本.但上述算法對目標域樣本的遷移性考慮不足,同時使用一般的熵損失函數(shù)在易轉(zhuǎn)移樣本的梯度幅度大,在難轉(zhuǎn)移樣本的梯度幅度小,從而導致類別不平衡.
本文主要從以下兩個方面解決上述問題,首先,提出可調(diào)節(jié)判別器,減少已對齊目標樣本的對抗訓練的次數(shù),增加未對齊目標樣本的對抗訓練的次數(shù);其次,利用平方熵損失函數(shù),旨在降低易轉(zhuǎn)移樣本的梯度幅度,增加難轉(zhuǎn)移樣本的梯度幅度.
本文中的兩個判別器D1、D2的輸出設(shè)定為2K維向量[17],第一個K維是源域的類分布,第二個K維是目標域的類分布,從而同時學習域和類變量的對齊.在目標域的對抗訓練中,首先,目標域數(shù)據(jù)無標簽信息,使用對應的偽標簽y′t;其次,判別器為2K維,因此判別器的正確輸出應該是[0,y′t],而特征提取器“欺騙”判別器將其分類到源域中,即[y′t,0];最終,把兩個判別器的分類概率的距離ldt作為權(quán)重應用到對抗訓練損失函數(shù)上,旨在減少已對齊的目標樣本對抗訓練的次數(shù),增加未對齊目標樣本分布,在單個判別器中實現(xiàn)域級別和類級別的對抗訓練的次數(shù),從而構(gòu)成可調(diào)節(jié)判別器.因此,可調(diào)節(jié)判別器在目標域樣本上的對抗損失為
(1)
其中:f(x)=F(G(x));fD1(x)=D1(G(x));fD2(x)=D2(G(x));G為特征提取器;F為分類器.
為了進一步提高目標樣本的適應性,對目標域樣本在類預測器上的熵損失函數(shù)進行熵最小化,一般采用的熵損失函數(shù)為香濃熵損失函數(shù),表示為
(2)
考慮到二分類的情況,其對應的梯度函數(shù)為
(3)
由式(3)可知,高概率類別的梯度比中概率類別的梯度大得多.然而,香農(nóng)熵損失函數(shù)最小化是由目標樣本的高概率類別主導,忽略了中低概率類別,因此,本文提出平方熵損失函數(shù)替代香農(nóng)熵損失函數(shù),其表示如下:
(4)
對應的梯度函數(shù)為
(5)
由式(5)可知,熵損失函數(shù)的梯度與對應的類別概率為線性關(guān)系.與香農(nóng)熵最小化方法相比,雖然高概率類別仍然有較大的梯度,但它的主導作用已經(jīng)減弱,使得中概率類別具有與高概率類別相差不大的訓練梯度,因此,平方熵損失函數(shù)對不同的類別具有更均衡的梯度.
基于可調(diào)節(jié)判別器領(lǐng)域適應的損失函數(shù)主要包含三部分:第一部分為類預測器的基于源域數(shù)據(jù)的分類損失函數(shù)lCE(f(x),y)和基于目標域數(shù)據(jù)的平方熵損失函數(shù)lte(F),第二部分為兩個可調(diào)節(jié)判別器對應的分類損失ldsc1、ldsc2、ldtc1、ldtc2和對抗損失ldsa1、ldsa2、ldta1、ldta2,第三部分為兩個可調(diào)節(jié)判別器輸出同一域的類概率差值的絕對值之和ld.基于可調(diào)節(jié)判別器領(lǐng)域適應的損失函數(shù)如下式所示(可調(diào)節(jié)判別器的損失函數(shù)以D1為例,D2與其具有相同的形式):
(6)
其中:lCE(f(x),y)=-〈y,logf(x)〉,為交叉熵損失函數(shù).
本文提出的基于可調(diào)節(jié)判別器的領(lǐng)域適應的結(jié)構(gòu)(圖1)由四部分組成,分別為特征提取器G、兩個可調(diào)節(jié)判別器D1、D2和一個類預測器F.其中,類預測器的輸出為K維向量,可調(diào)節(jié)判別器的輸出為2K維向量.對于源域樣本xs和目標域樣本xt,首先,使用共享的特征提取器G來提取樣本特征,分別得到源域樣本的特征G(xs)和目標域樣本的特征G(xt);其次,源域樣本的特征分別輸入到兩個可調(diào)節(jié)判別器D1、D2及類預測器F中,而目標域樣本的特征不需要輸入到類預測器中,分別得到源域樣本的類別預測概率和目標域樣本的類別預測概率;然后,對于目標域,將兩個判別器輸出的類別預測概率的差值作為權(quán)重應用在判別器的對抗損失上,得到關(guān)于目標域的對抗損失;最后,將平方熵損失函數(shù)作為熵最小化損失函數(shù),以提高類別的平衡性.
圖1 A-DADA structure diagram
A-DADA算法流程如下所示:
Input:源域數(shù)據(jù)集Ds=(Xs,Ys),目標域數(shù)據(jù)集Dt=Xt,訓練次數(shù)分別為K1、K2,Batch Size的大小為n.
Step1:采用ImageNet[21]預訓練模型的參數(shù)初始化網(wǎng)絡(luò)層參數(shù);
Step2: forkin 1:K1do
Step2.2:xsn通過網(wǎng)絡(luò)G得到G(xsn),再分別通過F、D1、D2網(wǎng)絡(luò)計算得到F(G(xsn))、D1(G(xsn))和D2(G(xsn));
Step2.3:根據(jù)式(6)中對應的損失函數(shù)lsc、ldsc1和ldsc2的計算公式,訓練分類器和可調(diào)節(jié)判別器對源域樣本進行正確分類,目標函數(shù)為
Step3: forkin 1:K2do
Step3.2:xsn、xtn通過網(wǎng)絡(luò)G得到G(xsn)和G(xtn),再分別通過F、D1、D2網(wǎng)絡(luò)計算得到F(G(xsn))、F(G(xtn))、D1(G(xsn))、D1(G(xtn))和D2(G(xtn));
Step3.3:根據(jù)式(6)中對應的損失函數(shù)lsc、lte、ldsc2、ldtc2、ldsc1、ldtc1和ld的計算公式,訓練類預測器和判別器,目標函數(shù)為
λdsc1ldsc1+λdtc1ldtc1-λdld
Step3.4:根據(jù)式(6)中對應的損失函數(shù)ldsa1、ldta1、ldsa2、ldta2和ld的計算公式,訓練特征提取器,目標函數(shù)如下:
Step4:end for
本文實驗使用的數(shù)據(jù)集Office-31[19]是一個基于圖片領(lǐng)域適應的應用較為廣泛的數(shù)據(jù)集,一共包含4 652張圖片,分為31個類別,這些圖片源于3個不同的領(lǐng)域,分別為Amazon(A)、Webcam(W)和DSLR(D).其中,Amazon為電商網(wǎng)站Amazon.com的商品展示圖片;Webcam為圖像處理軟件Webcam處理后的圖片;DSLR為數(shù)碼單反相機拍攝的圖片.實驗中,將這3個領(lǐng)域的數(shù)據(jù)集設(shè)置6種遷移任務,即A→W、D→W、A→D、W→D、D→A和W→A.
本文實驗所用的深度學習框架為Pytorch[20],在搭載GPU為GTX1080Ti的服務器實驗環(huán)境下使用Python3.6,在網(wǎng)絡(luò)框架中,使用ResNet-50作為基礎(chǔ)的特征提取器,其初始學習率為0.004,其中ResNet的初始參數(shù)為使用ImageNet[21]預訓練的模型參數(shù).類預測器和兩個判別器由兩個全連接層構(gòu)成,其初始學習率為0.04,對于優(yōu)化器的設(shè)置,采用動量為0.9的SGD來更新參數(shù),同時采用與文獻[22]相同的優(yōu)化策略,學習率ηp由公式ηp=η0/(1+αp)β計算所得,其中p指模型訓練完成程度,范圍為0~1.0,并設(shè)置η0=0.01、α=10和β=0.75的優(yōu)化器參數(shù)組合,其他參數(shù)的最優(yōu)設(shè)置見表1.
表1 參數(shù)的最優(yōu)設(shè)置
為了客觀地比較算法的優(yōu)劣,實驗依次使用ResNet、DANN、ADDA、JAN算法和本文算法作對比實驗,基于Office-31數(shù)據(jù)集的實驗結(jié)果見表2.相較于其他算法,本文的A-DADA算法在多個遷移任務上具有更好的性能,與ResNet、DANN、ADDA、JAN算法相比,平均精確度分別提高了10.7%、4.6%、3.9%、2.5%;在Office-31數(shù)據(jù)集的6個遷移任務中得到提升,尤其在A→W和A→D兩個任務上有較大的提升.但是在源域數(shù)據(jù)集較小的兩個遷移任務D→A和W→A的精確度較低,說明本文的算法還存在一定局限性,其主要原因是對于源域數(shù)據(jù)集較小的領(lǐng)域,模型的適應性能被弱化.但由于本文算法較好地考慮了類別平衡及目標樣本的適應性,使其在整體性能上優(yōu)于其他對比算法.
表2 基于Office-31數(shù)據(jù)集的實驗結(jié)果
本文對Office-31數(shù)據(jù)集上的任務A→W的訓練曲線進行可視化,如圖2所示,進一步分析算法的穩(wěn)定性和收斂性,其中前10 000次迭代是無領(lǐng)域適應時的目標樣本的預測平均精確度.由圖可知,未加入領(lǐng)域適應時,目標樣本的預測精確度較低且處于震蕩狀態(tài),加入領(lǐng)域適應后,目標樣本測試的平均精確度快速上升,并最終趨于穩(wěn)定.
圖2 目標樣本的平均精確度Fig.2 Average accuracy of the target sample
為了進一步驗證算法的有效性,使用本文算法在Office-31數(shù)據(jù)集上的特征可視化圖片如圖3所示,紅點表示源域數(shù)據(jù),藍點表示目標域數(shù)據(jù).從圖3可以看出,在無領(lǐng)域適應時,目標域數(shù)據(jù)散亂地分布,也未觀察到任何關(guān)于目標域間的分類信息及源域和目標域域間的適應信息,這說明源域數(shù)據(jù)與目標域數(shù)據(jù)之間存在較大差異;而在使用A-DADA算法對其進行領(lǐng)域適應分類后,源域數(shù)據(jù)和目標域數(shù)據(jù)的類間距離變小,具有相同類別的源域樣本和目標域樣本較好地擬合在一起,進一步驗證了本文算法的有效性.
圖3 T-SNE feature visualization
為了提高基于對抗學習的領(lǐng)域適應(DA)對目標樣本的適應性,本文提出了A-DADA算法.算法的網(wǎng)絡(luò)結(jié)構(gòu)主要由特征提取器G、兩個可調(diào)節(jié)判別器D1、D2和一個類預測器F連接組成.該算法將源域和目標域數(shù)據(jù)輸入到網(wǎng)絡(luò)中,經(jīng)過特征提取器G與可調(diào)節(jié)判別器間的對抗訓練及判別器間的對抗訓練,使該網(wǎng)絡(luò)在目標域上具有更好的適應性.與ResNet-50、DANN、ADDA、JAN算法相比,本文算法在Office-31數(shù)據(jù)集上的平均精確度得到了提高,從而有效地提高了目標域的預測精確度.
在下一步研究中,將探求如何解決在源域數(shù)據(jù)集較小的兩個領(lǐng)域中的遷移能力弱化的問題,從而使模型更具適應性.