王愛麗, 薛 冬, 吳海濱, 王敏慧
(哈爾濱理工大學(xué) 測控技術(shù)與通信工程學(xué)院,黑龍江 哈爾濱 150080)
圖像識別是計算機(jī)視覺領(lǐng)域中常見的任務(wù)之一,通過提取數(shù)字圖像中有效的特征信息,來賦予圖像各自的標(biāo)簽類別,從而完成識別過程。目前,圖像識別已經(jīng)廣泛應(yīng)用于醫(yī)療診斷[1]、交通標(biāo)志識別[2]等領(lǐng)域中。
近年來,深度學(xué)習(xí)的快速發(fā)展使得其在圖像識別領(lǐng)域取得了很多成果。相比于傳統(tǒng)的人工識別方法,使用深度學(xué)習(xí)方法可以提取到圖像更多的深層次特征,從而提高識別準(zhǔn)確率。然而,深度學(xué)習(xí)方法需要大量的訓(xùn)練數(shù)據(jù)才能取得好的識別效果,所以對于數(shù)據(jù)量較小的數(shù)據(jù)集,識別效果并不理想,而生成對抗網(wǎng)絡(luò)(GAN)的提出很大程度上緩解了這種問題[3]。
GAN的核心思想來源于博弈論的納什均衡[4],它含有一個生成器和一個判別器,生成器在訓(xùn)練過程中將生成的假樣本送入判別器來擴(kuò)充訓(xùn)練集的數(shù)據(jù)量,這使得判別器學(xué)習(xí)到更多的圖像特征,從而提高分類準(zhǔn)確率。因此,基于GAN的這種優(yōu)點,許多學(xué)者將其應(yīng)用到圖像分類、識別領(lǐng)域。Zhu等人[5]將GAN應(yīng)用到高光譜數(shù)據(jù)的分類中,提出了利用空間特征和光譜特征的1D-GAN和3D-GAN,顯著提高了高光譜數(shù)據(jù)的分類精度。劉坤等人[6]使用半監(jiān)督GAN實現(xiàn)了X光圖像的分類,充分利用有標(biāo)簽的數(shù)據(jù)提高分類效果。楊旺功等人[7]利用深度卷積生成對抗網(wǎng)絡(luò)(DCGAN)實現(xiàn)了花朵圖像的分類,實驗結(jié)果表明所提方法的分類準(zhǔn)確率高且穩(wěn)定性好。Kuang等人[8]首次將GAN用于肺結(jié)節(jié)惡性腫瘤的無監(jiān)督分類中,實驗表明此方法使用少量標(biāo)記樣本就能取得較好的分類結(jié)果。
雖然GAN在圖像處理領(lǐng)域具有較好的表現(xiàn),但是GAN存在訓(xùn)練太過自由,生成的圖像不可控等問題,因此Mirza等人[9]提出了條件生成對抗網(wǎng)絡(luò)(CGAN),在生成器和判別器中加入圖像的類別標(biāo)簽,使得生成的圖片可以人為控制。因此,本文將CGAN和條件批處理歸一化[10-11](CBN)結(jié)合,提出CBN-CGAN網(wǎng)絡(luò)。CBN利用類別標(biāo)簽對每一類數(shù)據(jù)進(jìn)行批處理歸一化,使網(wǎng)絡(luò)學(xué)習(xí)到更多的特征,提高模型魯棒性。
條件生成對抗網(wǎng)絡(luò)是在GAN的基礎(chǔ)上發(fā)展而來的,通過把附加信息y加入到生成器G和判別器D的輸入中,使得GAN的無監(jiān)督學(xué)習(xí)變?yōu)橛斜O(jiān)督學(xué)習(xí)。訓(xùn)練階段,隨機(jī)噪聲z與條件變量y同時輸入到G中,得到盡量服從真實數(shù)據(jù)分布Pdata的生成數(shù)據(jù)G(z|y);然后將真實數(shù)據(jù)x、條件變量y和G生成的數(shù)據(jù)G(z|y)同時輸入到D中,最后輸出一個標(biāo)量來估計輸入數(shù)據(jù)來自真實數(shù)據(jù)的概率。其目標(biāo)函數(shù)為:
.
(1)
D的目標(biāo)是實現(xiàn)對數(shù)據(jù)來源的二分類判別,G的目標(biāo)是最大化D判斷輸入數(shù)據(jù)錯誤的概率。因此,這兩個相互對抗并迭代優(yōu)化的過程使得G和D的性能不斷提升。當(dāng)最終D無法正確判別數(shù)據(jù)真假時,可以認(rèn)為G已經(jīng)學(xué)到了真實的數(shù)據(jù)分布,二者達(dá)到平衡狀態(tài)。
深度學(xué)習(xí)算法的一個缺點是網(wǎng)絡(luò)在訓(xùn)練階段容易出現(xiàn)無法收斂的問題。因此,谷歌公司通過對輸入數(shù)據(jù)進(jìn)行批量歸一化(BN)處理,提高了網(wǎng)絡(luò)的訓(xùn)練速度,解決梯度消失的問題[12]。但是在條件生成模型中存在一個問題,不同類別的訓(xùn)練數(shù)據(jù)放在一起做批處理歸一化不妥當(dāng)。因為在一個批量的數(shù)據(jù)中含有不同類別的數(shù)據(jù),而每個類別的數(shù)據(jù)在計算后得到的均值和方差的數(shù)值也是不同的,因此在還原數(shù)據(jù)階段,應(yīng)該采用不同的標(biāo)準(zhǔn)化、平移和縮放處理的方式來還原每一類數(shù)據(jù)。
條件批量歸一化[10-11](CBN)處理數(shù)據(jù)時不使用整個批量的統(tǒng)計數(shù)據(jù),而是在各類數(shù)據(jù)特征圖內(nèi)部做歸一化,依賴類別標(biāo)簽還原出原始的每一類數(shù)據(jù)。CBN對數(shù)據(jù)的批處理如下:
,
(2)
(3)
,
(4)
.
(5)
式(2)和(3)分別計算了數(shù)據(jù)的均值和方差;式(4)和(5)對數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化、平移和縮放的處理。
本文網(wǎng)絡(luò)模型將CGAN與CBN相結(jié)合,稱為CBN-CGAN網(wǎng)絡(luò),網(wǎng)絡(luò)模型結(jié)構(gòu)如圖1所示,主要包括3方面的改進(jìn)。
圖1 CBN-CGAN網(wǎng)絡(luò)結(jié)構(gòu)圖
(1)將生成器和判別器的網(wǎng)絡(luò)結(jié)構(gòu)改為全卷積的網(wǎng)絡(luò)結(jié)構(gòu),借助卷積神經(jīng)網(wǎng)絡(luò)(CNN)強(qiáng)大的提取特征的能力,更好地學(xué)習(xí)圖像的特征。在判別器輸出末端加入Softmax分類器,使模型適用于多分類任務(wù)。
(2)在生成器和判別器的網(wǎng)絡(luò)層中添加CBN,充分利用類別標(biāo)簽來對每一類數(shù)據(jù)進(jìn)行批量化處理,使得網(wǎng)絡(luò)充分學(xué)習(xí)到特征圖中每個類別的特征,提高生成圖像的質(zhì)量和識別精度。同時,提升模型穩(wěn)定性,加快收斂速度,緩解梯度消失的問題。
(3)提出新的目標(biāo)損失函數(shù),使模型最終的輸出既包括對圖像來源的真假判別,即將真實圖像判斷為真,將生成器生成的圖像判斷為假,又包括對圖像類別標(biāo)簽的多分類結(jié)果。
本文網(wǎng)絡(luò)模型的生成器是一個全卷積的網(wǎng)絡(luò)結(jié)構(gòu),使用反卷積層代替池化層,反卷積核的大小為4×4,并且去掉全連接層來增加模型穩(wěn)定性。首先,輸入端是兩個反卷積層,它們分別將100維的噪聲z和n維的類別標(biāo)簽(假設(shè)數(shù)據(jù)集有n個類別)進(jìn)行維度轉(zhuǎn)換,轉(zhuǎn)換成(256,4,4)的三維張量,然后將二者連接成(512,4,4)的三維張量作為下一個反卷積層的輸入。最后,3個反卷積層將這個3維張量轉(zhuǎn)換成(1,32,32)的圖像。本文在中間兩個反卷積層后添加CBN,CBN對數(shù)據(jù)進(jìn)行歸一化處理得到均值和方差,利用類別標(biāo)簽控制偏置因子和縮放因子的取值,從而在映射原始數(shù)據(jù)時充分還原出每一類數(shù)據(jù),使網(wǎng)絡(luò)充分學(xué)習(xí)到各類數(shù)據(jù)的特征分布。
CBN還可以解決初始化差和模式崩塌的問題,同時確保梯度傳播到模型的深層。此外,為了不降低模型穩(wěn)定性,在輸入端和輸出端的反卷積層后不添加CBN。其次,本文在中間兩個反卷積層后添加ReLU 激活函數(shù)層,提高模型的學(xué)習(xí)速度,在最后一個反卷積層后添加Tanh激活函數(shù)層。CBN-CGAN的生成器網(wǎng)絡(luò)結(jié)構(gòu)如圖2所示。
圖2 CBN-CGAN生成器網(wǎng)絡(luò)結(jié)構(gòu)圖
本文判別器網(wǎng)絡(luò)結(jié)構(gòu)和生成器網(wǎng)絡(luò)結(jié)構(gòu)類似,也是一個全卷積的網(wǎng)絡(luò)結(jié)構(gòu),含有5個卷積層,卷積核的大小為4×4,不含池化層和全連接層。在輸入端,訓(xùn)練數(shù)據(jù)和類別標(biāo)簽分別通過卷積層進(jìn)行維度轉(zhuǎn)換,轉(zhuǎn)換成(64,16,16)的三維張量,再將二者連接成(128,16,16)的三維張量。接著依次經(jīng)過兩個卷積層輸出(512,4,4)的三維張量到達(dá)輸出端。輸出端分為兩部分:一部分判斷數(shù)據(jù)的真實來源,另一部分輸出分類結(jié)果。
與生成器相同,本文在除了輸入端和輸出端的卷積層后添加CBN,提升訓(xùn)練效果。本文在判別器中使用LeakyReLU激活函數(shù),它是ReLU激活函數(shù)的改進(jìn)版,在判別器中表現(xiàn)得更好。因此,本文在中間兩個卷積層后添加LeakyReLU激活函數(shù)層。在輸出端,本文在最后兩個卷積層后分別使用Sigmoid分類器和Softmax分類器完成不同的任務(wù),一是Sigmoid分類器輸出數(shù)據(jù)的真假判別結(jié)果,二是Softmax分類器輸出真實數(shù)據(jù)和生成器生成的數(shù)據(jù)的分類結(jié)果。
在分類時,假設(shè)數(shù)據(jù)集有n個類別。首先,每個由生成器生成的假數(shù)據(jù)通過網(wǎng)絡(luò)向前傳遞,并通過獲取概率預(yù)測向量的最大值來分配一個標(biāo)簽。因此,這些假數(shù)據(jù)可用于在網(wǎng)絡(luò)中使用這些標(biāo)簽進(jìn)行訓(xùn)練。此外,假數(shù)據(jù)不屬于任何類型的真實數(shù)據(jù)。由于真實數(shù)據(jù)和假數(shù)據(jù)的不同,所以創(chuàng)建了一個新的類別標(biāo)簽(n+1類)劃分它們,每個假數(shù)據(jù)都被賦予這個新的類別標(biāo)簽。本文采用這種方法最后輸出n+1個類別的識別結(jié)果。CBN-CGAN的判別器網(wǎng)絡(luò)結(jié)構(gòu)如圖3所示。
圖3 CBN-CGAN判別器網(wǎng)絡(luò)結(jié)構(gòu)圖
本文設(shè)計的網(wǎng)絡(luò)模型的目標(biāo)損失函數(shù)包括兩部分:一部分是判斷數(shù)據(jù)來源真假的LS,另一部分是對真實數(shù)據(jù)和生成器生成的數(shù)據(jù)進(jìn)行分類的LC。該目標(biāo)損失函數(shù)如下所示:
,
(6)
,
(7)
,
(8)
式中,LS將輸入的真實數(shù)據(jù)判斷為真,將生成器生成的數(shù)據(jù)判斷為假。LC分為兩部分,一部分是真實數(shù)據(jù)對應(yīng)的分類結(jié)果,它應(yīng)該被檢測為前n種類別中其唯一對應(yīng)的類別,即檢測為真實類別的概率為1;另一部分是使生成器生成的假數(shù)據(jù)可以被分類為第n+1類,即假數(shù)據(jù)被檢測成真實類別的概率為0,也即假數(shù)據(jù)對應(yīng)的分類結(jié)果為第n+1類的概率為1。對于判別器,其最終目的是最大化LS+LC;生成器的目的是最小化LS-LC。
本文實驗在Windows操縱系統(tǒng)下進(jìn)行,基于開源深度學(xué)習(xí)框架Pytorch,使用的編程語言為Python,實驗設(shè)備包括Intel(R) Core(TM) i5-6500 CPU @ 3.2 GHz處理器,16 GB運行內(nèi)存(RAM),NVIDIA GeForce GTX 1060 GPU。實驗所用的數(shù)據(jù)集為MNIST數(shù)據(jù)集,它是機(jī)器學(xué)習(xí)領(lǐng)域最常見的手寫字體數(shù)據(jù)集,包含0~9的10類手寫數(shù)字圖像,每類數(shù)字包含60 000個訓(xùn)練樣本和10 000個測試樣本,每幅圖像是28×28像素的灰度圖像。
實驗階段,為了讓生成器盡可能學(xué)習(xí)到所有樣本的數(shù)據(jù)分布,需要對樣本進(jìn)行歸一化處理,對類別標(biāo)簽數(shù)據(jù)進(jìn)行one-hot編碼處理。網(wǎng)絡(luò)輸入的圖像大小為32×32,批處理大小設(shè)置為32,由于數(shù)據(jù)集為灰度的手寫數(shù)字圖像,特征相對較少,因此訓(xùn)練迭代了30個Epoch。在訓(xùn)練階段,網(wǎng)絡(luò)采用Adam優(yōu)化器,學(xué)習(xí)率設(shè)置為0.000 2,權(quán)重衰減設(shè)置為0.000 5。
為了評估本文所設(shè)計的CBN-CGAN網(wǎng)絡(luò)模型的性能,將本文網(wǎng)絡(luò)模型與傳統(tǒng)的深度學(xué)習(xí)網(wǎng)絡(luò)做了對比。實驗評價標(biāo)準(zhǔn)一方面是生成圖像的質(zhì)量,另一方面是識別的準(zhǔn)確率。圖4為本文方法與CGAN和深度卷積生成對抗網(wǎng)絡(luò)[13](DCGAN)在Epoch次數(shù)為30時對生成圖像所做的對比。從對比圖中可以直接看出,本文提出的CBN-CGAN網(wǎng)絡(luò)比其他網(wǎng)絡(luò)生成的圖像質(zhì)量更好,更清晰,例如數(shù)字“2”和數(shù)字“8”的輪廓更明顯,更容易辨別。
圖4 不同方法生成MNIST圖像對比結(jié)果
表1為本文識別方法與其他方法的識別精度對比。其中決策樹(Decision Tree)的最大深度設(shè)置為100;支持向量機(jī)(SVM)中的核函數(shù)設(shè)置為徑向基核函數(shù)(rbf),rbf的系數(shù)默認(rèn)為“auto”,懲罰參數(shù)設(shè)置為100;隨機(jī)森林(Random Forest)中樹的數(shù)目設(shè)置為200;卷積神經(jīng)網(wǎng)絡(luò)(CNN)的網(wǎng)絡(luò)結(jié)構(gòu)與本文判別器網(wǎng)絡(luò)結(jié)構(gòu)相似,卷積核大小也為4×4,網(wǎng)絡(luò)中加入ReLU激活函數(shù),BN層,Softmax分類器,數(shù)據(jù)輸入之前做歸一化處理,學(xué)習(xí)率設(shè)置為0.001;在CGAN和DCGAN的判別器輸出端分別添加Softmax分類器,使這兩個網(wǎng)絡(luò)可以完成分類任務(wù),學(xué)習(xí)率都設(shè)為0.000 2。
表1 MNIST數(shù)據(jù)集識別準(zhǔn)確率和識別時間
從表1可以看出,本文所提出的CBN-CGAN網(wǎng)絡(luò)模型達(dá)到的識別準(zhǔn)確率最高,為99.43%,分別比DCGAN、CGAN、CNN、隨機(jī)森林、SVM和決策樹高出0.57%,0.86%,1.07%,2.28%,2.83%和11.5%的準(zhǔn)確率。雖然本文模型在識別時所消耗的時間有一定程度的增加,但是卻提高了準(zhǔn)確率,因此證明了本文所提方法在手寫體數(shù)字識別時可以更好地提取特征,有效提高了識別準(zhǔn)確率。
圖5為本文所提出的CBN-CGAN網(wǎng)絡(luò)的判別器損失函數(shù)曲線與原始CGAN網(wǎng)絡(luò)的判別器損失函數(shù)曲線的對比。從圖中可以看出本文所提出的網(wǎng)絡(luò)判別器的收斂速度更快,且隨著迭代次數(shù)的增加越來越穩(wěn)定。
圖5 判別器損失函數(shù)曲線對比
本文結(jié)合條件生成對抗網(wǎng)絡(luò)與條件批處理歸一化提出CBN-CGAN、改進(jìn)生成器和判別器的結(jié)構(gòu),提出新的損失函數(shù)。利用CBN融合類別標(biāo)簽的優(yōu)勢,使網(wǎng)絡(luò)提取更多的圖像特征,提高了模型的魯棒性。在MNIST數(shù)據(jù)集上的實驗結(jié)果表明,相比于其他方法,本文所提出的方法雖然消耗的時間較長,但是生成的圖像質(zhì)量更好,同時手寫體數(shù)字識別的準(zhǔn)確率達(dá)到99.43%,驗證了本文所提出的CBN-CGAN模型在手寫體數(shù)字識別領(lǐng)域中的有效性。