黃曉琪,王 莉,李 鋼
1.太原理工大學 大數據學院,山西 晉中030600
2.太原理工大學 軟件學院,山西 晉中030600
隨著人工智能的飛速發(fā)展,深度學習技術已在計算機視覺、語音識別、自然語言處理等多個研究領域取得了諸多成績。其中在生成圖像方面取得的進展有:Rematas等[1]使用足球比賽視頻數據訓練網絡,從而提取3D 網格信息,進行動態(tài)3D 重建。Oord 等[2]利用自回歸模型產生了清晰的合成圖。Dosovitskiy等[3]訓練了一個反卷積網絡,根據一組指示形狀、位置和照明的圖形代碼生成三維椅子效果圖。何新宇等[4]提出了一種基于深度卷積神經網絡的肺炎圖像識別模型用于肺炎圖像的識別。近年來,生成對抗網絡在生成高質量圖像方面顯示出了巨大優(yōu)勢,它是由GoodFellow 等[5]提出的,該模型在工業(yè)界和學術界都有廣泛的應用,除了圖像外,它還可以應用于視頻和語音領域[6-7]。
隨著深度學習的不斷發(fā)展,越來越多的生成對抗網絡模型被提出,在生成圖像方面取得了越來越好的效果:Reed等[8]提出了GAN-INT-CLS模型,首次利用GAN有效地生成以文本描述為條件的64×64圖像。然而,在許多情況下,它們合成的圖像缺少逼真的細節(jié)和生動的物體部分,例如鳥的喙、眼睛和翅膀;此外,它們無法合成更高分辨率的圖像(例如128×128或256×256)。Reed等[9]為了更好地根據文本描述控制圖像中物體的具體位置,提出了GAWWN(Generative Adversarial What-Where Network)模型,把額外的位置信息與文本一起作為約束條件加入到生成器和判別器的訓練中。Wang等[10]利用提出的樣式結構生成對抗網絡(Style and Structure Generative Adversarial Networks,S2-GAN)模型以結構生成和樣式生成兩部分相結合的方法實現室內場景圖像的生成。Zhang 等[11]在網絡層次結構中引入了層次嵌套對抗性目標,提出了高清晰生成對抗網絡(High-Definition Generative Adversarial Network,HDGAN)模型,規(guī)范了中間層的表示,并幫助生成器捕獲復雜的圖像統(tǒng)計信息。Denton 等[12]在拉普拉斯金字塔框架內建立了多個GAN 模型,以前一層級的輸出為條件生成殘差圖像,然后作為下一層級的輸入,最后生成圖像。
上述這些文本-圖像對抗模型的判別器都使用卷積網絡[13]提取圖像特征,由于在卷積神經網絡中,上一層神經元傳遞到下一層神經元中的是個標量,標量不能表示出高層特征與低層特征之間的空間關系。另外,它的池化層會丟失大量有價值的信息,因此,卷積神經存在特別大的局限性。2017年年底,Hinton等[14]發(fā)表的論文Dynamic routing between capsules提出更深刻的算法及膠囊網絡架構,膠囊網絡采用到神經膠囊,上一層神經膠囊輸出到下一層神經膠囊中的是個向量,向量可以表示出組件的朝向和空間上的相對關系,極大地彌補了卷積網絡存在的不足。
改進后的模型使用膠囊網絡實現圖片的分類,當標題向量和圖片向量拼接后,依次進入初級膠囊層和第二膠囊層進行處理,由于膠囊層間傳遞的是向量,它很好地考慮了對象間的空間關系。經實驗驗證,加入膠囊網絡后有效提高了生成圖片的真實性和多樣性。
圖1 說明了GAN-CLS 的結構。生成器定義為G:RZ×RT→RD,判別器定義為D:RD×RT→{0,1},T是描述嵌入的維度,Z是輸入生成器中噪聲的維度。
圖1 文本條件GAN-CLS架構
在生成器中,首先從噪聲分布z∈Rz~Ν(0,1)中進行采樣,使用文本編碼器φ對文本標題T進行編碼,然后再使用連接層將嵌入的描述φ(t)壓縮為小尺寸,然后使用leaky-ReLU 激活函數對其進行處理,最后連接到噪聲矢量z。接下來,推理過程就像在一個正常的反卷積網絡中一樣:通過生成器G將其前饋;一個合成圖像x是通過x←G(z,φ(t))生成的。圖像生成對應于生成器G中基于查詢文本和噪聲樣本的前饋推理。
在判別器D 中,首先使用空間批處理歸一化和leaky-ReLU 激活函數執(zhí)行多個層的步長為2 的卷積處理,然后使用全連接層降低描述嵌入φ(t)的維數,其次對其進行校正。當判別器的空間維度為4×4時,在空間上復制描述嵌入,并執(zhí)行深度連接,然后執(zhí)行1×1 的卷積和校正,再執(zhí)行4×4的卷積,以從D計算最終分數,最后對所有卷積層執(zhí)行批處理規(guī)范化。
其中,判別器的損失函數如下所示:
生成器的損失函數如下所示:
上述的兩個公式中,x表示真實的圖片,t表示真實的標題向量,x表示生成的圖片,t表示錯誤的標題向量。在原始模型中,先訓練1 次判別器再訓練1 次生成器,如此迭代交替進行訓練,當生成器和判別器達到最優(yōu)后,輸入標題到生成器來生成預期的圖片。
改進后的判別器如圖2 所示,主要由以下3 部分組成:卷積層、初級膠囊層和第二膠囊層。卷積層主要負責提取輸入圖像的低級特征,包含有256 個步長為1 的9×9 的卷積核。初級膠囊層負責將卷積層提取到的特征組合起來,該層由32 個膠囊組成,每個膠囊又含有8個步長為2的9×9×256卷積核。第二膠囊層用來實現特征的分類,得到的向量模長表示圖片所屬的類,該層共有2個膠囊,該膠囊層由全連接層構成。
圖2 使用膠囊網絡的判別器
此判別器的工作流程:首先將圖片輸入判別器,經卷積網絡提取圖像特征后得到圖片特征向量;其次將英文標題輸入到文本編碼器得到標題特征向量;最后將圖片特征向量和標題特征向量進行拼接得到新的特征向量。將新的特征向量輸入到初級膠囊層,該層將這些低級特征向量組合起來,將組合后的特征向量輸入到第二膠囊層,該層負責將組合后的特征向量進行分類,得到的特征向量的模值越大表示圖片和標題匹配度越高。
膠囊網絡能夠比卷積神經網絡更好地學習特征和它們之間的關系,因為初級膠囊層查看輸入圖像的所有特征,而路由過程確定有助于第二膠囊層膠囊的全局特征集。其中第二膠囊層的網絡結構和原理如圖3所示。
圖3 第二膠囊網絡的內部原理
判別器的損失如下所示:
其中,m+=0.8,m-=0.2,當K類實體存在時,Tk=1,否則,Tk=0,λ=0.5 用作下權重因子,以防止訓練早期活動向量的收縮,vk表示膠囊層k的輸出。
生成器的損失函數如下所示:
其中,P(z)代表先驗分布,P(α)表示標題向量的真實分布,x表示生成的圖片,t表示真實的標題向量。
對判別器進行改進后,由于膠囊網絡能夠很好地表達底層對象之間的空間關系,因此判別器相對原始模型來說,其判別能力更強。先訓練1次判別器,再訓練1次生成器,如此迭代訓練,直到二者達到最優(yōu),二者能力相對原來模型都更優(yōu),因此生成器生成圖片的質量更高。
神經膠囊網絡不僅可以用于文本分類,也可以用于多標簽遷移和圖片分類的任務,在本文中,重點將它用于圖片分類的任務,圖3所示為第二膠囊網絡的結構和工作原理。
第二膠囊層由很多類似人腦神經的神經元組成,它輸出的向量的維度比較高,另外,該激活向量的模值越大,表示標題向量和圖像向量越匹配,生成的圖像越符合標題描述。另外在很多實際應用中,該向量的模長還可以代表某對象存在的可能性,如該模長比較長,則代表該對象存在的概率很高。圖3 表示了第二膠囊網絡的內部工作原理,主要分為四個流程,如下所示:
(1)仿射變換。首先將xn作為第二膠囊網絡的輸入向量,然后把這些輸入向量和對應的權重矩陣wnj進行相乘后就可以得到向量Xn,權重矩陣可以表示出底層特征和高層特征之間的關系,比如空間關系和其他的一些重要關系,這就是仿射變換的實現過程。
(2)標量加權。經過仿射變換后得到向量Xn,接著用耦合系數cnj與該向量進行相乘,其中耦合系數決定著某個低層膠囊的輸出向量作為哪個高層膠囊的輸入,同時耦合系數是使用動態(tài)路由的方法來實現更新。
(3)累加求和。經過標量加權后,可以得到一些向量,對這些向量進行求和。
(4)把進行累加求和后得到的結果輸入函數Squash()進行處理后得到非線性處理的結果。非線性處理的過程可以使用如下公式表示:
其中,sj是Xn經過標量加權操作后得到的相關向量再經過累加求和得到的結果,sj經過Squash()函數得到vj。Squash()函數不會改變該向量的方向,但會將該向量的長度控制在小于1的范圍內,Squash()函數處理后的結果即為第二膠囊層的輸出。
在實驗數據集的選擇和參數的設置上,本文引用了鳥類圖像的CUB 數據集和花圖像的Oxford-102 數據集。CUB擁有11 788張鳥類圖片,屬于200個不同類別中的一個。牛津102數據集包含了來自102個不同類別的8 189幅花卉圖片。在實驗中將這些劃分為不相交的訓練集和測試集。CUB 有150 個訓練類和50 個測試類,而Oxford-102 有82 個訓練和20 個測試類。在進行小批量選擇訓練時,隨機選取圖像視圖和其中一個標題。它的實現是建立在dcgan-tensorflow之上的。值得注意的是在模型中參與訓練、測試以及最后輸入模型的標題都是英文標題。
本文采用了3種圖像評估方法:
(1)FID(Frechet Inception Distance)Score[15]是近年來提出的圖像評價標準,它不僅考慮合成圖像的分布,而且還要考慮如何比較它和真實數據的分布。它直接測量合成數據分布p(?)和真實數據分布pr(?)之間的距離。實際上,在實際應用中,圖像由inception 模型用視覺特征進行編碼。假設特征嵌入遵循多維高斯分布,合成數據的高斯均值和方差(m,C)從合成數據分布p(?)中獲得,真實數據的高斯均值和方差(mr,Cr)從真實數據分布pr(?)處獲得。合成數據的高斯分布和真實數據的高斯分布差異由以下公式計算:
FID的值越低,表示真實數據和合成數據的距離越小。
(2)使用數值評估方法Inception score[16]進行定量評估。數值評估方法如下所示:
其中,x表示一個生成樣本,y是初始模型預測的標簽,p(y)是邊緣分布,p(y|x)是條件分布。這個指標背后的意義是好的模型應該生成多樣而且有意義的圖片。因此,KL 散度在邊際分布p(y)和條件分布p(y|x)之間應該足夠大。
為了證明將卷積網絡替換為膠囊網絡后可以提高模型的性能,本節(jié)將實驗中原始模型以及改進后的模型得到的結果和最近使用的比較流行的文本到圖像生成的網絡模型進行各種類型結果對比,使用IS評估指標和FID 評估指標來評價圖像的質量。以本章的實驗中設置的參數為依據,共計全部使用了40 000張生成的圖像來評估模型的性能。
本文中的GAN-CLS模型為實驗使用的原始生成對抗模型,它的模型如圖1 所示,在此模型中加入流行插值后的GAN-CLS-INT[8]是近期比較經典的文本生成圖像模型,進一步加入位置信息和條件約束后稱為GAWWN[9],本文融入膠囊網絡后的模型為GAN-CLS-SA。
從表1可以看出,本章模型GAN-CLS-SA和其他的一些經典模型相比較,結果都有了提高。本文改進的模型與GAWWN模型的結果相比,Oxford-102數據集上IS的結果提高了0.17,CUB 數據集上IS 的結果提高了0.23,這表明改進后的模型生成的圖片特征更豐富,更有意義。Oxford-102 數據集上FID 的結果降低了3.59,CUB數據集上FID的結果降低了4.73,這表明本文改進模型生成的圖片更逼近真實圖片的數據分布。由表1觀察可知,模型GAN-CLS-SA 和GAN-CLS 在Oxford-102數據集上的結果對比后可知FID和IS的數值分別降低和提高了14.49%和22.60%,而在CUB 數據集上FID和IS的結果分別降低和提高了9.64%和26.28%。
表1 定量結果
最后詳細地比較了GAN-CLS、GAWWN 和GANCLS-SA三種模型生成圖像的視覺質量。從定性的角度比較了上述三種模型在CUB數據集中鳥類測試圖像和Oxford-102 數據集中花朵測試圖像的文本描述標題下生成的相同圖像。最終結果比較分別如圖4和圖5所示。
圖4 相同花標題對應結果對比
圖5 相同鳥標題對應結果對比
通過觀察圖4 能夠發(fā)現,GAN-CLS 生成的花的形狀沒有其他兩種模型飽滿,顏色不夠準確。雖然GAWWN生成的圖像中花的形狀相對前者飽滿,但顏色方面也不夠準確或可信,相比之下,GAN-CLS-SA 生成的樣本在顏色方面相對GAWWN比較準確或較可信。
圖5為三種GAN模型在CUB和Oxford-102兩種數據集上的結果,從圖中可以看出,GAN-CLS-SA 模型生成的圖片中鳥的基本形狀和顏色方面更加符合標題描述,邊緣和細節(jié)更加逼真,與其他模型相比取得了較優(yōu)的結果。
本文在GAN-CLS 模型的上,首先將判別器中的卷積網絡替換為膠囊網絡,用膠囊網絡實現對圖片的分類。在Oxford-102花卉數據集和CUB鳥類數據集上的實驗結果表明,本文中提出的模型效果優(yōu)于基于原始生成對抗網絡的模型的效果,證明了對卷積網絡替換為膠囊網絡,提高了生成圖像的質量。除了完成文本生成圖像的任務外,生成對抗網絡模型還可以完成圖像到圖像的生成。未來,將進一步通過注意力機制來優(yōu)化網絡結構。