摘 要:生成對(duì)抗網(wǎng)絡(luò)(generative adversarial network,GAN)已成為圖像生成問題中常用的模型之一,但是GAN的判別器在訓(xùn)練過程中易出現(xiàn)梯度消失而導(dǎo)致訓(xùn)練不穩(wěn)定,以致無(wú)法獲得最優(yōu)化的GAN而影響生成圖像的質(zhì)量。針對(duì)該問題,設(shè)計(jì)滿足Lipschitz條件的譜歸一化卷積神經(jīng)網(wǎng)絡(luò)(CNN with spectral normalization,CSN)作為判別器,并采用具有更強(qiáng)表達(dá)能力的Transformer作為生成器,由此提出圖像生成模型TCSNGAN。CSN判別器網(wǎng)絡(luò)結(jié)構(gòu)簡(jiǎn)單,解決了GAN模型的訓(xùn)練不穩(wěn)定問題,且能依據(jù)數(shù)據(jù)集的圖像分辨率配置可調(diào)節(jié)的CSN模塊數(shù),以使模型達(dá)到最佳性能。在公共數(shù)據(jù)集CIFAR-10和STL-10上的實(shí)驗(yàn)結(jié)果表明,TCSNGAN模型復(fù)雜度低,生成的圖像質(zhì)量?jī)?yōu);在火災(zāi)圖像生成中的實(shí)驗(yàn)結(jié)果表明,TCSNGAN可有效解決小樣本數(shù)據(jù)集的擴(kuò)充問題。
關(guān)鍵詞: 生成對(duì)抗網(wǎng)絡(luò);圖像生成;Transformer;Lipschitz判別器
中圖分類號(hào): TP183文獻(xiàn)標(biāo)志碼:A 文章編號(hào): 1001-3695(2024)04-038-1221-07
doi: 10.19734/j.issn.1001-3695.2023.07.0357
TCSNGAN: image generation model based on Transformer and CNN with spectral normalization
Qian Huimin, Mao Qiuling, Chen Shi, Han Yixing, Lyu Benjie
Abstract:GAN has become one of the commonly-used image generation models. However, the discriminator of GAN is prone to the vanishing gradient problem in the training process, which leads to the instability of training. So that it is difficult to obtain the optimal GAN, and the quality of generation image is poor. To solve this problem, it designed a CNN with spectral normalization which satisfied the Lipchitz condition as the discriminator. Together with the Transformer generator, this paper proposed an image generation model, namely TCSNGAN (Transformer CSN GAN). The network structure of discriminator was simple, which could solve the problem of training instability of GAN model, and could configure the number of adjustable CSN modules according to the image resolution of data sets to achieve the optimal performance of the model. Experiments on public datasets CIFAR-10 and STL-10 show that the proposed TCSNGAN model has low complexity, and the generated image quality is good. And the experiments of fire image generation task demonstrates the effectiveness of small-sample dataset augmentation. Key words:generative adversarial networks; image generation; Transformer; Lipschitz discriminator
0 引言近年來(lái),生成式模型在圖像生成領(lǐng)域已取得了一定進(jìn)展?;谏墒侥P偷膱D像生成,旨在采用生成式模型學(xué)習(xí)圖像的像素值概率分布規(guī)律,繼而預(yù)測(cè)生成相似圖像。但是,現(xiàn)有生成式模型的性能在對(duì)分布規(guī)律復(fù)雜的圖像建模時(shí)仍存在不足[ 2]。目前常用的生成式模型有自回歸模型、變分自編碼器、流生成模型和生成對(duì)抗網(wǎng)絡(luò)等。相比于其他模型,GAN無(wú)須知道顯式的真實(shí)數(shù)據(jù)分布,只需要輸入隨機(jī)噪聲便可以生成接近于真實(shí)圖像的樣本[ 3,4]。GAN及其改進(jìn)模型已成為圖像生成問題常用的模型之一[5]。2014年,Goodfellow等人[6]提出基于多層感知機(jī) (multilayer perceptron,MLP)的生成對(duì)抗網(wǎng)絡(luò)MLP-GAN。該模型由生成器和判別器兩部分組成,前者捕獲數(shù)據(jù)分布生成符合真實(shí)分布的圖像,而后者則判別輸入是真實(shí)圖像還是生成器產(chǎn)生的圖像,兩者進(jìn)行最小最大博弈,直至納什平衡。但是,MLP-GAN的判別器訓(xùn)練很不穩(wěn)定,以致模型生成的圖像與真實(shí)圖像差異較大。隨著卷積神經(jīng)網(wǎng)絡(luò)的發(fā)展,Radfor等人[7]提出了深度卷積生成對(duì)抗網(wǎng)絡(luò)(deep convolution generative adversarial network,DCGAN),其判別器和生成器均采用CNN,DCGAN既提高了生成圖像的質(zhì)量,也提高了訓(xùn)練的穩(wěn)定性。BigGAN[8]、StyleGAN[9]這些模型生成的圖像質(zhì)量均逐漸接近真實(shí)圖像的質(zhì)量。
2017年,Vaswani等人[10]提出Transformer,且已廣泛應(yīng)用于計(jì)算機(jī)視覺領(lǐng)域,如目標(biāo)檢測(cè)的DETR(detection with Transformers)模型[11]、視覺分類的ViT(vision Transformer)模型[12]等。相較于CNN,Transformer的優(yōu)勢(shì)是可以建立圖像的全局依賴關(guān)系,從而獲取更多的圖像全局信息[13]。受此啟發(fā),Jiang等人[14]提出生成器和判別器均使用Transformer的圖像生成模型TransGAN。鑒于Transformer對(duì)圖像全局信息的表達(dá)能力,TransGAN能在48×48分辨率數(shù)據(jù)集下生成質(zhì)量較好的圖像。但是,它的計(jì)算量很大,訓(xùn)練時(shí)間很長(zhǎng)。另一方面,基于GAN的圖像生成模型易出現(xiàn)訓(xùn)練不穩(wěn)定。研究發(fā)現(xiàn),當(dāng)生成圖像和真實(shí)圖像的像素值分布不重疊或者重疊部分可以忽略不計(jì)時(shí),判別器模型訓(xùn)練時(shí)易出現(xiàn)梯度消失,從而導(dǎo)致GAN訓(xùn)練不穩(wěn)定[15]。因此,Arjovsky等人[16]提出使用Wasserstein距離判別真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的分布是否相同,以解決以上情況下判別器模型在訓(xùn)練時(shí)易出現(xiàn)梯度消失的問題。之后,Gulrajani等人[17]提出在Wasserstein距離中增加梯度懲罰項(xiàng),以進(jìn)一步提高判別器訓(xùn)練的穩(wěn)定性。Miyato等人[18]提出在判別器模型訓(xùn)練過程中,對(duì)參數(shù)矩陣進(jìn)行譜歸一化,使得模型參數(shù)梯度被限制在一個(gè)范圍內(nèi),以此實(shí)現(xiàn)判別器的穩(wěn)定訓(xùn)練。進(jìn)一步地,Zhou等人[19]提出欲保證GAN的判別器訓(xùn)練穩(wěn)定,本質(zhì)上要使判別器的函數(shù)滿足Lipschitz 條件。TransGAN的判別器是Transformer,其模型函數(shù)不滿足Lipschitz條件,因此,TransGAN在訓(xùn)練過程中易出現(xiàn)不穩(wěn)定,從而不能獲得最佳模型。
針對(duì)上述問題,本文設(shè)計(jì)了模型結(jié)構(gòu)簡(jiǎn)單且滿足Lipschitz條件的判別器——譜歸一化卷積神經(jīng)網(wǎng)絡(luò)(CNN with spectral normalization,CSN),結(jié)合具有更強(qiáng)表達(dá)能力的Transformer作為生成器,由此提出基于Transformer和CSN的圖像生成模型TCSNGAN(Transformer-and-CSN-based GAN)。在公共數(shù)據(jù)集CIFAR-10和STL-10上的實(shí)驗(yàn)表明,TCSNGAN模型復(fù)雜度較低,生成的圖像質(zhì)量高。本文工作主要貢獻(xiàn)為:
a)針對(duì)現(xiàn)有GAN的判別器易在訓(xùn)練過程出現(xiàn)不穩(wěn)定的問題,設(shè)計(jì)了譜歸一化卷積神經(jīng)網(wǎng)絡(luò)判別器,并在判別器結(jié)構(gòu)中設(shè)置了可調(diào)節(jié)CSN模塊數(shù),為不同分辨率數(shù)據(jù)集進(jìn)行最佳配置,以在最簡(jiǎn)潔結(jié)構(gòu)下生成最佳質(zhì)量的圖像;設(shè)計(jì)的CSN判別器結(jié)構(gòu)簡(jiǎn)潔,滿足Lipschitz 條件,訓(xùn)練穩(wěn)定。
b)提出了TCSNGAN圖像生成模型,采用Transformer生成器、CSN判別器;Transformer生成器的自注意力機(jī)制使模型具有更強(qiáng)的表達(dá)能力;結(jié)構(gòu)簡(jiǎn)潔的CSN判別器使得圖像生成模型的計(jì)算復(fù)雜度低,且性能良好。
c)在CSN判別器訓(xùn)練過程中,對(duì)生成器的生成圖像和真實(shí)圖像均進(jìn)行DiffAugment數(shù)據(jù)增強(qiáng),提高了生成圖像的質(zhì)量。
1 生成式網(wǎng)絡(luò)的Lipschitz判別器
生成式網(wǎng)絡(luò)的期望訓(xùn)練過程是交替訓(xùn)練生成器(generator,G)和判別器(discriminatory,D)。生成器的目標(biāo)是捕獲真實(shí)圖像的分布,以便生成更接近真實(shí)圖像的新圖像;判別器的目標(biāo)則是盡可能地區(qū)分真實(shí)圖像和生成圖像。生成器和判別器之間進(jìn)行兩者極小極大博弈,直至收斂至納什平衡點(diǎn),即生成器達(dá)到最小值,判別器達(dá)到最大值[11]。
假設(shè)生成器G的輸入為隨機(jī)噪聲向量 z ,G( z)表示噪聲z 經(jīng)生成器G生成的圖像,其數(shù)據(jù)服從分布p z ;x為真實(shí)圖像,其數(shù)據(jù)服從分布px;函數(shù)D(·)表示判別器D的輸入“·”為真實(shí)圖像的概率。GAN的優(yōu)化過程可由以下最大-最小值函數(shù)V(D,G)[20]表示:
其中:E是期望函數(shù)。判別器和生成器的優(yōu)化目標(biāo)完全相反,前者期望V(D,G)達(dá)到最大值,后者期望V(D,G)達(dá)到最小值,因此需要小心平衡判別器和生成器的訓(xùn)練程度[16]。
平衡判別器和生成器的訓(xùn)練程度并不容易,訓(xùn)練過程中常會(huì)出現(xiàn)以下兩種情況[21]:a)判別器的判別能力過強(qiáng),導(dǎo)致生成器過早收斂,由此,生成器不再進(jìn)行多樣化嘗試,只生成重復(fù)模式的圖像,此時(shí)GAN的訓(xùn)練過程出現(xiàn)模式崩塌[22];b)生成器的生成能力過強(qiáng),導(dǎo)致判別器過早收斂,由此判別器無(wú)法指導(dǎo)生成器的訓(xùn)練,此時(shí)GAN的訓(xùn)練過程提前終止。對(duì)此,Zhou等人[19]提出欲保證GAN的判別器訓(xùn)練穩(wěn)定,本質(zhì)上要使判別器模型的函數(shù)滿足Lipschitz條件。滿足Lipschitz條件的判別器稱為L(zhǎng)ipschitz判別器。
則稱f符合Lipschitz條件。其中,dX、dY分別是實(shí)數(shù)集子集X、Y上的度量,滿足Lipschitz條件的最小常數(shù)K為函數(shù)f的Lipschitz常數(shù)。由以上定義可知,Lipschitz 條件限定函數(shù)f在某一局部區(qū)域的變化幅度不會(huì)超過Lipschitz常數(shù)K。對(duì)于神經(jīng)網(wǎng)絡(luò)而言,如果其函數(shù)滿足Lipschitz條件,則其梯度的值也不會(huì)超過對(duì)應(yīng)的Lipschitz常數(shù)K,由此在神經(jīng)網(wǎng)絡(luò)函數(shù)的訓(xùn)練過程中就不會(huì)出現(xiàn)梯度爆炸,這就可增強(qiáng)網(wǎng)絡(luò)訓(xùn)練的穩(wěn)定性。
WGAN首次提出采用權(quán)重剪枝方法使得GAN的判別器滿足Lipschitz條件[16]。但是,剪枝使得判別器權(quán)重矩陣的結(jié)構(gòu)被破壞,以致WGAN收斂速度較慢[14]。Miyato等人[18]提出在訓(xùn)練判別器時(shí),通過譜歸一化操作將每一層網(wǎng)絡(luò)的權(quán)重矩陣限制在一個(gè)范圍內(nèi),從而使判別器函數(shù)滿足Lipschitz 條件。受此啟發(fā),本文通過在CNN中引入譜歸一化層,設(shè)計(jì)譜歸一化卷積神經(jīng)網(wǎng)絡(luò)判別器。
2 TCSNGAN模型鑒于Transformer對(duì)圖像中全局依賴關(guān)系的建模能力,并考慮生成式網(wǎng)絡(luò)的訓(xùn)練穩(wěn)定性問題,本文提出了一種新的生成式對(duì)抗網(wǎng)絡(luò)模型TCSNGAN,它由Transformer生成器和CSN判別器構(gòu)成。
2.1 Transformer生成器Transformer生成器的結(jié)構(gòu)如圖1所示。為了方便說(shuō)明各部分輸入輸出向量尺寸的變化,以CIFAR-10數(shù)據(jù)集為例(數(shù)據(jù)集圖像尺寸為3×32×32)[23]。生成器的輸入是滿足正態(tài)分布的隨機(jī)噪聲向量,這里取向量的尺寸為1×256。需要說(shuō)明的是,噪聲向量的維度不能太小,否則在圖像生成過程中易出現(xiàn)模式崩潰的現(xiàn)象。
首先,輸入的向量經(jīng)全連接層(fully connected,F(xiàn)C)映射成維度為C×H×W(C=1024,H=8,W=8)的張量 T 0,位置編碼器(positional encoding)為 T 0中的每一個(gè)像素賦予一個(gè)可學(xué)習(xí)的位置編碼,并在后續(xù)處理中將 T 0中的每一個(gè)像素值及其位置編碼記為一個(gè)token。接著, T 0經(jīng)線性展平操作(linear flatten)將所有像素展平,生成1×1×(1024×8×8)的序列張量 T1。然后,T1 被輸入N1個(gè)連續(xù)的Transformer編碼器(Transformer encoder)中進(jìn)行特征表示的學(xué)習(xí),輸出張量 T2。之后,T2 輸入Transformer循環(huán)體(TransCir),循環(huán)M次后,生成張量 T3,大小為1×1×(1024/16×32×32)。最后,T3由線性展平恢復(fù)操作(linear unflatten)重新拼接成尺寸為1024/16× 32×32的張量 T 4,并通過1×1的卷積操作(1×1 conv),獲得尺寸為3×32×32的輸出圖像 T5 。
Transformer生成器通過多頭自注意力機(jī)制提取信息,這一機(jī)制使其可以對(duì)輸入的上下文信息進(jìn)行編碼,并且可以讓模型學(xué)習(xí)到不同tokens之間的依賴關(guān)系[24,25]。
2.1.1 Transformer編碼器Transformer編碼器的組成如圖1 所示,包含歸一化 (layer norm,LN) 層、多頭自注意力機(jī)制 (multi-h(huán)ead self-attention,MHSA)層、多層感知機(jī)(multilayer perceptron,MLP)層和dropout層。
LN層旨在對(duì)張量 T5 進(jìn)行歸一化操作,獲得歸一化輸出Y。MHSA層的多頭自注意力機(jī)制運(yùn)算則可以提取特征信息。MLP層由兩個(gè)FC層構(gòu)成,其后的dropout層則可避免過擬合現(xiàn)象的產(chǎn)生。
需要說(shuō)明的是,Transformer編碼器的個(gè)數(shù)與數(shù)據(jù)集的分辨率有關(guān),圖像的分辨率越高,生成高質(zhì)量圖像所需的Transformer編碼器的個(gè)數(shù)越多。以數(shù)據(jù)集CIFAR-10為例,取N1=5,N2=4。
2.1.2 TransCir循環(huán)體TransCir循環(huán)體由線性展平恢復(fù)、像素重組(pixel shuffle)、位置編碼、線性展平、Transformer編碼器組成。其中,像素重組操作是對(duì)輸入張量進(jìn)行上采樣,以提高圖像的分辨率。TransCir循環(huán)體的個(gè)數(shù)也與數(shù)據(jù)集的分辨率有關(guān),選取循環(huán)體個(gè)數(shù)M=2。
2.2 CSN判別器
2.2.1 CSN判別器的結(jié)構(gòu)
CSN判別器的結(jié)構(gòu)如圖2所示,它由殘差型網(wǎng)絡(luò)Res_CSN、多個(gè)CSN模塊、平均池化層(AvgPooling)和全連接層(FC)構(gòu)成。需要說(shuō)明的是,CSN模塊由3×3卷積層和譜歸一化層級(jí)聯(lián)而成,深紅色CSN模塊含ReLU激活函數(shù)層,其余不含。殘差型網(wǎng)絡(luò)Res_CSN的結(jié)構(gòu)如圖2虛線框所示。其中的直聯(lián)結(jié)構(gòu)由均值池化層(2×2 AvgPooling)、卷積層(1×1 conv)和譜歸一化層(spectral norm)構(gòu)成。直聯(lián)結(jié)構(gòu)使網(wǎng)絡(luò)能獲得更穩(wěn)定且保留更多原圖信息的特征圖。譜歸一化層的數(shù)學(xué)表達(dá)為
2.2.2 CSN判別器是Lipschitz判別器
假設(shè)判別器網(wǎng)絡(luò)D(x)由n個(gè)多層神經(jīng)網(wǎng)絡(luò)Di(x)(i = 2, …, n)構(gòu)成,則第n個(gè)神經(jīng)網(wǎng)絡(luò)的輸出Dn(x)可以表示為
如果要使n層判別器網(wǎng)絡(luò)滿足1-Lipschitz條件,就需要保證每一層權(quán)重矩陣的最大奇異值恒等于1。由式(3)可知,CSN判別器通過譜歸一化層將每一層的權(quán)重矩陣的最大奇異值恒等于 從而使CSN判別器的梯度大小被限制在1以內(nèi),但并沒有破壞權(quán)重矩陣的結(jié)構(gòu)。因此,CSN判別器滿足Lipschitz 條件,能幫助提高TCSNGAN訓(xùn)練的穩(wěn)定性。進(jìn)一步地,以STL-10[26]數(shù)據(jù)集為例,圖3對(duì)比了TransGAN的Transformer判別器、TCSNGAN的CSN判別器的訓(xùn)練損失函數(shù)曲線。其中,縱坐標(biāo)為損失函數(shù)值,橫坐標(biāo)為模型訓(xùn)練的次數(shù)。由圖3可知,CSN判別器的損失函數(shù)值在迭代1 500次之后基本穩(wěn)定,而Transformer判別器的損失函數(shù)值在迭代2 500次之后仍然不穩(wěn)定。
3 DiffAugment數(shù)據(jù)增強(qiáng)
深度神經(jīng)網(wǎng)絡(luò)因參數(shù)眾多,對(duì)訓(xùn)練數(shù)據(jù)的數(shù)量要求較高。在訓(xùn)練數(shù)據(jù)不足的情況下,常采用數(shù)據(jù)增強(qiáng)操作擴(kuò)充訓(xùn)練集。常用的數(shù)據(jù)增強(qiáng)方法包括調(diào)整亮度、對(duì)比度、飽和度、色調(diào)等以降低模型對(duì)色彩的敏感性,或者隨機(jī)剪裁、隨機(jī)反轉(zhuǎn)等以降低模型對(duì)目標(biāo)位置的敏感性[27]。但是,在訓(xùn)練GAN模型時(shí)使用上述數(shù)據(jù)增強(qiáng)方法,反而會(huì)破壞生成器和判別器之間的微妙平衡,使得生成器性能大幅下降。Zhao等人[28]提出了一種可微分的數(shù)據(jù)增強(qiáng)方法DiffAugment,它可在數(shù)據(jù)增強(qiáng)的同時(shí)動(dòng)態(tài)調(diào)節(jié)生成器和判別器之間的平衡,從而提升網(wǎng)絡(luò)性能。因此,本文在CSN判別器的訓(xùn)練過程中也引入了DiffAugment數(shù)據(jù)增強(qiáng)。
DiffAugment的具體方法為:對(duì)真實(shí)圖像x和生成圖像G(z)進(jìn)行可微分的數(shù)據(jù)增強(qiáng)T,T的可學(xué)習(xí)參數(shù)會(huì)隨著每一次梯度更新而動(dòng)態(tài)變化。因此,基于DiffAugment數(shù)據(jù)增強(qiáng),判別器的學(xué)習(xí)對(duì)象從真實(shí)圖像x和生成圖像G(z)的分布轉(zhuǎn)變?yōu)榭晌⒃鰪?qiáng)函數(shù)T(x)和T(G(z))的分布,并在每一次模型學(xué)習(xí)后動(dòng)態(tài)變化。T隨著模型的變化而變化,可避免訓(xùn)練過程中生成器性能的下降。
4 實(shí)驗(yàn)結(jié)果與分析
4.1 數(shù)據(jù)集和實(shí)驗(yàn)設(shè)置
實(shí)驗(yàn)在CIFAR-10和STL-10數(shù)據(jù)集上進(jìn)行。CIFAR-10數(shù)據(jù)集包含60 000張尺寸為32×32的圖像,50 000張訓(xùn)練圖像和10 000張測(cè)試圖像。STL-10數(shù)據(jù)集則包含5 000張訓(xùn)練圖像和100 000張無(wú)標(biāo)簽圖像,尺寸48×48。
TCSNGAN模型的訓(xùn)練過程如圖4所示。噪聲矩陣 Z 輸入Transformer生成器,得到生成圖像;繼而生成圖像和數(shù)據(jù)集中的真實(shí)圖像均輸入CSN判別器,判別器判斷兩幅圖像的相似度,并輸出f(x)。Transformer生成器的損失函數(shù)[14]為
其中:y為真實(shí)圖像或生成圖像的標(biāo)簽,取值為1或-1;f(x)為判別器網(wǎng)絡(luò)的輸出值。需要說(shuō)明的是,TCSNGAN的Transformer生成器模型借鑒了TransGAN的生成器模型(https://github.com/VITA-Group/TransGAN)。
實(shí)驗(yàn)環(huán)境為Ubuntu 20.04,PyTorch深度學(xué)習(xí)框架,硬件為4塊RTX 3090顯卡。生成器和判別器的訓(xùn)練均使用了Adam優(yōu)化器,beta1為0,beta2為0.99。生成器的學(xué)習(xí)率為0.000 batchsize為128;判別器的學(xué)習(xí)率為0.000 15,batchsize為64。
實(shí)驗(yàn)使用Inception Score(IS)[30]和Frechet Inception Distance(FID)[31]作為生成圖像質(zhì)量的評(píng)價(jià)指標(biāo)。IS通過圖像類別分類器來(lái)評(píng)價(jià)圖像的質(zhì)量,圖像越清晰,類別越豐富,則IS的值越高。FID則通過計(jì)算生成圖像和真實(shí)圖像的分布距離來(lái)評(píng)價(jià)圖像的質(zhì)量,F(xiàn)ID值越小,則代表生成圖像越接近真實(shí)圖像。
4.2 實(shí)驗(yàn)結(jié)果與分析
本文從以下幾方面開展實(shí)驗(yàn):確定CIFAR-10和STL-10數(shù)據(jù)集下的最佳CSN模塊數(shù)K;通過消融實(shí)驗(yàn)驗(yàn)證提出技術(shù)的有效性;從模型復(fù)雜度、生成圖像的質(zhì)量角度,比較了TCSNGAN與已有模型。1)確定最佳CSN模塊數(shù)K
首先在CIFAR-10數(shù)據(jù)集(圖像尺寸為32×32)上進(jìn)行枚舉實(shí)驗(yàn),確定最優(yōu)的K值,如表1所示。當(dāng)K=4時(shí),IS值最大,F(xiàn)ID值最小。此外,當(dāng)Klt;2時(shí),訓(xùn)練判別器的損失函數(shù)很快收斂,說(shuō)明在生成器與判別器的對(duì)抗過程中,生成器生成的圖像很容易就騙過了判別器,使得判別器過早收斂。當(dāng)Kgt;5時(shí),在判別器的訓(xùn)練前期,損失函數(shù)基本不更新,且最后只能生成一些沒有意義的噪聲圖像。
進(jìn)一步地,圖5給出了K=2, 3, 4時(shí)的生成圖像。由圖可知,當(dāng)K=2時(shí),大部分圖像無(wú)法肉眼觀察出它具體的類別,并且背景會(huì)和圖像混在一起。當(dāng)K=3時(shí),可以分辨出大部分圖像的類別,但生成的圖像有些還不完整并存在扭曲的現(xiàn)象。當(dāng)K=4時(shí),達(dá)到最佳生成圖像的質(zhì)量,大部分圖像為完整清晰的類別。
然后在STL-10數(shù)據(jù)集(圖像尺寸為48×48)上進(jìn)行枚舉實(shí)驗(yàn),確定最優(yōu)的K值,如表2所示。當(dāng)K=10時(shí),IS值最大,F(xiàn)ID值最小。此外,當(dāng)Klt;6時(shí),訓(xùn)練判別器的損失函數(shù)過早收斂。當(dāng)Kgt;11時(shí),在判別器的訓(xùn)練前期,損失函數(shù)基本不更新,且最后只能生成一些沒有意義的噪聲圖像。圖6給出了K=6, 8, 10時(shí)的生成圖像,由圖可知,當(dāng)K=6時(shí),生成圖像大部分沒有確切的形狀,只是顏色的堆積,并且較為模糊。當(dāng)K=8時(shí),圖像逐漸形成確定的形狀,少許圖像能夠分辨其類別。當(dāng)K=10時(shí),達(dá)到最佳生成圖像的質(zhì)量,大部分圖像能夠分辨其類別并且更加清晰。綜上可知,當(dāng)數(shù)據(jù)集中圖像的分辨率越高,判別器中可調(diào)節(jié)CSN模塊數(shù)K的最優(yōu)值更大,也即分辨率更高圖像生成所需的判別器的結(jié)構(gòu)更復(fù)雜。具體地,對(duì)于CIFAR-10和STL-10數(shù)據(jù)集來(lái)說(shuō),圖像的分辨率從32×32提高到了48×48后,CSN塊增加了7個(gè)。
2)消融實(shí)驗(yàn)消融實(shí)驗(yàn)驗(yàn)證了在卷積神經(jīng)網(wǎng)絡(luò)上增加譜歸一化層構(gòu)成CSN判別器,及引入DiffAugment數(shù)據(jù)增強(qiáng)方法的效果。實(shí)驗(yàn)結(jié)果如表3所示。其中,TCSNGAN(only CNN)表示采用Transformer的生成器網(wǎng)絡(luò)結(jié)構(gòu),但將判別器中CSN模塊后的譜歸一化層舍去。DiffAug(r)、DiffAug(f)、DiffAug(r+f)分別表示僅在真實(shí)圖像上、僅在生成圖像上和在生成圖像與真實(shí)圖像上均使用DiffAugment數(shù)據(jù)增強(qiáng)。
由表3可知,與不采用譜歸一化層的TCSNGAN(only CNN)相比,TCSNGAN在CIFAR-10數(shù)據(jù)集上的IS提高了16.9%,F(xiàn)ID降低了48.6%;在STL-10數(shù)據(jù)集上IS提高了26%,F(xiàn)ID降低了4.9%。進(jìn)一步地,與TCSNGAN相比,在生成圖像和真實(shí)圖像集上均使用DiffAugment數(shù)據(jù)增強(qiáng)方法,在CIFAR-10數(shù)據(jù)集上IS提高了0.2%,F(xiàn)ID降低了8.3%;在STL-10數(shù)據(jù)集上,IS提高了19.6%,F(xiàn)ID降低了50.4%。因此,在CIFAR-10和STL-10數(shù)據(jù)集上采用譜歸一化和數(shù)據(jù)增強(qiáng)均對(duì)TCSNGAN生成的圖像質(zhì)量有提高的作用,其中譜歸一化對(duì)IS的提高作用更為明顯,而數(shù)據(jù)增強(qiáng)則對(duì)降低FID值更有幫助。因?yàn)閿?shù)據(jù)增強(qiáng)通過改變數(shù)據(jù)集的形狀顏色等特征,本質(zhì)上擴(kuò)充了原本數(shù)據(jù)集的大小,增強(qiáng)了判別器的泛化能力,所以生成圖像的多樣性得到了進(jìn)一步的提高。此外,僅在真實(shí)圖像上使用DiffAugment數(shù)據(jù)增強(qiáng),在CIFAR-10和STL-10數(shù)據(jù)集上均生成了具有一定質(zhì)量的圖像,但效果不佳;僅在生成圖像上使用DiffAugment數(shù)據(jù)增強(qiáng),在CIFAR-10和STL-10數(shù)據(jù)集上均生成了失敗的圖像,原因是使用數(shù)據(jù)增強(qiáng)后完全破壞了生成器生成圖像的規(guī)律,判別器無(wú)法從生成的圖像中學(xué)習(xí)到規(guī)律,導(dǎo)致生成器生成的圖像完全迷惑了判別器。兩組實(shí)驗(yàn)證明了在進(jìn)行數(shù)據(jù)增強(qiáng)時(shí),需要同時(shí)維持生成器和判別器的微妙平衡,只維持其中一方會(huì)導(dǎo)致生成圖像的質(zhì)量下降,而僅在生成圖像上使用數(shù)據(jù)增強(qiáng),則會(huì)導(dǎo)致生成失敗的圖像。
3)模型復(fù)雜度對(duì)比TCSNGAN模型采用了Transformer生成器、CSN判別器,而TransGAN模型和ViTGAN則采用了Transformer生成器和Transformer判別器。通過比較這三種模型的復(fù)雜度,分別對(duì)模型的計(jì)算量FLOPs(floating point operations)和參數(shù)量Params進(jìn)行統(tǒng)計(jì),如表4所示。其中,G表示每秒進(jìn)行1 000個(gè)浮點(diǎn)運(yùn)算,M表示兆字節(jié)。
由表4可知,判別器的計(jì)算量占生成網(wǎng)絡(luò)模型總計(jì)算量的大部分,其中TransGAN和ViTGAN中Transformer判別器占總計(jì)算量的80%左右,TCSNGAN中CSN判別器占總計(jì)算量的46.76%,因此優(yōu)化判別器的計(jì)算量是優(yōu)化GAN計(jì)算量的關(guān)鍵。Transformer模型運(yùn)用了多頭自注意力機(jī)制來(lái)提取圖像特征,需要對(duì)長(zhǎng)序列的張量進(jìn)行乘法運(yùn)算,當(dāng)序列張量的長(zhǎng)度越長(zhǎng),其計(jì)算量會(huì)呈指數(shù)級(jí)擴(kuò)大。例如,TCSNGAN與TransGAN具有相同的生成器,而前者的CSN判別器的計(jì)算量遠(yuǎn)小于后者Transformer判別器的計(jì)算量,因此與TransGAN相比,TCSNGAN在CIFAR-10數(shù)據(jù)集上,F(xiàn)LOPs減少了62.05%,Params減少了9.71%;在STL-10數(shù)據(jù)集上,F(xiàn)LOPs減少了54.16%,Params減少了35.89%。綜上,與TransGAN和ViTGAN相比,復(fù)雜度更低的Trans-CSN模型可應(yīng)用于更低性能的硬件設(shè)備,提高模型在實(shí)際場(chǎng)景中的普適性。
4)TCSNGAN與其他GAN模型的比較最后,比較了TCSNGAN與其他GAN模型的性能,如表5所示。其中,LN表示層規(guī)范化(layer norm),即GAN-LN為采用層規(guī)范化的GAN;WN表示權(quán)重規(guī)范化(weight norm),即GAN-WN為采用權(quán)重規(guī)范化的GAN;DiffAug表示采用DiffAugment數(shù)據(jù)增強(qiáng)。需要說(shuō)明的是,除了TransGAN、ViTGAN和TCSNGAN外,其余GAN模型均使用基于CNN的生成器,帶*數(shù)據(jù)為本文復(fù)現(xiàn)時(shí)獲得的最佳結(jié)果。表5結(jié)果表明,從IS和FID指標(biāo)看,TCSNGAN模型生成的圖像質(zhì)量?jī)?yōu)于多數(shù)已有GAN模型,包括TransGAN。DiffAugment數(shù)據(jù)增強(qiáng)能進(jìn)一步提升模型的性能,特別是FID值。由于數(shù)據(jù)增強(qiáng)對(duì)復(fù)雜模型Transformer的性能提升更有效,DiffAugment數(shù)據(jù)增強(qiáng)用于TransGAN模型后,其指標(biāo)提升效果優(yōu)于TCSNGAN。但是,TransGAN存在訓(xùn)練不穩(wěn)定的問題,即在訓(xùn)練過程中常常會(huì)出現(xiàn)生成器或判別器訓(xùn)練不充分的情況,此時(shí)生成的圖像質(zhì)量下降嚴(yán)重,IS和FID指標(biāo)也就不能總是達(dá)到最佳值。這種情況在圖像分辨率稍大的數(shù)據(jù)集(如STL-10)上更容易發(fā)生。而TCSNGAN模型能穩(wěn)定訓(xùn)練,模型性能保持在最佳值。
進(jìn)一步地,圖7給出了三種模型的生成圖像,可知,SN-GAN模型生成的圖像中部分目標(biāo)的輪廓與真實(shí)物體存在差距;相較而言,TransGAN和TCSNGAN模型生成的圖像中的目標(biāo)更接近真實(shí)物體。
綜上,TCSNGAN模型與TranGAN相比,評(píng)價(jià)指標(biāo)IS和FID與TranGAN接近,但計(jì)算量減少了60%左右,因此本文模型具有較強(qiáng)的競(jìng)爭(zhēng)力。此外,雖然ViTGAN在CIFAR-10數(shù)據(jù)集上的性能優(yōu)于TCSNGAN模型,但由表4可知,TCSNGAN模型的計(jì)算量同樣具有較大的優(yōu)勢(shì)。
4.3 模型在火焰圖像生成中的應(yīng)用在基于深度神經(jīng)網(wǎng)絡(luò)的火災(zāi)檢測(cè)研究中,獲取多樣化、數(shù)量大的火災(zāi)圖像樣本數(shù)據(jù),是提高火災(zāi)檢測(cè)模型準(zhǔn)確率的方法之一。但是,火災(zāi)圖像的獲取存在困難,通過圖像生成模型可以擴(kuò)充火災(zāi)圖像數(shù)據(jù)集。本文采用TCSNGAN生成火焰圖像,并將這些火焰圖像與不同的背景圖像融合,從而生成多樣化、大數(shù)量的火災(zāi)圖像樣本,擴(kuò)充火災(zāi)檢測(cè)模型的訓(xùn)練樣本集。具體地,從Bilkent大學(xué)火災(zāi)數(shù)據(jù)集[33]中裁剪出火焰區(qū)域(利用數(shù)據(jù)集的標(biāo)注數(shù)據(jù)),經(jīng)圖像縮放建立尺寸為32×32的火焰數(shù)據(jù)集。圖8(a)給出了火焰數(shù)據(jù)集中的部分樣本圖像。將火焰數(shù)據(jù)集中的圖像輸入訓(xùn)練好的TCSNGAN模型中,生成圖8(b)所示的火焰圖像,可知,生成的火焰圖像與真實(shí)的火焰圖像非常接近。由于生成火焰圖像的尺寸僅為32×32,生成的火焰圖像還需要通過像素重組(pixel shuffle)進(jìn)行上采樣,這樣才能較好地與不同的背景圖像融合。圖8(c)給出了生成火焰圖像經(jīng)上采樣6倍后的圖像,尺寸為192×192。
進(jìn)一步,生成的火焰圖像與不同背景圖像進(jìn)行融合,得到如圖9所示的火災(zāi)圖像。這些融合后的火災(zāi)圖像可用于擴(kuò)充火災(zāi)檢測(cè)模型的訓(xùn)練數(shù)據(jù)集,提高火災(zāi)檢測(cè)模型的準(zhǔn)確率。
5 結(jié)束語(yǔ)
基于深度神經(jīng)網(wǎng)絡(luò)的生成式對(duì)抗網(wǎng)絡(luò)是圖像生成領(lǐng)域的研究熱點(diǎn)。鑒于現(xiàn)有GAN的訓(xùn)練穩(wěn)定性差以及模型復(fù)雜度高的問題,本文提出基于Transformer和譜歸一化的卷積神經(jīng)網(wǎng)絡(luò)的圖像生成模型TCSNGAN,其生成器為Transformer,判別器為CSN。該CSN判別器滿足Lipschitz條件,使得TCSNGAN可以穩(wěn)定訓(xùn)練。在公共數(shù)據(jù)集CIFAR-10和STL-10上的實(shí)驗(yàn)表明,生成圖像的質(zhì)量評(píng)價(jià)指標(biāo)IS和FID優(yōu)于多數(shù)現(xiàn)有模型,并且,該網(wǎng)絡(luò)的判別器結(jié)構(gòu)更簡(jiǎn)單,使其更易于應(yīng)用在實(shí)際場(chǎng)景中。進(jìn)一步,本文將該模型應(yīng)用于火災(zāi)圖像生成中,以擴(kuò)充火災(zāi)檢測(cè)模型的訓(xùn)練樣本集,提高火災(zāi)檢測(cè)模型的準(zhǔn)確率。但模型直接生成的圖像尺寸較小,在實(shí)際應(yīng)用中仍需要經(jīng)圖像縮放處理,下一步將研究如何生成尺寸較大的圖像。
參考文獻(xiàn):
[1]Gui Jie,Sun Zhenan,Wen Yonggang,et al. A review on generative adversarial networks: algorithms,theory,and applications [J].IEEE Trans on Knowledge and Data Engineering ,2023, 35 (4): 3313-3332.
[2]林懿倫,戴星原,李力,等. 人工智能研究的新前線: 生成式對(duì)抗網(wǎng)絡(luò) [J]. 自動(dòng)化學(xué)報(bào),2018, 44 (5): 775-792. (Lin Yilun,Dai Xingyuan,Li Li,et al. The new frontier of AI research: generative adversarial networks [J].Acta Automatica Sinica ,2018, 44 (5): 775-792.)
[3]陳佛計(jì),朱楓,吳清瀟,等. 生成對(duì)抗網(wǎng)絡(luò)及其在圖像生成中的應(yīng)用研究綜述 [J]. 計(jì)算機(jī)學(xué)報(bào),202 23 (2): 347-369. (Chen Foji,Zhu Feng,Wu Qingxiao,et al. A survey about image generative with generative adversarial nets [J].Chinese Journal of Compu-ter ,202 23 (2): 347-369.)
[4]胡銘菲,左信,劉建偉. 深度生成模型綜述 [J]. 自動(dòng)化學(xué)報(bào),2022, 48 (1): 40-74. (Hu Mingfei,Zuo Xin,Liu Jianwei. Survey on deep generative model [J].Acta Automatica Sinica ,2022, 48 (1): 40-74.)
[5]Wang Zhengwei,She Qi,Ward T. Generative adversarial networks in computer vision: a survey and taxonomy [J].ACM Computing Surveys ,202 54 (2): 1-38.
[6]Goodfellow I,Pouget J,Mirza M,et al. Generative adversarial networks [J].Communications of the ACM ,2020, 63 (11): 139-144.
[7]Radford A,Metz L,Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks [C]// Proc of International Conference on Learning Representations. 2016.
[8]Brock A,Donahue J,Simonyan K. Large scale GAN training for high fidelity natural image synthesis [C]//Proc of International Conference on Learning Representations. 2018.
[9]Karras T,Laine S,Aila T. A style-based generator architecture for generative adversarial networks [C]// Proc of IEEE/CVF International Conference on Computer Vision. Piscataway,NJ: IEEE Press,2019: 4401-4410.
[10]Vaswani A,Shazeer N,Parmar N,et al. Attention is all you need [C]// Proc of the 31st International Conference on Neural Information Processing Systems. Red Hook,NY: Curran Associates Inc.,2017: 6000-6010.
[11]Hong Yongjun,Hwang U,Yoo J,et al. How generative adversarial networks and their variants work: an overview [J].Communications of the ACM ,2019, 52 (1): 1-43.
[12]Dosovitskiy A,Beyer L,Kolesnikov A,et al. An image is worth 16×16 words: Transformers for image recognition at scale [C]// Proc of International Conference on Learning Representations. 2021.
[13]裴炤,邱文濤,王淼,等. 基于Transformer動(dòng)態(tài)場(chǎng)景信息生成對(duì)抗網(wǎng)絡(luò)的行人軌跡預(yù)測(cè)方法 [J]. 電子學(xué)報(bào),2022, 50 (7): 1537-1547. (Pei Zhao,Qiu Wentao,Wang Miao,et al. Pedestrian trajectory prediction method using dynamic scene information based Transformer generative adversarial network [J].Acta Electronica Sinica ,2022, 50 (7): 1537-1547.
[14]Jiang Yifan,Chang Shiyu,Wang Zhangyang. TransGAN: two pure transformers can make one strong GAN,and that can scale up [C]// Advances in Neural Information Processing Systems. Cambridge,MA: MIT Press,2021: 14745-14758.
[15]Arjovsky M,Bottou L. Towards principled methods for training gene-rative adversarial networks [C]// Proc of International Conference on Learning Representations. 2017.
[16]Arjovsky M,Chintala S,Bottou L. Wasserstein GAN [C]//Proc of International Conference on Learning Representations. 2017.
[17]Gulrajani I,Ahmed F,Arjovsky M,et al. Improved training of Wasserstein GANs [C]// Advances in Neural Information Processing Systems. Cambridge,MA: MIT Press,2017: 5767-5777.
[18]Miyato T,Kataoka T,Koyama M,et al. Spectral normalization for ge-nerative adversarial networks [C]// Proc of International Conference on Learning Representations. 2018.
[19]Zhou Zhiming,Liang Jiadong,Song Yuxuan,et al. Lipschitz generative adversarial nets [C]// Proc of International Conference on Machine Learning. New York: ACM Press,2019: 7584-7593.
[20]Saxena D,Cao J. Generative adversarial networks (GANs) challenges,solutions,and future directions [J].ACM Computing Surveys ,202 54 (3): 1-42.
[21]Karnewar A,Wang O. MSG-GAN: multi-scale gradients for generative adversarial networks [C]// Proc of IEEE/CVF International Conference on Computer Vision. Piscataway,NJ: IEEE Press,2020: 7799-7808.
[22]Bau D,Zhu Junyan,Wulff J,et al. Seeing what a GAN cannot gene-rate [C]// Proc of IEEE/CVF International Conference on Computer Vision. Piscataway,NJ: IEEE Press,2019: 4502-4511.
[23]Krizhevsky A,Hinton G. Learning multiple layers of features from tiny images [J].Handbook of Systemic Autoimmune Diseases ,2009, 1 (4): 1-60.
[24]Hao Yaru,Dong Li,Wei Furu,et al. Self-attention attribution: interpreting information interactions inside Transformer [C]// Proc of AAAI Conference on Artificial Intelligence. Palo Alto,CA: AAAI Press,2021: 12963-12971.
[25]Wang Junpu,Xu Guili,Li Chunlei,et al. Defect Transformer: an ef-ficient hybrid Transformer architecture for surface defect detection [C]// Proc of International Conference on Learning Representations. 2022.
[26]Coates A,Ng A,Lee H. An analysis of single-layer networks in unsupervised feature learning [C]// Proc of International Conference on Learning Representations. 2011: 215-223.
[27]Falcon W,Cho K. A framework for contrastive self-supervised lear-ning and designing a new approach [C]//Proc of International Con-ference on Learning Representations. 2009.
[28]Zhao Shengyu,Liu Zhijian,Lin Ji,et al. Differentiable augmentation for data-efficient GAN training [C]// Proc of Advances in Neural Information Processing Systems. Cambridge,MA: MIT Press,2020: 7559-7570.
[29]Bartlett P,Wegkamp M. Classification with a reject option using a hinge loss [J].Journal of Machine Learning Research ,2008, 9 (8): 1823-1840.
[30]Salimans T,Goodfellow I,Zaremba W,et al. Improved techniques for training GANs [C]// Advances in Neural Information Processing Systems. Cambridge,MA: MIT Press,2016: 2234-2242.
[31]Heusel M,Ramsauer H,Unterthiner T,et al. GANs trained by a two time-scale update rule converge to a local Nash equilibrium [C]// Advances in Neural Information Processing Systems. Cambridge,MA: MIT Press,2017: 6626-6637.
[32]Gong Xinyu,Chang Shiyu,Jiang Yifan,et al. AutoGAN: neural architecture search for generative adversarial networks [C]// Proc of IEEE/CVF International Conference on Computer Vision. Pisca-taway,NJ: IEEE Press,2019: 3224-3234.
[33]Ali R,Reza T,Reza D. Fire and smoke detection using wavelet analysis and disorder characteristics [C]// Proc of International Confe-rence on Computer Research and Development. Piscataway,NJ: IEEE Press,2011: 262-265.
收稿日期:2023-07-25;修回日期:2023-09-15
作者簡(jiǎn)介:錢惠敏(1980—),女,江蘇宜興人,副教授,碩導(dǎo),博士,CCF會(huì)員,主要研究方向?yàn)橛?jì)算機(jī)視覺、機(jī)器學(xué)習(xí);毛邱凌(1998—),男(通信作者),江蘇南通人,碩士研究生,主要研究方向?yàn)榛谏蓪?duì)抗網(wǎng)絡(luò)的圖像生成(am_hohai@163.com);陳實(shí)(1999—),男,江蘇鹽城人,碩士研究生,主要研究方向?yàn)榛谏疃壬窠?jīng)網(wǎng)絡(luò)的視頻分析與理解;韓怡星(1998—),男,江蘇蘇州人,碩士研究生,主要研究方向?yàn)榛谏疃葘W(xué)習(xí)的圖像理解;呂本杰(1998—),男,江蘇南通人,碩士研究生,主要研究方向?yàn)樾颖緢D像的生成.