曹爽
摘? 要: 為了提升合成表格數(shù)據(jù)的質(zhì)量,提出一種簡(jiǎn)單的方法生成每個(gè)類的數(shù)據(jù),使用度量損失控制每一類結(jié)構(gòu)化數(shù)據(jù)的生成,將此方法命名為SCGAN。文章用此方法在二分類問題上進(jìn)行了嘗試。使用三種不同的度量損失在三個(gè)真實(shí)的數(shù)據(jù)集上訓(xùn)練生成對(duì)抗網(wǎng)絡(luò):逐次對(duì)每一類數(shù)據(jù)進(jìn)行合成,利用合成數(shù)據(jù)訓(xùn)練分類器模型,使用gmean來評(píng)估模型的性能。結(jié)果表明,單獨(dú)生成每一類數(shù)據(jù)能夠提升模型的分類性能。
關(guān)鍵詞: 合成數(shù)據(jù); 度量損失; 生成對(duì)抗網(wǎng)絡(luò); 分類器
中圖分類號(hào):TP391? ? ? ? ? 文獻(xiàn)標(biāo)識(shí)碼:A? ? ?文章編號(hào):1006-8228(2021)04-25-03
Abstract: In order to improve the quality of tabular data synthesis, a simple method to generate data of each category is proposed, and it is named SCGAN and uses metrics loss to control the generation of structured data of each category. In this paper, the binary classification problem is tried to be solved by this method. By using three different metrics losses, the generative adversarial network is trained on three real datasets that each category of data are synthesized one by one, the classifier model are trained with the synthesized data, and gmean is used to evaluate the performance of the model. The results show that generating each category of data separately can improve the classification performance of the model.
Key words: synthesized data; metrics loss; generative adversarial networks; classifier
0 引言
近年來,生成對(duì)抗網(wǎng)絡(luò)在生成高質(zhì)量合成圖像方面取得了很大的成功[1]。多種數(shù)據(jù)類型,數(shù)據(jù)分布不確定,多模態(tài)分布,數(shù)據(jù)不均衡等特點(diǎn)對(duì)生成表格型數(shù)據(jù)帶來了挑戰(zhàn)[2]。MedGAN提出醫(yī)學(xué)生成對(duì)抗網(wǎng)絡(luò),來生成逼真的合成病歷[3]。TableGAN使用生成對(duì)抗網(wǎng)絡(luò)來合成假表,這些假表在統(tǒng)計(jì)上類似于原始表[4]。CTGAN對(duì)連續(xù)數(shù)據(jù)進(jìn)行建模,對(duì)離散數(shù)據(jù)增加條件損失來合成高質(zhì)量數(shù)據(jù)[2]。
本文在CTGAN的基礎(chǔ)上提出一種無(wú)監(jiān)督的生成對(duì)抗網(wǎng)絡(luò)方法,將衡量指標(biāo)FID[5],MMD[6],最小二乘作為度量模塊應(yīng)用到生成對(duì)抗網(wǎng)絡(luò)模型中,利用單個(gè)類別的數(shù)據(jù)訓(xùn)練模型生成大量的合成數(shù)據(jù),利用梯度懲罰[7]和譜歸一化方法[8]來增強(qiáng)模型訓(xùn)練的穩(wěn)定性。在三個(gè)真實(shí)的數(shù)據(jù)集上選取相同數(shù)量的生成數(shù)據(jù)對(duì)三種度量方法做了比較,實(shí)驗(yàn)結(jié)果顯示,本文提出的方法能夠提升生成數(shù)據(jù)的質(zhì)量,提升模型分類的性能。
1 SCGAN
1.1 生成對(duì)抗網(wǎng)絡(luò)
生成對(duì)抗網(wǎng)絡(luò)是一種生成模型[1],包含生成器(G)和判別器(D)兩部分。生成器目的是生成逼真的合成數(shù)據(jù)以最大程度的騙過判別器來達(dá)到損失的最小化,判別器爭(zhēng)取將真實(shí)數(shù)據(jù)和合成數(shù)據(jù)分別開來[9]。以下為生成對(duì)抗網(wǎng)絡(luò)的一般形式:
其中[z]是隨機(jī)輸入的噪聲,一般為高斯分布中的隨機(jī)采樣點(diǎn),[pz]是潛在向量[z]的先驗(yàn)分布,[G?]是生成器函數(shù),[D?]是判別器函數(shù)。
1.2 度量損失
為了保證生成數(shù)據(jù)的質(zhì)量,將三種度量損失:FID,MMD,最小二乘等加入到生成對(duì)抗網(wǎng)絡(luò)模型中,由于最小二乘比較簡(jiǎn)單,在此我們著重介紹前兩種方法。
⑴ Frechet Inception Distance (FID)
FID[5]常用于評(píng)估生成器最終生成的圖像質(zhì)量,計(jì)算真實(shí)數(shù)據(jù)和合成數(shù)據(jù)在特征層面的距離,距離越小,說明合成數(shù)據(jù)與真實(shí)數(shù)據(jù)越相似,以下是FID的計(jì)算公式:
其中[Pr],[Pg]分別表示真實(shí)數(shù)據(jù)和生成數(shù)據(jù),[C]表示數(shù)據(jù)的協(xié)方差矩陣,[u]表示數(shù)據(jù)的均值,我們將這種評(píng)估方式應(yīng)用到生成表格數(shù)據(jù)的生成對(duì)抗模型中,參與生成器模型的訓(xùn)練,鼓勵(lì)生成器學(xué)習(xí)真實(shí)數(shù)據(jù)的分布。
⑵ Maximum Mean Discrepancy (MMD)
MMD[6]是一種基于最大均方差的統(tǒng)計(jì)檢驗(yàn)來優(yōu)化兩類樣本的分布,常用于評(píng)估生成圖像的質(zhì)量。此處,我們使用MMD衡量生成的結(jié)構(gòu)化數(shù)據(jù),定義如下:
給定兩類結(jié)構(gòu)化數(shù)據(jù)集,[V=v1,v2,…vm]和[W=w1,w2,…wm],以下為MMD計(jì)算公式:
其中[k?]是高斯核函數(shù)。
1.3 SCGAN整體流程
整體流程如圖1所示。我們使用生成對(duì)抗網(wǎng)絡(luò)對(duì)劃分好的訓(xùn)練集進(jìn)行訓(xùn)練,生成指定類別的合成數(shù)據(jù),TrainData0表示第一類數(shù)據(jù)對(duì)應(yīng)生成數(shù)據(jù)Fake0,TrainData1表示第二類數(shù)據(jù)對(duì)應(yīng)生成數(shù)據(jù)Fake1,在G,D網(wǎng)絡(luò)中我們遵循了CTGAN的網(wǎng)絡(luò)結(jié)構(gòu),但是由于我們是生成指定類別的數(shù)據(jù),所以在生成器和判別器中去除了條件輸入,在G中加入了3種度量損失函數(shù)。當(dāng)生成指定類別的數(shù)據(jù)后,對(duì)生成的數(shù)據(jù)每個(gè)類分別選取500個(gè)和1000個(gè)樣本,最終組成1000和2000大小的訓(xùn)練集,訓(xùn)練分類器(SVM,RF,DT)模型,使用gmean[10]評(píng)估分類器的性能。
2 實(shí)驗(yàn)
2.1 數(shù)據(jù)集介紹
本文研究的數(shù)據(jù)集來自于①Covtype,用來預(yù)測(cè)森林覆蓋類型的多分類數(shù)據(jù)集,我們選擇了Ponderosa Pine,Krummholz這兩類數(shù)據(jù)來測(cè)試我們的模型。②Adult是一個(gè)從人口普查數(shù)據(jù)庫(kù)中提取的個(gè)人信息記錄的數(shù)據(jù)集,我們將收入是否超過50k,作為分類的二進(jìn)制標(biāo)簽。③BitcoinHeist是一個(gè)有關(guān)比特幣交易圖的數(shù)據(jù)集,簡(jiǎn)記為Bit,從中選取了princetonCerber和montrealCryptoLocker類別的數(shù)據(jù),對(duì)數(shù)據(jù)進(jìn)行二分類。
2.2 方法比較
在我們的SCGAN中,我們對(duì)比了使用不同度量下生成樣本的質(zhì)量,而且也與不加度量損失的生成對(duì)抗網(wǎng)絡(luò)和原始的CTGAN進(jìn)行了對(duì)比。SCGAN-FID表示在生成器上使用FID作為度量損失,SCGAN-MMD表示在生成器上使用MMD作為度量損失,SCGAN-LS表示在生成器上使用最小二乘作為度量損失,GAN表示沒有加度量損失。值得注意的一點(diǎn),在三種度量方法和沒有使用度量方法的GAN中,除了損失函數(shù)的差異,其他迭代次數(shù)和網(wǎng)絡(luò)都是一致的。
2.3 實(shí)驗(yàn)結(jié)果
在實(shí)驗(yàn)中,我們記錄了每一種方法以及每一種數(shù)據(jù)集在每一種基分類器實(shí)驗(yàn)結(jié)果,為了顯現(xiàn)整體的有效性,表1至表3是每一種方法在三個(gè)基分類器上的平均結(jié)果。從表1和表2中可以看到,在三個(gè)真實(shí)的數(shù)據(jù)集上,本文提出的SCGAN整體優(yōu)于CTGAN,另外,在表3中,我們記錄了不使用度量損失下的GAN模型的性能,根據(jù)在gmean指標(biāo)上的評(píng)估可以看到,進(jìn)一步說明了度量損失的有效性。
3 總結(jié)
本文提出的SCGAN,分別進(jìn)行每一類別的數(shù)據(jù)合成,通過實(shí)驗(yàn)表明能夠提升模型的分類性能。我們只在二分類問題上進(jìn)行了嘗試,將此方法應(yīng)用到多類不均衡數(shù)據(jù)集中是我們接下來的研究重點(diǎn)。
參考文獻(xiàn)(References):
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems,2014:2672-2680
[2] Xu L, Skoularidou M, Cuesta-Infante A, et al. Modeling tabular data using conditional gan[C]//Advances in Neural Information Processing Systems,2019:7335-7345
[3] Choi E, Biswal S, Malin B, et al. Generating multi-label discrete patient records using generative adversarial networks[J]. arXiv preprint arXiv:1703.06490,2017.
[4] Park N, Mohammadi M, Gorde K, et al. Data synthesis based on generative adversarial networks[J].arXiv preprint arXiv:1806.03384,2018.
[5] Heusel M, Ramsauer H, Unterthiner T, et al. Gans trained by a two time-scale update rule converge to a local nash equilibrium[J]. Advances in neural information processing systems,2017.30: 6626-6637
[6] Sutherland D J, Tung H Y, Strathmann H, et al.Generative models and model criticism via optimized maximum mean discrepancy[J]. arXiv preprint arXiv:1611.04488,2016.
[7] Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems,2017.30: 5767-5777
[8] Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks[J].arXiv preprint arXiv:1802.05957,2018
[9] 張重生著.人工智能 人臉識(shí)別與搜索[M].電子工業(yè)出版社,2020.
[10] Leevy J L, Khoshgoftaar T M, Bauder R A, et al. A survey on addressing high-class imbalance in big data[J]. Journal of Big Data,2018.5(1):42