張輝宜, 張 進, 黃 俊
(安徽工業(yè)大學(xué) 計算機科學(xué)與技術(shù)學(xué)院,安徽 馬鞍山 243000)
傳統(tǒng)監(jiān)督學(xué)習(xí)中每個樣本只含有一個語義信息,但是現(xiàn)實世界的數(shù)據(jù)往往含有多個類別的語義信息,即單個樣本關(guān)聯(lián)著多個語義標簽。例如,一幅天空的圖像可以同時標注“藍天”、“白云”等語義標簽;一段新聞文檔可以同時屬于“時事”、“政策”等多個類別。針對這些含有多個語義標簽的多標簽數(shù)據(jù),如果只考慮單一語義標簽對其進行學(xué)習(xí),就很難獲得很好的分類效果。多標簽學(xué)習(xí)的應(yīng)用領(lǐng)域十分廣泛,包含了圖像分類[1]、文本分類[2]、音樂分類[3]以及生物學(xué)分類[4]等多個領(lǐng)域。隨著現(xiàn)實生活中多標簽圖像數(shù)量越來越多、種類越來越復(fù)雜,多標簽學(xué)習(xí)在圖像分類上的應(yīng)用也顯得更加重要。利用標簽之間的相關(guān)性可以提升多標簽?zāi)P偷姆诸愋阅躘5]。根據(jù)標簽相關(guān)性挖掘的程度,可以將多標簽分類模型分為3類:沒有利用到標簽相關(guān)性的一階方法[6-7];挖掘標簽對之間關(guān)系的二階方法[8-9];和利用所有標簽或類別標簽子集中標簽關(guān)系的高階方法[10-11]。初期采用淺層分類模型[12-13]對多標簽圖像進行分類,在人工干預(yù)提取數(shù)據(jù)特征的情況下,淺層模型一般都能取得較好的分類結(jié)果。近十年來應(yīng)用深度學(xué)習(xí)理論[14]尤其是卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network, CNN)構(gòu)建了一批經(jīng)典深度卷積神經(jīng)網(wǎng)絡(luò)模型,例如AlexNet[15]、VGG[16]、ResNet[17],這些模型可以對大量多標簽圖像樣本進行有效的深層特征學(xué)習(xí),但單獨利用卷積神經(jīng)網(wǎng)絡(luò)對多標簽圖像進行分類缺乏對標簽相關(guān)性的利用,這會影響模型的分類性能,因此多標簽分類的現(xiàn)有工作往往會利用標簽相關(guān)性以提高性能。在標簽相關(guān)性中標簽的共現(xiàn)關(guān)系可以通過概率圖模型很好地表述,在以往的研究工作中,有很多基于這種數(shù)學(xué)理論的方法可以對標簽關(guān)系進行建模[18,19],但是概率圖模型的計算成本過高,為了解決這個問題,使用遞歸網(wǎng)絡(luò)將標簽編碼為嵌入向量,以實現(xiàn)標簽間相關(guān)性建模的方法被提出[20],該方法也存在著遞歸神經(jīng)網(wǎng)絡(luò)模型依賴于預(yù)定義或?qū)W習(xí)的標簽順序的不足,且無法很好地獲得標簽全局依賴性。2019年Chen等提出了ML-GCN模型[21],ML-GCN利用訓(xùn)練數(shù)據(jù)集中所有標簽類別的標簽共現(xiàn)關(guān)系建立了整體的標簽相關(guān)性,在最終分類階段使用圖卷積網(wǎng)絡(luò)(Graph Convolutional Network,GCN)[22]傳播標簽共現(xiàn)嵌入并將標簽共現(xiàn)嵌入與CNN特征合并,但是ML-GCN學(xué)習(xí)到的標簽共現(xiàn)嵌入維度遠遠高于需要分類的標簽類別數(shù),這會影響模型的分類性能。提出基于圖注意力網(wǎng)絡(luò)(Graph Attention Network,GAT)[23]的多標簽圖像分類模型ML-GAT,ML-GAT采用降維[24]的方法對ML-GCN中標簽共現(xiàn)嵌入維度過高的問題進行改進,同時采用圖注意力網(wǎng)絡(luò)對標簽之間的關(guān)系進行更加精確的建模。
針對通過圖卷積神經(jīng)網(wǎng)絡(luò)得到標簽共現(xiàn)嵌入維度過高的問題,ML-GAT采用詞嵌入降維模塊對高維雙向Transformer 的表征編碼器(Bidirectional Encoder Representation from Transformers,BERT)[25]標簽語義嵌入表示矩陣進行降維,得到低維標簽語義嵌入表示矩陣。為了學(xué)習(xí)標簽之間非對稱的關(guān)系特征,將低維標簽語義嵌入表示矩陣和標簽類別共現(xiàn)圖輸入GAT,獲取標簽共現(xiàn)嵌入模塊,得到維度合適的低維標簽共現(xiàn)嵌入。同時ML-GAT采用圖像特征提取模塊提取圖像特征。為了匹配低維標簽共現(xiàn)嵌入維度,圖像特征需要經(jīng)過圖像特征降維模塊進行降維,在降維的同時也減少了圖像特征中的冗余部分。最后,將標簽共現(xiàn)嵌入與降維后的圖像特征通過圖像特征與標簽共現(xiàn)嵌入融合模塊進行融合,得到多標簽預(yù)測評分。多標簽圖注意力網(wǎng)絡(luò)模型結(jié)構(gòu)如圖1所示。
圖1 多標簽圖注意力網(wǎng)絡(luò)模型結(jié)構(gòu)Fig. 1 Model structure of multi-label graph attention network
ML-GAT中圖像通用特征提取模塊使用101層ResNet,即ResNet-101模型。ResNet-101是目前主流的CNN之一,其優(yōu)點是易于調(diào)整,可以比較方便地利用在多標簽圖像分類任務(wù)上,并且有較強的特征提取能力。因為ML-GAT采用的是在ImageNet上預(yù)訓(xùn)練的ResNet-101,所以需要去除用來對ImageNet進行分類的全連接層,為了控制圖像維度,需要同時去除ResNet-101的自適應(yīng)池化層,這樣可以得到多標簽圖像特征提取器。將解析度為448×448的多標簽圖像樣本I輸入多標簽圖像特征提取器,可提取多標簽圖像的特征圖F:
F=fResNet(I;θResNet)∈RW×H×D
其中,特征圖F的長寬為W、H,通道數(shù)為D,fResNet表示圖像通用特征提取模塊,θResNet是ResNet-101模型參數(shù)。
因為在圖像通用特征與標簽共現(xiàn)嵌入融合模塊中需要將圖像特征與標簽共現(xiàn)嵌入維度進行匹配,同時對圖像特征進行降維,可以一定程度上提高圖像特征的判別力,故在ML-GAT中采取以下步驟對對特征圖F的長寬W,H以及通道數(shù)D進行降維,F(xiàn)首先通過卷積層conv1下采樣,得到F′∈RW′×H′×D,W′和H′代表降維后特征圖F′的長與寬,再通過一層卷積層conv2對特征圖F′的通道數(shù)D進行降維,得到F″∈RW′×H′×d″,d″為降維后F″的通道數(shù),最后經(jīng)過全局最大值池化層GMP,提取多標簽圖像的特征紋理,去除無用特征。這樣可以為每一張圖像提取一個維度為Rd″的圖像特征向量x:
x=fGMP(fconv2(fconv1(F);θconv1);θconv2)∈Rd″
其中,fGMP為全局最大值池化運算,fconv1和fconv2分別為卷積層conv1與conv2進行的卷積運算,θconv1,θconv2分別為卷積層conv1與卷積層conv2的模型參數(shù)。
GAT首先針對每一個標簽節(jié)點i,計算標簽節(jié)點i與包括自身節(jié)點自身在內(nèi)的所有鄰居節(jié)點j之間的相關(guān)系數(shù)eij:
其中,W∈RM×d′是共享參數(shù),對標簽節(jié)點i的特征zi和標簽節(jié)點j的特征zj進行增維,增維后的維度為M,在ML-GAT中最后一層M=d″,[·‖·]表示對標簽節(jié)點i,j的特征進行拼接可以將兩個維度為RM的向量拼接為R2M的向量,j∈Ni是與標簽節(jié)點i存在共現(xiàn)關(guān)系的一跳鄰居節(jié)點。a∈R2M運算將拼接特征映射到一個實數(shù)上。
對相關(guān)系數(shù)采用歸一化運算得到注意力系數(shù):
其中,LeakyReLU是激活函數(shù)。用注意力系數(shù)aij用來計算每個節(jié)點的最終輸出特征:
其中,σ為非線性激活函數(shù)。因為通過計算得到的標簽i對標簽j的注意力系數(shù)與標簽j對標簽i的注意力系數(shù)不同,所以GAT得到的標簽節(jié)點特征可以一定程度上表示多標簽學(xué)習(xí)中標簽與標簽之間的非對稱關(guān)系,例如"飛機”和“天空”這一對標簽,有“天空”這一標簽時“飛機”有著很小的概率同時出現(xiàn),而“飛機”標簽出現(xiàn)時則會大概率伴隨著“天空”標簽的出現(xiàn)。ML-GAT采用GAT可以單獨為每一對標簽計算注意力系數(shù),得到能更加準確表達標簽之間關(guān)系的標簽共現(xiàn)嵌入。
在GAT獲取標簽共現(xiàn)嵌入模塊,經(jīng)過兩層GAT的計算,可以得到一個帶有類別標簽間非對稱關(guān)系,維度為RC×d″的標簽共現(xiàn)嵌入Zl+2。每一層的GAT計算為
Zl+1=fGAT(Zl,U)+Zl
其中,fGAT表示一層GAT計算,U∈RC×C表示標簽節(jié)點從標簽類別共現(xiàn)圖中獲得的相關(guān)矩陣建立方式與ML-GCN中相同,U中元素uij取值取決于類別標簽i與類別標簽j之間的共現(xiàn)次數(shù),為了能更好地將上一層的信息傳遞到下一層,因此在計算時將加上之前一層GAT的計算結(jié)果Zl。
對于一張多標簽圖像樣本,本模型使用的多標簽分類損失函數(shù)(Multi-label Soft Margin Loss):
實驗所采用的軟硬件環(huán)境為Intel Pentium G4560 @ 3.50 GHz,NVIDIA GeForece GTX 1080Ti 11 GB顯卡,12 GB內(nèi)存,操作系統(tǒng)為Ubuntu 16.04,編程語言為Python,深度學(xué)習(xí)框架為Pytorch 1.5。
ML-GAT在兩種常用多標簽圖像數(shù)據(jù)集上進行對比實驗,分別是:Microsoft COCO 2014(MS-COCO 2014)[26]和PASCAL Visual Object Classes Challenge(VOC 2007)[27]。MS-COCO 2014擁有80個類別的多標簽圖像,包含82 081張圖像組成的訓(xùn)練集和 40 504張圖像組成的驗證集,平均每張圖像都擁有2.9個類別標簽。VOC 2007數(shù)據(jù)集包含9 963張圖像組成的訓(xùn)練集、驗證集和測試集,包含20個常見物體類別標簽。
在ML-GAT中,將維度為RC×L的高維BERT標簽語義嵌入矩陣Z0輸入到詞嵌入降維模塊,預(yù)訓(xùn)練BERT標簽次嵌入矩陣維度L取值為1 024,經(jīng)過一層卷積核長度4寬度為1的卷積層進行下采樣,水平步長為4垂直步長為1,得到低維標簽語義嵌入表示矩陣Zl∈RC×d′,d′此時為256,將Zl輸入GAT獲取標簽共現(xiàn)嵌入模塊,經(jīng)過兩層GAT計算得到標簽共現(xiàn)嵌入Zl+2∈RC×d″。為了將標簽共現(xiàn)嵌入應(yīng)用在圖像特征上,將多標簽圖像解析度設(shè)置為448×448,將其輸入圖像通用特征提取模塊,得到多標簽圖像特征圖F∈RW×H×D,D為2 048,W、H均為14,針對VOC 2007數(shù)據(jù)集模型,采用的卷積層conv1不改變其W、H,使W′、H′與W、H相等。MS-COCO 2014數(shù)據(jù)集中W,H通過長寬為5卷積核的conv1計算得到值均為10的W′、H′,在兩種數(shù)據(jù)集上均經(jīng)過長寬為1的卷積核的卷積層conv2,對特征圖通道數(shù)D進行降維,得到F″∈RW′×H′×d″,最后采用池化核大小為W′×H′的全局最大值池化層GMP得到維度為Rd″的圖像特征向量x,d″是圖像特征向量x的維度,同時也是標簽共現(xiàn)嵌入的列維度,在VOC 2007數(shù)據(jù)集上的取值分別為{300,512,768},而在MS-COCO 2014數(shù)據(jù)集上設(shè)置d″為{1 024,1 280,1 536},d″參數(shù)設(shè)置由參數(shù)搜索和數(shù)據(jù)集中的標簽類別標簽個數(shù)共同決定,參數(shù)搜索策略為試錯法,由于MS-COCO 2014所含有的類別標簽數(shù)是VOC 2007中所含類別標簽數(shù)的4倍,因此d″也同步增加。設(shè)置初始學(xué)習(xí)率為0.005,采用隨機梯度下降作為優(yōu)化器,權(quán)重衰減設(shè)置為10-4,動量設(shè)置為0.9,總共訓(xùn)練100輪。
測試采用的評價指標有:平均每類精度(CP)、平均每類召回率(CR)和平均每類F1(CF1)值。另外針對整體分類結(jié)果使用平均整體精度(OP),平均整體召回率(OR),平均整體F1(OF1)進行評價。針對每個類別的分類準確度,取平均值得到平均精度均值(mAP)[28],評價指標定義如下:
在MS-COCO 2014數(shù)據(jù)集的實驗中,因為實驗設(shè)備條件有限,且數(shù)據(jù)集中樣本相對較多,故進行實驗時,采用隨機抽取部分訓(xùn)練樣本用作訓(xùn)練模型,再將訓(xùn)練出的模型在全部測試樣本上進行測試的方法。對于ML-GCN和ResNet-101進行同樣的采樣、訓(xùn)練、測試方法進行實驗,在MS-COCO 2014訓(xùn)練樣本列表中采用Python的Random模塊,從82 081張訓(xùn)練樣本中隨機抽取4 000個樣本,采樣3次,訓(xùn)練出3個模型分別測試,對所有測試產(chǎn)生的評價指標,取3次測試的均值作為實驗結(jié)果。VOC 2007數(shù)據(jù)集采用全部訓(xùn)練樣本和測試樣本進行實驗。實驗結(jié)果中各評價指標中最佳值均已加粗。
ML-GAT在VOC2007上的測試結(jié)果如表1所示,經(jīng)過與近幾年來的主流深度多標簽圖像分類模型進行對比(實驗數(shù)據(jù)來源中除ResNet-101、ML-GAT,其他方法數(shù)據(jù)均來自各論文中給出的測試結(jié)果),在d″設(shè)置為512的情況下,ML-GAT在mAP這一指標上達到了94.3,在14個類別的分類上為最佳值。在MS-COCO 2014數(shù)據(jù)集上ML-GAT的測試結(jié)果如表2所示,此時d″設(shè)置為1 280,在所有標簽上的預(yù)測與前3個標簽上的預(yù)測結(jié)果中,有7個主要分類指標超過或持平ML-GCN,說明ML-GAT模型可以在多個常用數(shù)據(jù)集上取得較好的分類結(jié)果。
表1 在VOC 2007上的實驗結(jié)果Table 1 Experimental results on VOC 2007
表2 在MS-COCO 2014上的實驗結(jié)果Table 2 Experimental results on MS-COCO 2014
為了比較不同數(shù)據(jù)集上標簽共現(xiàn)嵌入列維度d″對分類性能的影響,分別對兩個數(shù)據(jù)集設(shè)置不同的d″進行對比實驗,如圖2所示,在VOC 2007數(shù)據(jù)集中,d″取值為512時,ML-GAT在mAP評價指標上達到最佳,在MS-COCO 2014數(shù)據(jù)集上進行一次采樣測試,d″大小為1 280時得到最佳mAP,這說明MS-COCO 2014數(shù)據(jù)集中的標簽類別更多,標簽共現(xiàn)嵌入中冗余部分較少。而VOC 2007因為標簽類別較少,因此標簽共現(xiàn)嵌入冗余部分較多。通過在這兩種數(shù)據(jù)集上進行對比實驗,驗證了ML-GAT在標簽共現(xiàn)嵌入降維,與對標簽之間非對稱關(guān)系的提取上采取的策略是有效的。
(a) VOC 2007
(b) MS-COCO 2014
圖卷積神經(jīng)網(wǎng)絡(luò)與CNN結(jié)合的深度多標簽圖像分類模型ML-GCN在多標簽圖像的分類上取得了很好的效果,但是ML-GCN中通過GCN獲取到的標簽共現(xiàn)嵌入維度過高,標簽共現(xiàn)嵌入沒有很好的反應(yīng)標簽之間非對稱關(guān)系,針對ML-GCN存在的這兩點不足,提出一種基于圖注意力網(wǎng)絡(luò)的多標簽圖像分類模型ML-GAT。ML-GAT通過對輸入GAT的高維標簽語義嵌入表示矩陣進行降維,解決了ML-GCN利用GCN獲取標簽共現(xiàn)嵌入時,冗余部分降低模型分類準確度問題,同時GAT可以對標簽鄰居之間計算不同注意力系數(shù),學(xué)習(xí)標簽之間非對稱關(guān)系特征,促進模型分類。通過在主流數(shù)據(jù)集上與多標簽深度學(xué)習(xí)經(jīng)典模型進行對比實驗,ML-GAT模型在多標簽圖像分類主要評價指標上,相較經(jīng)典深度多標簽圖像分類模型有一定的改進,實驗證明了ML-GAT模型的有效性。