顧 昕,葉海良,楊 冰,曹飛龍
(中國計量大學(xué) 理學(xué)院,浙江 杭州 310018)
在社交媒體、生物工程和交通運輸?shù)阮I(lǐng)域存在許多具有圖結(jié)構(gòu)的數(shù)據(jù),它們往往是不規(guī)則且無序的,因此,傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò)[1-2]難以直接對其進(jìn)行處理。近年來,圖神經(jīng)網(wǎng)絡(luò)(Graph Neural Networks, GNNs)[3]成功地將卷積運算推廣到圖數(shù)據(jù)上,有效地解決了這一問題。GNNs主要是利用節(jié)點特征信息及邊信息來計算節(jié)點表示,應(yīng)用于節(jié)點分類[4]及鏈接預(yù)測[5]等節(jié)點級別的任務(wù)。對于圖分類、圖生成等圖級別的任務(wù)而言,沒有池化層的GNNs無法獲取圖數(shù)據(jù)的圖級表示,因此需要構(gòu)建針對圖數(shù)據(jù)的池化操作。
現(xiàn)有的圖池化方法主要有基于核的方法和基于神經(jīng)網(wǎng)絡(luò)的方法?;诤说姆椒ㄍǔR罁?jù)圖的結(jié)構(gòu)特性計算不同圖之間的相似性,進(jìn)而實現(xiàn)圖分類。如隨機(jī)游走核[6],它同時對兩個待分類的圖執(zhí)行隨機(jī)游走,然后計算兩次游走產(chǎn)生的路徑數(shù),從而得到兩個圖的相似性。然而大多數(shù)基于核的方法沒有特征提取,直接將圖轉(zhuǎn)換為固定長度的實值特征向量,造成圖級表示學(xué)習(xí)不充分,而帶有特征提取的核方法[7]計算復(fù)雜度高以及特征提取與分類器分離,不能以端到端的方式學(xué)習(xí)。
基于神經(jīng)網(wǎng)絡(luò)的方法主要有節(jié)點聚類和節(jié)點采樣兩種類型。對圖中節(jié)點聚類實現(xiàn)圖池化的代表工作有分層可微分圖池化(Differentiable Pooling, DIFFPool)[8],它構(gòu)造了一個可微池化層,根據(jù)圖中節(jié)點的特征將其軟分配給固定數(shù)量的類,由于要計算軟聚類分配,因此計算復(fù)雜度較高。節(jié)點采樣方法主要是依據(jù)節(jié)點的重要性保留一定數(shù)量的節(jié)點,構(gòu)造一個新的粗化圖。Zhang等[9]直接對每個節(jié)點的表示向量排序,將較大的節(jié)點表示向量保留生成池化圖;Cangea等[10]和Gao等[11]利用可訓(xùn)練的投影向量對圖中節(jié)點進(jìn)行采樣。首先將節(jié)點表示向量在投影向量上的投影得分作為節(jié)點的重要性得分,然后根據(jù)所設(shè)置的池化率k將得分較低的節(jié)點丟棄實現(xiàn)圖池化。Lee等[12]提出自注意力圖池化(Self-Attention Graph Pooling, SAGPool)方法,利用圖卷積神經(jīng)網(wǎng)絡(luò)(Graph Convolutional Neural Networks, GCNConv)[13]自適應(yīng)地計算注意力得分,并將其作為節(jié)點的重要性得分指導(dǎo)節(jié)點采樣。
本文采取基于節(jié)點采樣的圖池化方法,雖然這方面的工作已取得了一定的成果,但它們?nèi)源嬉恍﹩栴},有待進(jìn)一步解決?,F(xiàn)有的基于節(jié)點采樣進(jìn)行圖池化的方法,對于沒有被采樣的節(jié)點采取直接丟棄的原則,然而被丟棄的節(jié)點也帶有一定有效的信息,因此不可避免造成圖信息的丟失。另外,學(xué)習(xí)節(jié)點重要性得分時并沒有考慮每個節(jié)點與其鄰居節(jié)點間的相關(guān)度,從而導(dǎo)致節(jié)點重要性得分的學(xué)習(xí)不全面。在丟棄重要性較低的節(jié)點時節(jié)點的邊也被丟棄,圖中容易形成孤立點,影響整個圖結(jié)構(gòu)的連通性。
針對上述問題,本文做了以下工作:(1)提出信息保留模塊,在丟棄節(jié)點前,對節(jié)點中有利于圖分類的信息先進(jìn)行聚合保留,從而保留丟棄節(jié)點中的有效信息。(2)采用多頭注意力機(jī)制學(xué)習(xí)節(jié)點重要性得分,通過中心節(jié)點與其鄰居節(jié)點間的相關(guān)度聚合鄰域信息,從而更充分地學(xué)到節(jié)點的重要性得分。(3)在節(jié)點采樣之后應(yīng)用保持圖連通性模塊,將孤立點與其鄰居節(jié)點相連,保證圖結(jié)構(gòu)的連通性。
本文提出的結(jié)合信息保留的多頭注意力圖池化模型(Multi-head attention graph pooling model with information retention, MHAPool)完整結(jié)構(gòu)如圖1所示,對于輸入的圖,在提取初始特征階段利用經(jīng)典且被廣泛使用的GCNConv[13]來提取初始特征,公式如下:
(1)
通常,池化層只將被采樣節(jié)點的信息保留,用作新圖的特征。然而,被丟棄的節(jié)點也帶有一定的有效信息,若直接丟棄節(jié)點則會造成圖信息的損失。因此,本文設(shè)計節(jié)點信息保留模塊,將丟棄節(jié)點中的有效信息保留下來。
圖1 結(jié)合信息保留的多頭注意圖池化模型
在學(xué)到的節(jié)點表示向量中,最大的節(jié)點表示向量帶有的特征往往最具判別性與代表性,更有利于圖分類,因此我們將圖中每個節(jié)點i依次作為中心節(jié)點,取其一階鄰居作為它的鄰域N(i),在鄰域N(i)中找出初始特征提取階段得到的最大的節(jié)點表示向量mi,公式如下:
(2)
然后,計算鄰域N(i)中每個節(jié)點與mi的余弦相似度:
χi,j=cos(mi,hj)。
(3)
式(3)中χi,j代表N(i)中每個節(jié)點的表示向量與最大的節(jié)點表示向量之間的相似度,相似度越高代表這個節(jié)點包含更多有利于圖分類的信息,因此我們將得到的相似度作為權(quán)重與每一個節(jié)點表示向量相乘求和:
(4)
本文采用多頭注意力機(jī)制[14]學(xué)習(xí)節(jié)點重要性得分,如圖2所示,每一頭有三支輸入,首先將節(jié)點i與一階鄰域中的節(jié)點j進(jìn)行線性變換,公式如下:
qc,i=Wc,qhi+bc,q,kc,j=Wc,khj+bc,k。
(5)
式(5)中qc,i和kc,j分別為線性變換后中心節(jié)點i的表示向量與其鄰域中節(jié)點j的表示向量,Wc,q,Wc,k,bc,q,bc,k表示可學(xué)習(xí)的權(quán)重和偏置,c表示注意力機(jī)制的頭數(shù)。
圖2 以3頭為例的多頭注意力學(xué)習(xí)節(jié)點重要性得分示意圖
得到qc,i和kc,j之后,使用縮放點積計算中心節(jié)點i與其鄰域中節(jié)點j的多頭注意力系數(shù),公式如下:
(6)
得到每個節(jié)點與相鄰節(jié)點的多頭注意力系數(shù)后,我們對鄰域中的節(jié)點再進(jìn)行一次線性變換如下:
vc,j=Wc,vhj+bc,v。
(7)
式(7)中vc,j為線性變換后節(jié)點j的表示向量,Wc,v,bc,v表示可學(xué)習(xí)的權(quán)重和偏置。
我們將變換后的鄰居節(jié)點表示vc,j與多頭注意力系數(shù)相乘之后取平均得到每個節(jié)點的重要性得分,公式如下:
(8)
式(8)中Z={z1,z2,…,zn}∈Rn×1,其中zi為節(jié)點i的重要性得分,C表示實際所取的注意力機(jī)制的頭數(shù)。
在學(xué)到節(jié)點重要性得分之后,設(shè)定池化率k并結(jié)合節(jié)點的重要性得分來生成索引,
idx=top-k(Z,|kN|)。
(9)
式(9)中|·|表示向下取整,N表示輸入圖的節(jié)點數(shù),top-k表示生成向量Z中前|kN|個值的索引,idx表示生成的索引。
隨后,根據(jù)生成的索引對圖中節(jié)點進(jìn)行采樣,實現(xiàn)圖池化。如圖1所示,在節(jié)點采樣時直接丟棄節(jié)點,與節(jié)點相連的邊也隨之被丟棄,導(dǎo)致部分節(jié)點成為沒有邊連接的孤立點,影響整個圖的連通性。受Ying等[8]的啟發(fā),為保持圖的連通性,本文采用如下方式進(jìn)行采樣:
M=A(:,idx),A′=MTAM,H′=H(idx,:)。
(10)
式(10)中H(idx,:)表示對圖的特征矩陣H執(zhí)行行提取,形成池化圖的特征矩陣,對于圖的鄰接矩陣A,首先對其列提取得到M,然后通過MTAM得到池化圖的鄰接矩陣A′,此操作可以使得孤立點與鄰居節(jié)點相連,進(jìn)而保證整個圖的連通性。
本文構(gòu)建分層圖池化模型實現(xiàn)圖分類。圖3是一個以三層為例的圖池化模型架構(gòu)圖,每一層可分為3個部分:特征提取、圖池化和讀出操作。首先,通過GCNConv對輸入圖進(jìn)行特征提取;隨后,對節(jié)點作信息保留,同時利用多頭注意力機(jī)制學(xué)習(xí)節(jié)點重要性得分,并據(jù)此實現(xiàn)節(jié)點采樣;之后再將孤立點與鄰居節(jié)點相連保持圖的連通性;接著,在讀出操作中對采樣后的圖表示取平均和最大,并進(jìn)行拼接;最后,將各層讀出操作的輸出相加,傳至圖分類器,完成圖分類任務(wù)。
已有的圖池化方法中最具代表性的有DIFFPool[8]和SAGPool[12],其中DIFFPool[8]學(xué)習(xí)節(jié)點分配矩陣,將原圖中的每個節(jié)點以指定的概率分配給新圖中的不同類,重復(fù)多次將圖中節(jié)點聚合為一個超級節(jié)點,這是一種通過聚合節(jié)點實現(xiàn)圖池化的方法。本文提出的MHAPool學(xué)習(xí)節(jié)點重要性得分,是一種依據(jù)節(jié)點重要性得分丟棄節(jié)點實現(xiàn)圖池化的方法。SAGPool[12]同為學(xué)習(xí)節(jié)點重要性得分,進(jìn)而丟棄節(jié)點實現(xiàn)圖池化的方法,但其學(xué)習(xí)節(jié)點得分時只是通過一層GCNConv[13]。而本文提出的MHAPool采用多頭注意力機(jī)制學(xué)習(xí)節(jié)點重要性得分,通過中心節(jié)點與其鄰居節(jié)點間的相關(guān)度聚合鄰域信息,更充分地學(xué)到節(jié)點的重要性得分,并且設(shè)置信息保留模塊,從而保留被丟棄節(jié)點中的有效信息。
本章介紹了數(shù)據(jù)集、對比方法和實驗設(shè)置,展示了本文提出的MHAPool對比實驗和消融實驗的結(jié)果,并對模型中的關(guān)鍵參數(shù)進(jìn)行了討論。
本文在4個生物信息數(shù)據(jù)集(DD[15], PROTEINS[16],NCI1[17]和NCI109[17])和3個社交網(wǎng)絡(luò)數(shù)據(jù)集(IMDB-BINARY[18],IMDB-MULTI[18]和COLLAB[19])上評估所提出模型的性能,具體介紹如下:DD[15]和PROTEINS[16]中的每個圖皆表示某種蛋白質(zhì)結(jié)構(gòu),其標(biāo)簽是每個圖所表示的蛋白質(zhì)是否為酶。NCI1[17]是美國國家癌癥研究所(National Cancer Institute, NCI)發(fā)布的用于癌細(xì)胞活性分類的生物信息數(shù)據(jù)集,其標(biāo)簽為細(xì)胞是否可以抑制癌細(xì)胞的生長。NCI109[17]中每個圖表示卵巢癌細(xì)胞的化學(xué)結(jié)構(gòu),圖分類對應(yīng)于卵巢癌細(xì)胞的活性篩選。IMDB-BINARY[18]是電影演員合作數(shù)據(jù)集,由出演動作電影和浪漫電影的演員組成。IMDB-MULTI[18]和IMDB-BINARY類似,它由出演喜劇電影、浪漫電影和科幻電影3種類別電影演員組成,圖分類是對出演不同類別電影的演員的進(jìn)行分類。COLLAB[19]是來自高能物理、凝聚態(tài)物理和天體物理這3個領(lǐng)域的科學(xué)家社交網(wǎng)絡(luò)數(shù)據(jù)集,圖分類是將每位科學(xué)家所屬的領(lǐng)域分類。
圖3 圖池化模型架構(gòu)
本文所提出的方法與以下兩種類別的圖分類方法比較。
基于核的方法:威斯費勒-萊曼核(Weisfeiler-Lehman Kernels, WL)[7],最短路徑核(Shortest-path Kernels, SP)[20],圖核(Graphlet Kernels, GK)[21],深度圖核(Deep Graph Kernels, DGK)[22],和匿名游走嵌入(Anonymous Walk Embeddings, AWE)[23]。
基于圖神經(jīng)網(wǎng)絡(luò)的方法:DIFFPool[8],gPool[11],SAGPool[12],EigenPool[24],基于信息的圖池化(Information-Based Pooling, iPool)[25],結(jié)構(gòu)學(xué)習(xí)分層圖池化(Hierarchical graph pooling with structure learning, SLPool)[26],基于ARMA濾波的圖神經(jīng)網(wǎng)絡(luò)(ARMA)[27],用于圖學(xué)習(xí)的瓦瑟斯坦嵌入(Wasserstein Embedding for Graph Learning, WEGL)[28],圖多集池化(Graph Multiset Pooling, GMT)[29],和空間卷積神經(jīng)網(wǎng)絡(luò)(Spatial Convolutional Neural Networks, SCNN)[30]。
實驗中將每個數(shù)據(jù)集隨機(jī)分成3部分:80%作為訓(xùn)練集,10%作為驗證集,其余10%作為測試集,將數(shù)據(jù)集隨機(jī)拆分過程重復(fù)10次,取10次實驗精度的平均值和標(biāo)準(zhǔn)差作為結(jié)果。本文使用作者提供的源碼得到對比方法的結(jié)果,同時為了公平比較,對本文所提出的方法和已有的方法使用相同的模型架構(gòu),并將節(jié)點表示維度都設(shè)為128。本文在PyTorch框架下實現(xiàn)了MHAPool,并使用Adam優(yōu)化器[31]對模型進(jìn)行優(yōu)化。MLP由3個全連接層組成,每層的神經(jīng)元數(shù)量依次設(shè)為256、128、64,最后接上softmax分類器,完成圖分類。在訓(xùn)練過程中采用了提前停止準(zhǔn)則,即若驗證損失在連續(xù)50個時期內(nèi)沒有減少,將提前停止訓(xùn)練。
本文在圖分類精度方面將所提出的MHAPool與其他模型進(jìn)行比較,結(jié)果如表1和表2所示,最佳模型以粗體突出顯示,次優(yōu)模型以下劃線顯示。
表1 4個生物信息數(shù)據(jù)集的統(tǒng)計信息以及MHAPool與對比方法在圖分類實驗上的比較結(jié)果
表2 3個社交網(wǎng)絡(luò)數(shù)據(jù)集的統(tǒng)計信息以及MHAPool與對比方法在圖分類實驗上的比較結(jié)果
表1總結(jié)整理了4個生物信息學(xué)數(shù)據(jù)集PROTEINS、DD、NCI1和NCI109的統(tǒng)計信息以及與對比方法的比較結(jié)果,表2歸納整理了3個社交網(wǎng)絡(luò)數(shù)據(jù)集IMDB-MULTI、IMDB-BINARY和COLLAB的統(tǒng)計信息以及與其他圖分類模型的比較結(jié)果。從實驗結(jié)果可以看出,MHAPool在PROTEINS,DD和NCI109這3個生物信息學(xué)數(shù)據(jù)集取得了最好的結(jié)果。特別是單個圖上節(jié)點數(shù)較多的數(shù)據(jù)集DD,本模型的圖分類精度達(dá)到最高,這說明了MHAPool具有較好的處理復(fù)雜數(shù)據(jù)的能力。在3個社交網(wǎng)絡(luò)數(shù)據(jù)集中,本模型在IMDB-MULTI,IMDB-BINARY這兩個數(shù)據(jù)集上比現(xiàn)有的最優(yōu)結(jié)果分別高出3.0%和1.4%。
對節(jié)點重要性得分學(xué)習(xí)方式的分析。對已有的使用GCNConv學(xué)習(xí)節(jié)點重要性得分與本文提出的多頭注意力機(jī)制(Multi-head Attentation,MHA)學(xué)習(xí)節(jié)點重要性得分在7個數(shù)據(jù)集上進(jìn)行圖分類實驗,實驗結(jié)果見表3,從實驗結(jié)果可以看出使用MHA的模型圖分類實驗的精度更高,結(jié)果表明多頭注意機(jī)制相比GCNConv學(xué)習(xí)到的節(jié)點重要性得分更全面,準(zhǔn)確。
對信息保留模塊的分析。為了說明信息保留模塊(Information Retention, IR)的作用,在圖分類數(shù)據(jù)集上對是否設(shè)置信息保留模塊的模型進(jìn)行實驗。實驗結(jié)果見表3,結(jié)果表明信息保留模塊保留了圖中節(jié)點的有效信息,解決了直接丟棄節(jié)點造成有效信息的丟失。
對圖連通性保持的分析。為了說明圖連通性保持模塊(Maintain Graph Connectivity, MGC)的作用,分別訓(xùn)練了是否帶有圖連通性保持模塊的模型。實驗結(jié)果見表3,結(jié)果表明在節(jié)點采樣之后加上圖連通性保持模塊,可以將由于節(jié)點采樣形成的孤立點與其鄰居節(jié)點相連保證圖結(jié)構(gòu)的連通性,使得圖分類實驗效果更好。
表3 MHAPool在圖分類數(shù)據(jù)集上的消融實驗結(jié)果
本節(jié)進(jìn)一步探究了關(guān)鍵超參數(shù)取不同值時對實驗效果的影響,分別為:圖池化模型的網(wǎng)絡(luò)層數(shù)L和學(xué)習(xí)節(jié)點重要性得分時多頭注意力機(jī)制的頭數(shù)c。在社交網(wǎng)絡(luò)和生物信息兩個領(lǐng)域中取大小各異的3個數(shù)據(jù)集COLLAB,NCI109和IMDB-BINARY作為代表,在這3個數(shù)據(jù)集上研究兩個超參數(shù)L和c取不同值時對圖分類性能的影響,這3個數(shù)據(jù)集的實驗結(jié)果見圖4。
IMDB-BINARY,NCI109和COLLAB這3個數(shù)據(jù)集中平均每個圖的節(jié)點數(shù)分別為13,32.1和74.5。在實驗中設(shè)置網(wǎng)絡(luò)層數(shù)L=1,2,3,4,5,在IMDB-BINARY數(shù)據(jù)集上,當(dāng)L取1時達(dá)到最好的實驗精度,而對于NCI109和COLLAB,L分別取2和4時達(dá)到最好的圖分類實驗效果,這表明對于較小的數(shù)據(jù)集,淺層的網(wǎng)絡(luò)就可以學(xué)習(xí)到圖表示,并達(dá)到較好的效果。在較大的數(shù)據(jù)集上需要加深網(wǎng)絡(luò)層數(shù)才能充分學(xué)習(xí)圖表示。對于較小的數(shù)據(jù)集隨著網(wǎng)絡(luò)層數(shù)的加深,實驗效果變差,這主要是因為每一層中都有GCNConv模塊,疊加多層GCNConv造成過平滑問題。
對于多頭注意力機(jī)制的頭數(shù)c,從實驗結(jié)果上來看,節(jié)點較少的數(shù)據(jù)集IMDB-BINARY,c取2時實驗精度達(dá)到最大值,對于NCI109和COLLAB,c分別取3和4時達(dá)到最好的圖分類結(jié)果。多頭注意力機(jī)制用于計算節(jié)點的重要性得分,在節(jié)點數(shù)較少的數(shù)據(jù)集上,頭數(shù)c取較小值即可學(xué)得較好的節(jié)點重要性得分,而對于節(jié)點數(shù)較多的數(shù)據(jù)集COLLAB,需要更多的頭數(shù)才能全面地學(xué)習(xí)節(jié)點重要性得分達(dá)到最優(yōu)的結(jié)果。
本文提出一種結(jié)合信息保留的多頭注意力圖池化方法,具有較好的圖分類性能。首先,通過信息保留模塊,在丟棄節(jié)點時保留圖中有效信息,解決因直接丟棄未被采樣的節(jié)點,造成圖信息損失的問題。其次,本文采用多頭注意力機(jī)制,考慮每個節(jié)點與其鄰居節(jié)點間的相關(guān)度,有效聚合鄰域信息,從而更充分地學(xué)習(xí)各節(jié)點的重要性得分。在節(jié)點采樣之后,設(shè)置圖連通性保持模塊,將孤立點與其鄰居節(jié)點相連,保證整個圖結(jié)構(gòu)的連通性。最后,在多個數(shù)據(jù)集上的圖分類實驗結(jié)果驗證了所提出方法的先進(jìn)性。
在未來的工作中,我們將考慮開發(fā)圖特征提取的方法,改善多層GCNConv易造成的過平滑問題。
圖4 超參數(shù)分析