王 聰,王 杰,劉全明,梁吉業(yè)
1.山西大學(xué) 計(jì)算機(jī)與信息技術(shù)學(xué)院,太原030006
2.山西大學(xué) 計(jì)算智能與中文信息處理教育部重點(diǎn)實(shí)驗(yàn)室,太原030006
傳統(tǒng)的監(jiān)督學(xué)習(xí),如支持向量機(jī)(support vector machine,SVM)、神經(jīng)網(wǎng)絡(luò)(neural networks,NN)等,通常需要大量良好的標(biāo)記樣本對(duì)模型進(jìn)行訓(xùn)練,以便獲得較好的模型泛化能力。同時(shí),在處理高維數(shù)據(jù)(如視頻、語(yǔ)音、圖像、文檔)時(shí),訓(xùn)練一個(gè)好的監(jiān)督模型所需要的標(biāo)記樣本數(shù)量會(huì)進(jìn)一步增長(zhǎng)。這使得傳統(tǒng)監(jiān)督學(xué)習(xí)很難應(yīng)用于一些缺乏標(biāo)記訓(xùn)練樣本的任務(wù)中。
半監(jiān)督學(xué)習(xí)(semi-supervised learning,SSL)[1]是近十多年發(fā)展起來(lái)的一種新型機(jī)器學(xué)習(xí)方法,其思想是在標(biāo)記樣本數(shù)量很少的情況下,通過(guò)在模型訓(xùn)練中引入無(wú)標(biāo)記樣本來(lái)避免傳統(tǒng)監(jiān)督學(xué)習(xí)在訓(xùn)練樣本不足(學(xué)習(xí)不充分)時(shí)出現(xiàn)性能(或模型)退化的問(wèn)題。半監(jiān)督學(xué)習(xí)的研究具有重要的實(shí)用價(jià)值,因?yàn)樵谠S多實(shí)際應(yīng)用中,無(wú)標(biāo)記樣本的獲取相對(duì)容易,而標(biāo)記樣本的獲取成本往往較高。因此,減少標(biāo)記樣本的使用能夠大幅縮減人力、時(shí)間和資源的開(kāi)銷,從而降低生產(chǎn)成本。同時(shí)在標(biāo)記樣本數(shù)量減少數(shù)十或數(shù)百倍(甚至更多)的情況下,半監(jiān)督算法能夠取得與傳統(tǒng)監(jiān)督學(xué)習(xí)算法相近甚至更好的效果,提升生產(chǎn)效率。半監(jiān)督學(xué)習(xí)的研究具有重要的理論價(jià)值,它是介于傳統(tǒng)監(jiān)督學(xué)習(xí)和無(wú)監(jiān)督學(xué)習(xí)之間的一種新型機(jī)器學(xué)習(xí)方法,是對(duì)傳統(tǒng)機(jī)器學(xué)習(xí)理論的拓展和補(bǔ)充。
圖半監(jiān)督學(xué)習(xí)(semi-supervised learning on graphs)作為半監(jiān)督學(xué)習(xí)的一個(gè)重要分支,在理論和實(shí)踐上引起了極大的關(guān)注。給定一個(gè)由少量標(biāo)記節(jié)點(diǎn)和大量未標(biāo)記節(jié)點(diǎn)組成的圖,它的目標(biāo)是為圖中的未標(biāo)記節(jié)點(diǎn)分配標(biāo)簽。生成對(duì)抗網(wǎng)絡(luò)(generative adversarial networks,GAN)[2]由于其強(qiáng)大的表征能力已經(jīng)被廣泛應(yīng)用于半監(jiān)督學(xué)習(xí),但它在圖半監(jiān)督學(xué)習(xí)任務(wù)上的工作較少?,F(xiàn)有的工作主要關(guān)注在低密度區(qū)域生成未標(biāo)記樣本來(lái)削弱子圖之間的信息傳播,從而使決策邊界更清晰,如GraphSGAN[3]通過(guò)GAN 在子圖之間的低密度區(qū)域生成未標(biāo)記樣本,減少子圖邊緣節(jié)點(diǎn)的影響,從而提高圖半監(jiān)督分類效果。但受限于標(biāo)記樣本過(guò)少,監(jiān)督信息的不足仍在一定程度上限制了其性能。針對(duì)這個(gè)問(wèn)題,本文提出了一種新的圖半監(jiān)督學(xué)習(xí)框架(semi-supervised learning on graphs using adversarial training with generated sample,SemiGATDS),它由圖嵌入模塊、兩個(gè)生成器、一個(gè)分類器和一個(gè)判別器五部分組成。其中,圖嵌入模塊將圖映射到特征空間,在特征空間中,一個(gè)生成器生成服從真實(shí)樣本分布的標(biāo)記樣本,另一個(gè)生成器生成與真實(shí)樣本分布不同的未標(biāo)記樣本。分類器負(fù)責(zé)為給定的樣本分配標(biāo)簽,判別器用來(lái)區(qū)分樣本標(biāo)簽對(duì)是否來(lái)自真實(shí)分布。通過(guò)生成器、判別器和分類器的對(duì)抗訓(xùn)練,當(dāng)模型達(dá)到穩(wěn)態(tài)時(shí),生成的標(biāo)記樣本擴(kuò)充了標(biāo)記樣本訓(xùn)練集,生成的未標(biāo)記樣本削弱了子圖邊緣節(jié)點(diǎn)的影響,迫使分類界限更加清晰,從而提高了分類效果。本文在Cora、Citeseer、Pubmed[4]三個(gè)數(shù)據(jù)集上評(píng)估了SemiGATDS 的分類性能,并討論了不同數(shù)量的標(biāo)記樣本和不同生成樣本比例對(duì)算法的影響,實(shí)驗(yàn)結(jié)果驗(yàn)證了本文方法的有效性。
半監(jiān)督學(xué)習(xí)旨在利用大量未標(biāo)記樣本來(lái)提高模型性能。半監(jiān)督學(xué)習(xí)有以下幾種范式:生成式方法[5]、基于支持向量機(jī)的半監(jiān)督學(xué)習(xí)算法[6]、基于分歧的方法[7]和圖半監(jiān)督學(xué)習(xí)[8-9]。其中,由于圖半監(jiān)督學(xué)習(xí)解釋性強(qiáng)、性能優(yōu)越,受到很多的關(guān)注,它的核心思想是數(shù)據(jù)集中每個(gè)樣本對(duì)應(yīng)于圖中一個(gè)節(jié)點(diǎn),若兩個(gè)樣本之間的相似度很高(或相關(guān)性很強(qiáng)),則對(duì)應(yīng)的節(jié)點(diǎn)之間存在一條邊,邊的“強(qiáng)度”(strength)正比于樣本之間的相似度(或相關(guān)性)。利用圖上的鄰接關(guān)系將標(biāo)簽從標(biāo)記樣本向無(wú)標(biāo)記樣本傳播。
關(guān)于圖半監(jiān)督學(xué)習(xí)的研究大致分為兩類,基于圖的拉普拉斯正則化框架[10]是其中一個(gè)重要的研究方向。Zhou 等人[11]通過(guò)在損失函數(shù)中使用基于圖的拉普拉斯正則化項(xiàng),在圖上平滑標(biāo)簽信息。文獻(xiàn)[12]提出了一種基于高斯隨機(jī)場(chǎng)和形式化圖拉普拉斯正則化框架的算法。Belkin 等人[13]提出了一種利用幾何的邊緣分布理論進(jìn)行半監(jiān)督學(xué)習(xí)的正則化方法ManiReg。另一個(gè)研究方向是將半監(jiān)督學(xué)習(xí)與圖嵌入[14]相結(jié)合。文獻(xiàn)[15]首次將深度神經(jīng)網(wǎng)絡(luò)引入圖的拉普拉斯正則化框架中進(jìn)行半監(jiān)督學(xué)習(xí)和圖嵌入。Yang 等人[16]提出了聯(lián)合圖嵌入學(xué)習(xí)和節(jié)點(diǎn)標(biāo)簽預(yù)測(cè)模型Planetoid。DeepWalk[17]是第一個(gè)關(guān)于圖嵌入的工作,作為一種無(wú)監(jiān)督圖嵌入學(xué)習(xí)方法,如果與分類器相結(jié)合,很容易轉(zhuǎn)化為半監(jiān)督學(xué)習(xí)基線模型。圖卷積神經(jīng)網(wǎng)絡(luò)(graph convolutional network,GCN)[18]是第一個(gè)用于圖半監(jiān)督學(xué)習(xí)的圖卷積模型,它在這個(gè)問(wèn)題上表現(xiàn)出了強(qiáng)大的能力。
GAN 作為一種功能強(qiáng)大的深度生成模型,最早用來(lái)表示自然圖像上的數(shù)據(jù)分布,通過(guò)生成器和判別器的互相博弈學(xué)習(xí)產(chǎn)生更好的輸出。最近在半監(jiān)督學(xué)習(xí)框架中展示了它們的能力[19]。半監(jiān)督生成對(duì)抗網(wǎng)絡(luò)(semi-supervised generative adversarial networks,SGAN)[20]最早是在計(jì)算機(jī)視覺(jué)領(lǐng)域提出的。SGAN用分類器取代了GAN 中的判別器。為了防止生成器過(guò)度訓(xùn)練,Salimans 等人[21]首次提出特征匹配損失,將GAN 應(yīng)用于關(guān)于“K+1”類的半監(jiān)督學(xué)習(xí)。Li 等人[22]認(rèn)識(shí)到生成器和判別器可能無(wú)法同時(shí)達(dá)到最優(yōu),并且無(wú)法控制生成樣本的語(yǔ)義信息,提出了Triple-GAN。隨著標(biāo)記樣本數(shù)量的減少,Triple-GAN 的性能改善更加顯著,這表明生成的樣本標(biāo)簽對(duì)可以有效地用于訓(xùn)練分類器。文獻(xiàn)[23]意識(shí)到生成器也存在同樣的問(wèn)題,從理論上解釋了為什么生成與真實(shí)樣本分布不同的樣本可以提高SSL 性能。通過(guò)精心設(shè)計(jì)生成器的損失,生成器可以生成與真實(shí)樣本分布不同的樣本,迫使分類器的決策邊界位于不同類的數(shù)據(jù)流形之間,這反過(guò)來(lái)又增強(qiáng)了分類器的泛化能力。
基于GAN 的圖半監(jiān)督學(xué)習(xí)的研究工作較少,如GraphSGAN。這項(xiàng)工作的主要思想是在子圖之間的密度間隙生成未標(biāo)記樣本,削弱不同類之間的信息傳播,但是用于訓(xùn)練的標(biāo)記樣本過(guò)少仍然是制約其性能的關(guān)鍵。針對(duì)這個(gè)問(wèn)題,本文提出了Semi-GATDS,該算法同時(shí)生成服從真實(shí)樣本分布的標(biāo)記樣本和與真實(shí)樣本分布不同的未標(biāo)記樣本,以提高圖半監(jiān)督學(xué)習(xí)性能。
設(shè)G=(V,E)表示一個(gè)圖,其中V代表節(jié)點(diǎn)集,E?V×V代表邊集。假設(shè)每個(gè)節(jié)點(diǎn)vi與d維實(shí)值特征向量wi∈Rd和標(biāo)簽yi∈{1,2,…,K}相關(guān)聯(lián)。如果節(jié)點(diǎn)vi的標(biāo)簽yi未知,則節(jié)點(diǎn)vi是一個(gè)未標(biāo)記節(jié)點(diǎn)。設(shè)標(biāo)記節(jié)點(diǎn)集合為VL,未標(biāo)記節(jié)點(diǎn)集合為VU=V?VL。通常,有|VL|?|VU|。由此,本文形式化地定義圖上的半監(jiān)督學(xué)習(xí)問(wèn)題,給定部分標(biāo)記圖G=(VL?VU,E),使用與每個(gè)節(jié)點(diǎn)和圖相關(guān)聯(lián)的特征w來(lái)學(xué)習(xí)函數(shù)f,預(yù)測(cè)圖中未標(biāo)記節(jié)點(diǎn)的標(biāo)簽。
本文模型框架如圖1 所示,SemiGATDS 由五部分組成,分別是圖嵌入模塊、兩個(gè)生成器、一個(gè)分類器和一個(gè)判別器?;贕AN 的模型不能直接應(yīng)用于圖數(shù)據(jù),因此,遵循文獻(xiàn)[3]的設(shè)置,首先使用網(wǎng)絡(luò)表示學(xué)習(xí)算法(本文使用TADW(text-associated deepwalk)[24]對(duì)節(jié)點(diǎn)原始特征進(jìn)行預(yù)處理)學(xué)習(xí)每個(gè)節(jié)點(diǎn)的潛在分布表示qi,然后將潛在分布表示qi與原始特征向量wi拼接,即xi=(wi,qi)。在模型中,將生成標(biāo)記樣本的生成器稱為gG,它接受真實(shí)標(biāo)簽y和隨機(jī)噪聲z作為輸入,并生成以y為標(biāo)簽的服從真實(shí)樣本分布的標(biāo)記樣本;生成未標(biāo)記樣本的生成器稱為bG,它接受隨機(jī)噪聲z為輸入,生成與真實(shí)樣本分布不同的未標(biāo)記樣本;分類器C,為給定的樣本分配標(biāo)簽;判別器D,判斷樣本標(biāo)簽對(duì)是否來(lái)自真實(shí)樣本分布。
圖1 SemiGATDS 模型示意圖Fig.1 Illustration of SemiGATDS
在模型中,考慮“K+1”類分類問(wèn)題。gG首先通過(guò)真實(shí)標(biāo)簽y和200 維隨機(jī)噪聲z,采樣于先驗(yàn)分布Pz(z)(實(shí)驗(yàn)中使用均勻分布噪聲z)生成樣本xgG~PgG(x|y,z),與條件標(biāo)簽y組成標(biāo)記樣本。接著bG通過(guò)隨機(jī)噪聲z生成未標(biāo)記樣本xbG~PbG(x|z) 。C接受四種不同類型的樣本:標(biāo)記樣本xL、未標(biāo)記樣本xU、來(lái)自gG的生成樣本xgG和來(lái)自bG的生成樣本xbG,并依據(jù)條件分布PC(y|x)為它們產(chǎn)生偽標(biāo)簽。對(duì)于帶標(biāo)簽的數(shù)據(jù)xL和gG生成的樣本xgG,期望C為它們分配正確的標(biāo)簽(為xL分配標(biāo)簽yL,為xgG分配它的條件標(biāo)簽y)。對(duì)于bG生成的樣本xbG和未標(biāo)記樣本xU,期望C將它們分別識(shí)別為第“K+1”類(即“假”類)和前K類其中之一。D接受C和gG生成的樣本標(biāo)簽對(duì)(xC,yC) 和(xgG,ygG),以及標(biāo)記樣本(xL,yL)作為輸入,并將標(biāo)記樣本標(biāo)簽對(duì)視為真樣本,而來(lái)自gG和C的樣本標(biāo)簽對(duì)均為假樣本。定義各個(gè)部分的損失如下:
將gG的損失函數(shù)定義為:
其中,PD(x,y)表示樣本標(biāo)簽對(duì)(x,y),來(lái)自真實(shí)樣本分布的概率,最小化損失函數(shù),使得gG生成更接近真實(shí)樣本分布的標(biāo)記樣本。
bG的損失函數(shù)定義為:
為特征匹配損失,它最小化了bG生成樣本與真實(shí)樣本中心點(diǎn)之間的距離,以確保生成器在類和類之間的密度間隙中生成樣本。
為pull-away term[11],它具有增加生成特征的多樣性從而增加生成熵的效果,這里可以鼓勵(lì)bG生成更多不同類別的樣本。其中N是批次大小,xi、xj是同一批次的樣本。λ0是用來(lái)平衡兩個(gè)損失的超參數(shù),實(shí)驗(yàn)中將其設(shè)置為1。
C的損失函數(shù)由四部分組成:
C的總損失是:
其中,損失和損失分別表示標(biāo)記樣本和gG生成樣本的交叉熵?fù)p失,損失迫使C將未標(biāo)記樣本識(shí)別為前“K”類,而損失迫使C將bG生成的樣本識(shí)別為“K+1”類。λ1、λ2、λ3是用于平衡每個(gè)損失的超參數(shù),實(shí)驗(yàn)中將這三個(gè)超參數(shù)均設(shè)置為0.5。
最后,判別器D的損失由三部分組成,分別為:
D的總損失是:
其中,損失迫使判別器D增大真實(shí)標(biāo)記樣本對(duì)被視為真類的概率,損失和迫使判別器D減小生成樣本標(biāo)簽對(duì)被視為真類的概率,β1、β2是用于平衡每個(gè)損失的超參數(shù),實(shí)驗(yàn)中將這兩個(gè)超參數(shù)均設(shè)置為1。
在訓(xùn)練過(guò)程中,SemiGATDS 由三組對(duì)抗訓(xùn)練組成:(1)gG通過(guò)生成以標(biāo)簽y為條件的標(biāo)記樣本來(lái)與D進(jìn)行對(duì)抗訓(xùn)練;(2)C通過(guò)為未標(biāo)記樣本生成置信度高的標(biāo)簽與D進(jìn)行對(duì)抗訓(xùn)練;(3)bG通過(guò)生成未標(biāo)記樣本與C進(jìn)行對(duì)抗訓(xùn)練。生成的未標(biāo)記樣本迫使分類界限更清晰,生成的標(biāo)記樣本對(duì)擴(kuò)充了監(jiān)督信息,模型從這兩種生成樣本中學(xué)習(xí)。詳細(xì)的訓(xùn)練過(guò)程如算法1 所示。
算法1SemiGATDS 訓(xùn)練算法
假設(shè)給定圖G=(VL?VU,E),其中節(jié)點(diǎn)總數(shù)為s(包含標(biāo)記節(jié)點(diǎn)和未標(biāo)記節(jié)點(diǎn)),節(jié)點(diǎn)特征維度d,圖嵌入表示維度e,節(jié)點(diǎn)類別數(shù)為k。本文算法的時(shí)間復(fù)雜度主要由計(jì)算節(jié)點(diǎn)的潛在分布表示和訓(xùn)練生成器、分類器、判別器四個(gè)神經(jīng)網(wǎng)絡(luò)產(chǎn)生。其中圖嵌入算法TADW 的時(shí)間復(fù)雜度為O(s2)。
本文使用的生成器、分類器、判別器均采用全連接神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)。神經(jīng)網(wǎng)絡(luò)時(shí)間復(fù)雜度依據(jù)浮點(diǎn)運(yùn)算次數(shù)計(jì)算,一次浮點(diǎn)運(yùn)算可以定義為一次乘法和一次加法。生成器和判別器均是擁有兩個(gè)隱藏層的神經(jīng)網(wǎng)絡(luò),分別具有(c1,c1)個(gè)神經(jīng)元,bG生成器輸入為隨機(jī)噪聲z,維度為t1,輸出為節(jié)點(diǎn)特征和節(jié)點(diǎn)圖嵌入表示拼接后的維度d+e,第一層執(zhí)行t1×c1次乘加操作,第二層執(zhí)行c1×c1次乘加操作,最后一層執(zhí)行c1×(d+e) 次操作,總共執(zhí)行t1×c1+c1×c1+c1×(d+e)次操作,假設(shè)每批次訓(xùn)練m個(gè)樣本,bG生成器的總操作次數(shù)為m(t1×c1+c1×c1+c1×(d+e)),時(shí)間復(fù)雜度為O(m×c1×(d+e))。gG生成器輸入為隨機(jī)噪聲z與標(biāo)簽y的拼接,標(biāo)簽y經(jīng)過(guò)編碼后其維度為t2,因此gG生成器的輸入維度為t1+t2,其每批次訓(xùn)練m個(gè)樣本,gG生成器的總操作次數(shù)為m((t1+t2)×c1+c1×c1+c1×(d+e)),時(shí)間復(fù)雜度為O(m×c1×(d+e))。判別器D的輸入維度即節(jié)點(diǎn)特征維度為d+e,輸出為真假即維度為1,其每批次訓(xùn)練總操作次數(shù)為m(c1×(d+e)+c1×c1+c1),時(shí)間復(fù)雜度為O(m×c1×(d+e))。分類器C輸入維度即節(jié)點(diǎn)特征維度為d+e,擁有5個(gè)隱藏層的神經(jīng)網(wǎng)絡(luò),分別具有(c1,c1,c2,c2,c2)個(gè)神經(jīng)元輸出為類別個(gè)數(shù),其維度為k。以此類推,每批次訓(xùn)練總操作數(shù)為m((d+e)×c1+c1×c1+c1×c2+2c2×c2+c2×k),時(shí)間復(fù)雜度為O(m×c1×(d+e))。
綜上,SemiGATDS 算法總的時(shí)間復(fù)雜度為O(s2)+O(m×c1×(d+e))。
數(shù)據(jù)集統(tǒng)計(jì)匯總?cè)绫? 所示。在引文網(wǎng)絡(luò)數(shù)據(jù)集Citeseer、Cora 和Pubmed 中,節(jié)點(diǎn)是文檔,邊是引文鏈接。標(biāo)記節(jié)點(diǎn)數(shù)表示用于訓(xùn)練的標(biāo)記節(jié)點(diǎn)的個(gè)數(shù)。每個(gè)文檔都有以詞袋模型(bag-of-words model)表示的特征,并根據(jù)主題賦予特定的標(biāo)簽。
表1 數(shù)據(jù)集統(tǒng)計(jì)Table 1 Dataset statistics
為了避免過(guò)度調(diào)整網(wǎng)絡(luò)體系結(jié)構(gòu)和超參數(shù),所有實(shí)驗(yàn)均使用默認(rèn)設(shè)置進(jìn)行訓(xùn)練與測(cè)試。具體地說(shuō),分類器C有5 個(gè)隱藏層,分別具有(500,500,250,250,250)個(gè)神經(jīng)元。隨機(jī)層采用零均值高斯噪聲,隱藏層輸入標(biāo)準(zhǔn)差為0.05,輸出標(biāo)準(zhǔn)差為0.5。生成器bG具有兩個(gè)500 個(gè)神經(jīng)元的隱藏層,每個(gè)隱藏層后面都有一個(gè)批歸一化層,輸出層使用Tanh 激活函數(shù)。生成器gG和bG具有相同結(jié)構(gòu),不同的是前者以噪聲z和真實(shí)標(biāo)簽y的拼接作為輸入。判別器也采用和生成器相同的隱藏層結(jié)構(gòu),只是對(duì)輸入層和輸出層作了相應(yīng)的調(diào)整。模型由ADAM 進(jìn)行優(yōu)化,所有參數(shù)均使用Xavier初始化方法。
為了公平比較,實(shí)驗(yàn)遵循文獻(xiàn)[16]中的設(shè)置,對(duì)于每個(gè)類,選擇20 個(gè)樣本(文檔)作為標(biāo)記樣本用于訓(xùn)練,同時(shí)選擇1 000 個(gè)樣本作為測(cè)試樣本。所有實(shí)驗(yàn)結(jié)果取10 次隨機(jī)拆分的平均值。在這3 個(gè)數(shù)據(jù)集中,將提出的方法SemiGATDS 與4 類方法進(jìn)行了比較:
(1)基于正則化的方法LP(label propagation)[11]、ICA(iterative classification algorithm)[25]和ManiReg[13];
(2)基于圖嵌入的方法DeepWalk[17]、SemiEmb[15]和Planetoid[16];
(3)基于圖卷積的方法Chebyshev[26]、GCN[18];
(4)基于GAN 的方法Triple-GAN[22]、GraphSGAN[3]。
由于原始Triple-GAN 并未用于圖,本文在圖上重新實(shí)現(xiàn)了Triple-GAN,并復(fù)現(xiàn)了GraphSGAN,在3個(gè)數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn)。其中Triple-GAN 的生成器生成服從真實(shí)樣本分布的標(biāo)記樣本,而GraphSGAN的生成器生成與真實(shí)樣本分布不同的未標(biāo)記樣本。
本文在3 個(gè)數(shù)據(jù)集上均訓(xùn)練了200 個(gè)epoch。表2 顯示了SemiGATDS 與上述方法對(duì)比的實(shí)驗(yàn)結(jié)果。
表2 分類準(zhǔn)確率匯總Table 2 Summary of results of classification accuracy 單位:%
實(shí)驗(yàn)結(jié)果表明,本文方法優(yōu)于所有基于正則化、圖嵌入以及圖卷積的方法,且比Cora、Citeseer 和Pubmed 數(shù)據(jù)集上的最佳結(jié)果分別提升了2.4 個(gè)百分點(diǎn)、0.2 個(gè)百分點(diǎn)和0.4 個(gè)百分點(diǎn)。同時(shí)由表可知,在Cora 和Citeseer 數(shù)據(jù)集上,基于GAN 的方法均優(yōu)于其他方法,也驗(yàn)證了將生成對(duì)抗網(wǎng)絡(luò)用于圖半監(jiān)督學(xué)習(xí)任務(wù)的有效性。而GraphSGAN 的效果優(yōu)于Triple-GAN,說(shuō)明產(chǎn)生的與真實(shí)樣本分布不同的未標(biāo)記樣本對(duì)分類效果影響更大。SemiGATDS 結(jié)合兩者的優(yōu)點(diǎn),同時(shí)生成的服從真實(shí)樣本分布的標(biāo)記樣本和與真實(shí)樣本分布不同的未標(biāo)記樣本,共同對(duì)模型產(chǎn)生了影響,獲得了比Triple-GAN 和GraphSGAN 更好的結(jié)果,從而驗(yàn)證了SemiGATDS 的有效性。
為了進(jìn)一步了解SemiGATDS 使用不同數(shù)量的標(biāo)記樣本訓(xùn)練時(shí)的表現(xiàn),本文通過(guò)改變每類選擇的標(biāo)記樣本的數(shù)量n獲得不同的訓(xùn)練集。表3~表5 顯示了3 個(gè)數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果。由表可知,隨著有標(biāo)記樣本比例的增加,用于訓(xùn)練模型的數(shù)據(jù)增加,模型能夠?qū)W到的信息越多,從訓(xùn)練數(shù)據(jù)中得到的模型的分類性能越好。以Cora 數(shù)據(jù)集為例,當(dāng)n為10 時(shí),Triple-GAN、GraphSGAN 和SemiGATDS 分類準(zhǔn)確率分別為76.4%、82.9%和83.5%;當(dāng)n為20 時(shí),它們的分類準(zhǔn)確率上漲到81.3%、84.0%和85.4%。并且當(dāng)n值相同時(shí),SemiGATDS 所獲得的結(jié)果仍然好于GraphSGAN 和Triple-GAN。同 樣的,在Citeseer 和Pubmed 數(shù)據(jù)集上也可以觀察到相同的結(jié)果,說(shuō)明生成的標(biāo)記樣本可以擴(kuò)充圖半監(jiān)督學(xué)習(xí)中的標(biāo)記樣本訓(xùn)練集,生成的未標(biāo)記樣本可以強(qiáng)制決策邊界位于正確的位置。這兩種生成樣本同時(shí)起作用,使Semi-GATDS 獲得了更好的效果。
表3 Cora 數(shù)據(jù)集上不同數(shù)量標(biāo)記樣本下的分類準(zhǔn)確率Table 3 Classification accuracy under different number of labeled samples on Cora dataset
表4 Citeseer數(shù)據(jù)集上不同數(shù)量標(biāo)記樣本下的分類準(zhǔn)確率Table 4 Classification accuracy under different number of labeled samples on Citeseer dataset
表5 Pubmed 數(shù)據(jù)集上不同數(shù)量標(biāo)記樣本下的分類準(zhǔn)確率Table 5 Classification accuracy under different number of labeled samples on Pubmed dataset
在Cora 數(shù)據(jù)集上,本文對(duì)比了Triple-GAN、GraphSGAN 和SemiGATDS 的分類準(zhǔn)確率與epoch的關(guān)系,實(shí)驗(yàn)取了前20 個(gè)epoch 的結(jié)果,如圖2 所示。
圖2 算法在Cora 數(shù)據(jù)集上分類準(zhǔn)確率與訓(xùn)練周期的關(guān)系Fig.2 Relationship between classification accuracy and training period of algorithms on Cora dataset
通過(guò)觀察發(fā)現(xiàn)了兩個(gè)不同的訓(xùn)練階段:
一階段:三個(gè)模型訓(xùn)練波動(dòng)比較大。推測(cè)是因?yàn)樵诔跏茧A段生成的樣本質(zhì)量不高,對(duì)模型造成了干擾。
二階段:模型趨于穩(wěn)定,SemiGATDS 明顯超過(guò)了Triple-GAN 和GraphSGAN。從分類器的角度看,gG生成的標(biāo)記樣本用于擴(kuò)充圖半監(jiān)督學(xué)習(xí)中標(biāo)記樣本訓(xùn)練集,bG生成的未標(biāo)記樣本減少了密度間隙中鄰近節(jié)點(diǎn)的影響。兩種生成樣本的共同作用,使得分類器得到了更好的分類效果。
為了探究生成的未標(biāo)記樣本和標(biāo)記樣本的比例對(duì)實(shí)驗(yàn)結(jié)果的影響,本文對(duì)比了模型在3 種數(shù)據(jù)集Cora、Citeseer、Pubmed 上,不同生成比例下的性能,如表6~表8 所示。
表6 SemiGATDS 在Cora 數(shù)據(jù)集上不同生成比例(未標(biāo)記樣本∶標(biāo)記樣本)的分類準(zhǔn)確率Table 6 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Cora dataset
表7 SemiGATDS 在Citeseer數(shù)據(jù)集上不同生成比例(未標(biāo)記樣本∶標(biāo)記樣本)的分類準(zhǔn)確率Table 7 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Citeseer dataset
表8 SemiGATDS 在Pubmed 數(shù)據(jù)集上不同生成比例(未標(biāo)記樣本∶標(biāo)記樣本)的分類準(zhǔn)確率Table 8 Classification accuracy of SemiGATDS under different generation ratios(unlabeled samples∶labeled samples)on Pubmed dataset
從表中結(jié)果可以得出如下結(jié)論:在Citeseer、Pubmed 兩個(gè)數(shù)據(jù)集上,當(dāng)生成的未標(biāo)記樣本和標(biāo)記樣本比例為1∶1 時(shí),模型的效果更好。在Cora 數(shù)據(jù)集上,當(dāng)生成的未標(biāo)記樣本和標(biāo)記樣本比例為1∶2 時(shí)模型的效果更好,但生成的未標(biāo)記樣本和標(biāo)記樣本比例為1∶1 和比例為1∶2 的效果相差不大,因此最終選取1∶1 的比例作為所有實(shí)驗(yàn)的基準(zhǔn)。
現(xiàn)有基于GAN 的圖半監(jiān)督學(xué)習(xí)算法能有效提升半監(jiān)督學(xué)習(xí)的分類性能,但標(biāo)記樣本過(guò)少仍是其面臨的主要困難。針對(duì)這個(gè)問(wèn)題,本文提出了一種基于GAN 的圖半監(jiān)督學(xué)習(xí)框架SemiGATDS,它通過(guò)生成器、分類器以及判別器之間的對(duì)抗訓(xùn)練,同時(shí)生成服從真實(shí)樣本分布的標(biāo)記樣本和與真實(shí)樣本分布不同的未標(biāo)記樣本,當(dāng)模型達(dá)到穩(wěn)態(tài)時(shí),生成的標(biāo)記樣本可以擴(kuò)充標(biāo)記樣本訓(xùn)練集,生成的未標(biāo)記樣本可以減少密度間隙中鄰近節(jié)點(diǎn)的影響,使決策邊界更清晰,從而提高圖半監(jiān)督分類的效果。在多個(gè)數(shù)據(jù)集上本文提出的SemiGATDS 均優(yōu)于現(xiàn)有的方法,進(jìn)一步討論了不同數(shù)量的標(biāo)記樣本和不同生成樣本比例對(duì)SemiGATDS 性能的影響,實(shí)驗(yàn)結(jié)果驗(yàn)證了該方法的有效性。