黃友文,魏國慶,胡燕芳
(江西理工大學(xué) 信息工程學(xué)院,江西 贛州 341000)
文本分類是文本管理的基本工作。文本分類是自然語言處理(NLP)領(lǐng)域的一個重要分支。目前,隨著互聯(lián)網(wǎng)和媒體的快速發(fā)展,新聞網(wǎng)站已成為人們獲取新聞信息的主要平臺,每天向數(shù)百萬用戶提供信息。鑒于網(wǎng)絡(luò)平臺強(qiáng)大的實(shí)時性能,大量新聞文本呈現(xiàn)出快速增長的趨勢,這些文本往往沒有經(jīng)過處理,人工整理耗時耗力。因此,如何高效準(zhǔn)確地對這些海量文本進(jìn)行分類,以及如何快速獲取有效信息一直是學(xué)術(shù)界和產(chǎn)業(yè)界關(guān)注的焦點(diǎn)。
通過在大量文本上進(jìn)行預(yù)訓(xùn)練的大型語言模型在各種自然語言處理任務(wù)中的效果好于小型語言模型,然而這些大型語言模型變得越來越復(fù)雜,計算成本也越來越高,嚴(yán)重阻礙了這些模型的廣泛應(yīng)用。知識蒸餾作為一種有效的模型壓縮方法,可以緩解這個問題,該方法不但可以將大型模型壓縮成較小的模型,而且不會顯著降低模型的性能。知識蒸餾的過程如下: 先訓(xùn)練一個結(jié)構(gòu)較為復(fù)雜、分類性能較好的模型作為教師模型,然后構(gòu)建一個小型模型作為學(xué)生模型,最后對教師模型進(jìn)行蒸餾,把教師模型學(xué)到的知識轉(zhuǎn)移到小模型中。Hintion等[1]首次提出采用知識蒸餾的方式對模型進(jìn)行壓縮。Sanh等[2]使用BERT[3]模型通過知識蒸餾的方式得到一個具有6層Transformer的小模型DistilBERT。Jiao等[4]采用Transformer蒸餾的方法把BERT模型學(xué)到的知識遷移到小型模型TinyBERT中。Liu等[5]提出一種采用獨(dú)特自蒸餾機(jī)制進(jìn)行微調(diào)的語言模型FastBERT。
基于以上問題和解決方法,本文提出了基于知識蒸餾的文本分類模型DistillBIGRU。本文主要工作如下:
(1) 本文在MPNet模型的基礎(chǔ)上,融合圖卷積網(wǎng)絡(luò)GCN提出了一種結(jié)合大規(guī)模預(yù)訓(xùn)練和直推式學(xué)習(xí)的文本分類模型MPNetGCN作為教師模型;
(2) 選擇并構(gòu)建了一個復(fù)雜度和參數(shù)量較低的BiGRU網(wǎng)絡(luò)作為學(xué)生模型;
(3) 在目標(biāo)數(shù)據(jù)集上微調(diào)(Fine-tune)教師模型,把教師模型輸出的文本所屬類別概率預(yù)測值作為學(xué)生模型輸入文本的標(biāo)簽,對學(xué)生模型進(jìn)行訓(xùn)練,最終把教師模型學(xué)到的知識遷移到學(xué)生模型中,得到一個分類效果好的小型DistillBIGRU模型。
目前,文本分類的方法主要有兩種,一種是基于傳統(tǒng)機(jī)器學(xué)習(xí)的方法,另一種是基于深度學(xué)習(xí)的方法。在傳統(tǒng)機(jī)器學(xué)習(xí)領(lǐng)域,史瑞芳[6]提出了一種改進(jìn)的貝葉斯文本分類器,Miao等[7]提出了一種基于PCA和KNN的混合文本分類算法,周慶平等[8]提出了一種改進(jìn)的支持向量機(jī)文本分類方法。
傳統(tǒng)機(jī)器學(xué)習(xí)方法常采用詞頻-逆文本頻率(TF-IDF)或詞袋模型來表示文本,因此往往會導(dǎo)致語義表示不充分和“維度災(zāi)難”。近年來,隨著深度學(xué)習(xí)方法的興起,基于該方法的模型在計算機(jī)視覺和NLP等領(lǐng)域都取得了許多重要成果。在NLP領(lǐng)域,卷積神經(jīng)網(wǎng)絡(luò)(CNN)和循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)是深度學(xué)習(xí)中常用的神經(jīng)網(wǎng)絡(luò)。CNN能夠利用多層感知結(jié)構(gòu)來捕捉文本的顯著特征。然而,CNN的特征捕獲能力往往取決于卷積內(nèi)核的大小,同時CNN忽略了本地信息之間的依賴性特征,這些缺點(diǎn)對文本分類的準(zhǔn)確性有一定的影響。與CNN相比,RNN的輸出值取決于當(dāng)前時間單位的輸入和以前時間單位的輸出,該結(jié)構(gòu)可以捕獲文本上下文信息。Guo等[9]提出了基于CNN-RNN的混合文本分類模型。Wu等[10]提出了一種基于字符級CNN和支持向量機(jī)的文本分類模型,但RNN在訓(xùn)練過程中存在梯度爆炸和破壞的問題。在此基礎(chǔ)上,研究人員提出了改進(jìn)的RNN模型,包括長短時記憶網(wǎng)絡(luò)(LSTM)和門控循環(huán)單元(GRU),通過選擇保存信息來克服梯度爆炸和破壞的問題。雙向長短時記憶網(wǎng)絡(luò)(BiLSTM)和雙向門控循環(huán)單元(BiGRU)分別是LSTM和GRU的進(jìn)一步發(fā)展,它們可以從文本序列的前向和后向獲得上下文特征,并更好地解決文本分類的任務(wù)。王海濤等[11]提出了基于LSTM和CNN的文本分類方法。
基于深度學(xué)習(xí)的文本分類模型分為兩種: 預(yù)訓(xùn)練語言模型和非預(yù)訓(xùn)練語言模型。在非預(yù)訓(xùn)練語言模型方面,Liu等[12]提出了一種具有注意力機(jī)制和卷積層的雙向LSTM文本。孔韋韋等[13]提出了一種融合注意力機(jī)制、CNN和LSTM的混合分類模型。張昱等[14]提出一種基于組合卷積神經(jīng)網(wǎng)絡(luò)的文本分類方法。硬件設(shè)備的快速迭代推動了預(yù)訓(xùn)練語言模型的快速發(fā)展。2018年,Google AI推出了自編碼(AutoEncoder)語言模型BERT,該模型在當(dāng)時刷新了多項(xiàng)NLP任務(wù)榜單。BERT模型采用多層雙向Transformer結(jié)構(gòu),在預(yù)訓(xùn)練階段創(chuàng)新性地提出了Mask機(jī)制和Next Sentence Predict任務(wù)。與此同時,Mask也帶來了一些缺點(diǎn),由于該模型只在預(yù)訓(xùn)練階段進(jìn)行Mask,導(dǎo)致了預(yù)訓(xùn)練和后續(xù)Fine-tune的差異。另外,Mask機(jī)制假設(shè)在一句話中被掩蓋的詞之間相互獨(dú)立,而實(shí)際上被掩蓋的詞之間可能是有聯(lián)系的。為了解決這些問題,2019年谷歌推出了自回歸(AutoRegressive)語言模型XLNet[15],該模型引入了排列語言建模PLM(Permutation Language Modeling)用于預(yù)訓(xùn)練來解決Mask機(jī)制帶來的負(fù)面影響,通過使用Transformer-XL解決了超長序列的依賴問題。但是XLNet沒有充分利用句子的位置信息,因此在訓(xùn)練前和調(diào)優(yōu)之間存在位置差異。2020年,Song等[16]提出了一種新的預(yù)訓(xùn)練語言模型MPNet,它繼承了BERT和XLNet的優(yōu)點(diǎn),避免了它們的局限性。與BERT中的掩碼語言模型MLM(Masked Language Model)相比,MPNet通過排列語言建模利用了預(yù)測標(biāo)記之間的依賴關(guān)系,并將輔助位置信息作為輸入,使模型看到完整的句子,從而減少了位置差異。以上模型大都存在模型復(fù)雜度高、參數(shù)量大、訓(xùn)練數(shù)據(jù)龐大的特點(diǎn),為了使這些模型能夠被推廣使用,需要對這些模型進(jìn)行壓縮。Ma等[17]和Li等[18]相繼提出了基于知識蒸餾的模型壓縮方法,通過該方法能夠使得分類效果不好的、復(fù)雜度較低的模型經(jīng)過知識蒸餾后能夠?qū)崿F(xiàn)較好的分類效果。
與傳統(tǒng)的語言模型訓(xùn)練方式不同,圖卷積網(wǎng)絡(luò)GCN的概念被Thomas[19]首次提出并在文本分類任務(wù)中取得了很好的效果。該網(wǎng)絡(luò)通過構(gòu)建圖的方式來對詞與詞之間以及詞與文檔之間的關(guān)系進(jìn)行建模,然后對圖中的節(jié)點(diǎn)進(jìn)行分類。圖中的節(jié)點(diǎn)代表文本單詞和文檔,邊的數(shù)值為語義相似度權(quán)重。由于圖中節(jié)點(diǎn)的決策不僅受節(jié)點(diǎn)自身的影響,還受與其相鄰的節(jié)點(diǎn)影響,所以該網(wǎng)絡(luò)的抗干擾能力更強(qiáng)。另外無標(biāo)簽的數(shù)據(jù)也有助于圖的構(gòu)建進(jìn)而提高模型的分類效果?;贕CN,Yao等[20]提出一種用于文本分類的圖形卷積網(wǎng)絡(luò)Text GCN。Tang等[21]提出了基于圖卷積網(wǎng)絡(luò)的混合文本分類模型IGCN。Lin等[22]結(jié)合BERT模型和GCN提出一種新的文本分類模型BERTGCN。
考慮到預(yù)訓(xùn)練語言模型以及GCN在文本分類領(lǐng)域的優(yōu)勢,本文在MPNet模型的基礎(chǔ)上融合圖卷積網(wǎng)絡(luò)GCN,提出了MPNetGCN文本分類模型。同時,為了加快推理速度,推廣其在下游任務(wù)中的應(yīng)用,本文采用知識蒸餾的方法,把MPNetGCN模型學(xué)到的知識遷移到一個較小的模型BiGRU中,得到適用于下游任務(wù)的文本分類模型DistillBIGRU。
一個輸入長度為4的MPNet[16]模型結(jié)構(gòu)如圖1所示。文本在輸入模型前會被打亂,重新排列組合成一個新的文本作為模型的輸入。輸入文本由4個詞組成,即可產(chǎn)生4!種組合方式,可以隨機(jī)排列組合成24種句子,輸入的句子為其中的一種。假設(shè)輸入的文本序列表示為x=(x1,x4,x3,x2),文本在輸入模型后被分為預(yù)測部分和非預(yù)測部分。假設(shè)非預(yù)測部分的長度c=2,表示為(xz<=c)=(x1,x4)(圖1中左邊虛線的左側(cè)部分),z<=c表示x中的前c個單詞。預(yù)測部分表示為(xz>c)=(x3,x2)(圖1中右邊虛線的右側(cè)部分),z>c表示x中的后c個單詞。為了使預(yù)訓(xùn)練中的輸入信息與下游任務(wù)中的輸入信息保持一致,在非預(yù)測部分添加了預(yù)測單詞(token)的掩碼符號M和位置信息P,此時序列可表示為(xz<=c,Mz>c,xz>c)=(x1,x4,M,M,x3,x2)。添加了M后的非預(yù)測部分可表示為(xz<=c,Mz>c)=(x1,x4,M,M),預(yù)測部分保持不變。圖1中兩條虛線中間的部分Mz>c=(M,M)表示預(yù)測單詞部分的掩碼符號,相應(yīng)的位置序列同時也更新為(p1,p4,p3,p2,p3,p2)。通過位置補(bǔ)償,模型在預(yù)測每個單詞時都能看到全句的位置信息,避免了PLM中位置信息的缺失。模型對于非預(yù)測部分采用雙流注意力機(jī)制(內(nèi)容流、查詢流)來提取特征,使得輸出和輸入的依賴性一致。
圖1 MPNet結(jié)構(gòu)圖
圖2為一個單層的GCN。GCN[19]是一個由邊和節(jié)點(diǎn)構(gòu)成的異構(gòu)圖。圖中的節(jié)點(diǎn)分為詞節(jié)點(diǎn)和文檔節(jié)點(diǎn),節(jié)點(diǎn)與節(jié)點(diǎn)之間用邊連接,邊的數(shù)值代表節(jié)點(diǎn)與節(jié)點(diǎn)之間的關(guān)系權(quán)重。
圖2 單層圖卷積網(wǎng)絡(luò)結(jié)構(gòu)
節(jié)點(diǎn)與節(jié)點(diǎn)之間的關(guān)系權(quán)重的計算方式如式(1)所示。其中PMI(i,j)代表節(jié)點(diǎn)i和節(jié)點(diǎn)j之間的互信息,其計算如式(2)所示,式中p(i,j)代表節(jié)點(diǎn)i和節(jié)點(diǎn)j同時出現(xiàn)的概率,p(i),p(j)分別代表節(jié)點(diǎn)i和節(jié)點(diǎn)j單獨(dú)出現(xiàn)的概率。TF-IDF代表詞頻-逆文本頻率。
(1)
(2)
本文提出的MPNetGCN模型的結(jié)構(gòu)如圖3所示,該模型由MPNet模型、GCN、全連接層、softmax分類層以及詞向量存儲庫組成。MPNetGCN模型是一個GCN和MPNet模型相結(jié)合的模型,該模型利用單詞或文檔的語料庫構(gòu)建一個GCN異構(gòu)圖,然后通過預(yù)先訓(xùn)練的MPNet模型來初始化圖中節(jié)點(diǎn)的向量表示,最后用GCN對這些節(jié)點(diǎn)進(jìn)行分類。通過聯(lián)合訓(xùn)練BERT和GCN模塊,使得模型能夠結(jié)合大規(guī)模預(yù)訓(xùn)練和圖卷積網(wǎng)絡(luò)的優(yōu)勢,提高模型自身的分類效果。該模型的訓(xùn)練方式如下: ①利用MPNet模型生成詞節(jié)點(diǎn)對應(yīng)的特征向量; ②利用這些特征向量去初始化GCN; ③迭代訓(xùn)練數(shù)據(jù),更新GCN中的權(quán)重。
圖3 MPNetGCN模型結(jié)構(gòu)圖
MPNetGCN在構(gòu)建和后續(xù)的訓(xùn)練中考慮到了以下兩個方面的問題: ①GCN和MPNet對訓(xùn)練數(shù)據(jù)的加載和迭代方式不同。GCN在訓(xùn)練時需要加載整個圖中的所有節(jié)點(diǎn),而MPNet受到模型大小和內(nèi)存容量的限制,每次只能加載一個批(batch)的數(shù)據(jù)。即MPNet每次只能更新一個批的特征向量,而GCN需要加載所有批的特征向量。②MPNet的權(quán)重可通過模型的損失函數(shù)進(jìn)行反向傳播來更新,但GCN存在梯度消失和過度平滑的缺點(diǎn),因此如果采用GCN的輸出結(jié)果計算損失進(jìn)行反向傳播更新權(quán)重,可能導(dǎo)致GCN無法更新權(quán)重。
對于問題一,本文通過構(gòu)建詞向量存儲庫的方法分批次地存放詞向量(詞向量存儲庫的大小根據(jù)詞節(jié)點(diǎn)的數(shù)量設(shè)定)。每個批的詞向量放在同一個存儲單元內(nèi),每迭代一個批的數(shù)據(jù)更新對應(yīng)的存儲單元。由于詞節(jié)點(diǎn)對應(yīng)的存儲單元內(nèi)的特征向量是不斷變化的,所以會對GCN的訓(xùn)練產(chǎn)生干擾,通過采用小學(xué)習(xí)率和增加迭代數(shù)據(jù)的次數(shù)的方法來訓(xùn)練模型可以減少這種干擾。對于問題二,本文聯(lián)合MPNet模型的輸出以及GCN的輸出與標(biāo)簽做交叉損失進(jìn)行反向傳播來更新GCN的權(quán)重。即通過把MPNet模型對輸入文本的概率預(yù)測值和GCN模型對輸入文本的概率預(yù)測值相加進(jìn)行歸一化后作為GCN模型對輸入文本的最終預(yù)測概率值,然后利用該預(yù)測值計算GCN的損失并更新權(quán)重。計算過程如式(5)~式(7)所示,式(5)中,X代表輸入文本經(jīng)MPNet模型后得到的特征向量,WM代表與MPNet模型相連的全連接層的內(nèi)部權(quán)重。式(6)中,ZMPNetGCN、ZMPNet分別代表GCN和MPNet模型對輸入文本所屬類別的預(yù)測概率。式(7)中,losscem代表GCN的損失函數(shù),T代表批數(shù)量(batchsize),C代表數(shù)據(jù)集中包含C類文本,第t個樣本屬于i的概率yti的值為0或1,如果第t個樣本的真實(shí)類別等于i則取1,否則取0。
ZMPNet=softmax(WM·X)
(5)
Z=ZGCN+ZMPNet
(6)
(7)
本文采用第2節(jié)提出的MPNetGCN模型作為教師模型。在文本分類任務(wù)中使用知識蒸餾的目的是為了在一定的分類準(zhǔn)確率范圍內(nèi)盡可能地使蒸餾過后得到的模型更小。目前,常用的基于深度學(xué)習(xí)且結(jié)構(gòu)較為簡單的文本分類模型大都基于LSTM。GRU作為LSTM的一種變體,將忘記門和輸入門合成為一個單一的更新門,同樣還混合了細(xì)胞狀態(tài)和隱藏狀態(tài)。GRU網(wǎng)絡(luò)的結(jié)構(gòu)只包含了兩個門,其結(jié)構(gòu)比標(biāo)準(zhǔn)的LSTM模型要簡單,模型的參數(shù)更少,但性能基本相同。從計算成本和時間成本來看,GRU網(wǎng)絡(luò)更有效率。由于在處理文本序列時需要考慮上下文的語義關(guān)系,因此本文選擇BiGRU作為學(xué)生模型的主體部分,其結(jié)構(gòu)如圖4所示。
圖4 BiGRU網(wǎng)絡(luò)結(jié)構(gòu)圖
(8)
(9)
(10)
z(r)(x)=[max(H1),max(H2),…,max(Ht)]
(11)
對于輸入序列x,利用學(xué)生模型預(yù)測該序列屬于每個類別的概率的計算如式(12)所示,Wr代表全連接層中的權(quán)重信息。
p(r)(x)=softmax(W(r)·z(r)(x))
(12)
(13)
在知識蒸餾前,首先根據(jù)Dl訓(xùn)練數(shù)據(jù)對MPNetGCN 教師模型進(jìn)行微調(diào),調(diào)優(yōu)的目標(biāo)是最小化損失函數(shù)的值。
本文通過“標(biāo)簽”把教師模型學(xué)到的知識遷移到學(xué)生模型中。知識蒸餾的目的是為了讓學(xué)生模型學(xué)到更多的外部知識,提高模型的泛化能力。數(shù)據(jù)集中的每個樣本的真實(shí)標(biāo)簽(hard target)為獨(dú)熱的編碼形式,根據(jù)式(13)計算交叉熵?fù)p失時,學(xué)生模型的輸出結(jié)果中只有一維參與了損失loss的計算,忽略了標(biāo)簽與標(biāo)簽之間的關(guān)系。比如類別貓和類別狗的相似性較高,類別貓和類別車的相似性較低,因此在計算損失時前者應(yīng)該給予更小的loss。采用教師模型對輸入文本所屬類別的預(yù)測概率(soft label)作為學(xué)生模型的label,在計算loss時可以使學(xué)生模型輸出結(jié)果中不為0的每一維都參與運(yùn)算,可以使學(xué)生模型學(xué)到更多的信息,在一定程度上提高模型的泛化能力。
為了給學(xué)生模型提供一個更好的“視野”,本文采用教師模型對輸入文本所屬類別的預(yù)測概率(soft label)經(jīng)過式(14)變換后的結(jié)果作為學(xué)生模型輸入文本的“標(biāo)簽”,p(x)代表教師模型對輸入文本x所屬類別的預(yù)測概率。Logit強(qiáng)調(diào)模型在不同情況下應(yīng)該學(xué)習(xí)到不同關(guān)系。例如,評論“我喜歡這部電影”的負(fù)面可能性很小,而評論“這部電影本可以更好”可以是正的或負(fù)的,具體取決于上下文。Logit把這種正和負(fù)的關(guān)系反映到了標(biāo)簽中,通過這種方式可以使模型學(xué)到更多的信息。
(14)
學(xué)生模型的損失函數(shù)采用均方誤差損失函數(shù),計算過程如式(15)、式(16)所示。式(15)中zs(x)表示對于輸入文本x,學(xué)生模型輸出的特征向量。rs(xu)表示學(xué)生網(wǎng)絡(luò)對輸入文本xu的標(biāo)簽預(yù)測值,WT、bT為可訓(xùn)練參數(shù)。式(16)中,N代表批數(shù)量(batchsize),pt(xu)表示教師模型對輸入文本xu預(yù)測得到的屬于每個類別的概率,Du表示一個批(batch)的數(shù)據(jù)(該部分?jǐn)?shù)據(jù)只需要文本部分,不需要文本對應(yīng)的標(biāo)簽)。
rs(x)=WT·zs(x)+bT
(15)
(16)
迭代訓(xùn)練數(shù)據(jù)進(jìn)行知識蒸餾,當(dāng)損失函數(shù)的值最小的時候,學(xué)生模型的學(xué)習(xí)能力達(dá)到飽和,蒸餾完成。
為了檢驗(yàn)本文提出的文本分類模型的分類效果,選取了4個廣泛使用的文本分類數(shù)據(jù)集進(jìn)行實(shí)驗(yàn),數(shù)據(jù)集的相關(guān)介紹如下:
20NG(20newsgroups): 該數(shù)據(jù)集包含20個不同種類的文檔,新聞種類按照新聞主題劃分。
R8: 該數(shù)據(jù)集為Reuters-21578數(shù)據(jù)集的子集,Reuters-21578數(shù)據(jù)集由8個不同種類的路透社財經(jīng)新聞文檔組成。
R52: 該數(shù)據(jù)集與R8數(shù)據(jù)集類似,同為Reuters-21578數(shù)據(jù)集的子集。
MR: 該數(shù)據(jù)集為電影評論數(shù)據(jù)集,評論分為“積極”和“消極”兩個種類。
對數(shù)據(jù)集進(jìn)行數(shù)據(jù)清洗(去掉文本中的停用詞、文本分割等),然后對數(shù)據(jù)集的相關(guān)信息進(jìn)行統(tǒng)計,包括: 樣本數(shù)量(text_num)、訓(xùn)練集樣本數(shù)量(train_num)、測試集樣本數(shù)量(test_num)、數(shù)據(jù)集包含單詞的數(shù)量(word_num)、節(jié)點(diǎn)數(shù)量(node_num)、數(shù)據(jù)集中樣本種類的數(shù)量(class)以及數(shù)據(jù)集中樣本的平均長度(AL),統(tǒng)計結(jié)果如表1所示。
表1 數(shù)據(jù)集各項(xiàng)指標(biāo)統(tǒng)計結(jié)果表
實(shí)驗(yàn)平臺的配置為Intel Xeon3104處理器、16 GB內(nèi)存、GTX2080Ti顯卡,并使用64位操作系統(tǒng)Ubuntu 18.04。
本文實(shí)驗(yàn)分為兩部分進(jìn)行,第一部分為MPNetGCN,第二部分為知識蒸餾部分。
該部分實(shí)驗(yàn)所選取的對比基線模型如下(其中對預(yù)訓(xùn)練模型加載預(yù)訓(xùn)練權(quán)重,然后在數(shù)據(jù)集上進(jìn)行Fine-tune處理)。
BiLSTM: 由前向LSTM和后向LSTM構(gòu)成的模型。
TextGCN[20]: 由雙層GCN構(gòu)成的文本分類模型。
BERT[3]: 實(shí)驗(yàn)使用谷歌開源的BERT-Base模型,使用官方提供的預(yù)訓(xùn)練權(quán)重加載模型。
RoBERTa[23]: 強(qiáng)化BERT模型,實(shí)驗(yàn)使用facebook官方提供的預(yù)訓(xùn)練權(quán)重加載模型。
BERTGCN[22]: 實(shí)驗(yàn)中使用作者提供的預(yù)訓(xùn)練權(quán)重初始化該模型。
RoBERTaGCN[23]: 一種基于GCN和RoBERTa的混合文本分類模型。
MPNet[16]: 實(shí)驗(yàn)使用作者提供的預(yù)訓(xùn)練權(quán)重加載模型。
該部分實(shí)驗(yàn)中的各模型初始超參數(shù)設(shè)置如表2所示。
表2 各模型的初始超參數(shù)表
實(shí)驗(yàn)中的各模型分類準(zhǔn)確率統(tǒng)計結(jié)果如表3所示(1M=1百萬)。
表3 各模型分類準(zhǔn)確率統(tǒng)計表
續(xù)表
根據(jù)表3可以看出本文提出的語言模型MPNetGCN在除MR外的實(shí)驗(yàn)數(shù)據(jù)集上均取得了最高的分類準(zhǔn)確率。與RoBERTaGCN相比,MPNetGCN模型平均準(zhǔn)確率提高了0.4%。從表2可以看出,MR數(shù)據(jù)集中樣本的平均長度最短,GCN在該數(shù)據(jù)集上的分類效果最差。另外,BERT模型、RoBERTa模型、MPNet模型在融合GCN后模型的分類效果的提升相比其他數(shù)據(jù)集最低。這是因?yàn)樵谖谋鹃L度較短的情況下GCN能利用的節(jié)點(diǎn)信息較少,導(dǎo)致GCN捕捉文本特征的能力有所下降,說明GCN不善于處理短文本。RoBERTa和BERT結(jié)構(gòu)基本相同[23],由于RoBERTa在預(yù)訓(xùn)練階段與其他預(yù)訓(xùn)練語言模型相比使用了更大的數(shù)據(jù)集,事先學(xué)習(xí)到了更多的“知識”,在集成GCN后,RoBERTaGCN在MR數(shù)據(jù)集上取得了最好的分類效果,分類準(zhǔn)確率達(dá)到了89.7%。另外,通過表3還可以發(fā)現(xiàn),預(yù)訓(xùn)練語言模型的平均分類準(zhǔn)確率明顯高于非預(yù)訓(xùn)練語言模型BiLSTM和TextGCN。以上結(jié)果都體現(xiàn)出了大規(guī)模預(yù)訓(xùn)練對于提升模型的性能有很大作用,同時也說明可以通過使用更大的數(shù)據(jù)集對模型進(jìn)行預(yù)訓(xùn)練來進(jìn)一步提高模型的性能。但從表3也可以看出預(yù)訓(xùn)練語言模型的參數(shù)量也遠(yuǎn)遠(yuǎn)高于非預(yù)訓(xùn)練語言模型。
在面對長度適中且分類邊界較為清晰(70詞左右,8個類別)的數(shù)據(jù)集R8、R52時,MPNetGCN表現(xiàn)最好,分類準(zhǔn)確率最高,分別達(dá)到了98.3%和97.4%。通過表3還可以看出,各模型的分類準(zhǔn)確率都明顯高于在其他數(shù)據(jù)集上的分類準(zhǔn)確率,說明了現(xiàn)有的文本分類模型更善于處理這種類型的數(shù)據(jù)。
從表1中可以看出,20NG數(shù)據(jù)集中的文本平均長度達(dá)到221詞,遠(yuǎn)遠(yuǎn)高于其他數(shù)據(jù)集。在該數(shù)據(jù)集上MPNetGCN的平均分類準(zhǔn)確率達(dá)到了91.0%,高于實(shí)驗(yàn)中的其他模型。此外,BERTGCN模型和MPNetGCN模型的分類準(zhǔn)確率相較于BERT和MPNet提升最為明顯;且GCN的分類準(zhǔn)確率比BERT模型提高了1%,體現(xiàn)出GCN在處理長文本時更有優(yōu)勢。
結(jié)合表1以及表3中各模型在多個數(shù)據(jù)集上的分類準(zhǔn)確率,可以看出本文提出的MPNetGCN模型在面對不同長度的文本時文本的分類效果綜合表現(xiàn)最好,平均分類準(zhǔn)確率達(dá)到了93.8%,高于其他模型。通過各模型之間的對比還可以發(fā)現(xiàn),在融合了GCN后各模型的平均分類準(zhǔn)確率都得到了一定的提升,說明通過多個模型來提取文本特征進(jìn)行文本分類的方式能夠避免單一模型自身的一些局限性,從而提高模型的整體分類性能。
該部分實(shí)驗(yàn)選擇的對比基線模型如下(其中對預(yù)訓(xùn)練模型進(jìn)行加載預(yù)訓(xùn)練權(quán)重,然后在數(shù)據(jù)集上進(jìn)行Fine-tune的處理):
BiGRU: 由前向GRU網(wǎng)絡(luò)和后向GRU網(wǎng)絡(luò)構(gòu)成的模型。
DistillLSTM[24]: 以BERT-Base作為教師模型,BiLSTM作為學(xué)生模型,通過知識蒸餾得到的模型。
DistillBERT6[2]: 通過BERT-Base模型蒸餾得到的含有6層transformer的BERT6模型。
實(shí)驗(yàn)超參數(shù)設(shè)置如表2所示。
在該部分實(shí)驗(yàn)中,各模型的分類準(zhǔn)確率結(jié)果如表4所示。其中,T表示教師模型。
表4 各模型分類準(zhǔn)確率統(tǒng)計表
從表4可以看出,以MPNetGCN作為教師模型,BIGRU、BILSTM作為學(xué)生模型,在進(jìn)行知識蒸餾后,DistillBIGRU和DistillLSTM對文本的分類效果二者相當(dāng),平均分類準(zhǔn)確率均達(dá)到了91%,高于實(shí)驗(yàn)中的其他模型。但BiGRU結(jié)構(gòu)上更加簡單,因此選擇BiGRU作為學(xué)生模型更有效率。結(jié)合表3看出,無論是以MPNetGCN作為教師模型,還是以BERT作為教師模型,在進(jìn)行知識蒸餾后學(xué)生模型的性能相比蒸餾前都得到了提升。DistillBIGRU(T=MPNetGCN)模型和DistillLSTM(T=BERT)模型相比于BIGRU模型和BILSTM模型分類準(zhǔn)確率分別提升了5.4%,4.3%;DistillLSTM(T=MPNetGCN) 模型相比于BILSTM模型平均分類準(zhǔn)確率提升了5.2%,說明了采用“教師-學(xué)生”知識蒸餾的方法對學(xué)生模型的性能提升是有幫助的。
DistillLSTM(T=MPNetGCN)模型與DistillLSTM(T=BERT)模型相比平均分類準(zhǔn)確率提高了0.9%,說明了教師模型性能的提升可以提高學(xué)生模型的學(xué)習(xí)能力。
DistillBERT6模型和DistillLSTM模型在教師模型均為BERT模型的情況下,由于BERT6模型參數(shù)更多,分類性能更高,導(dǎo)致在進(jìn)行知識蒸餾后DistillBERT6模型的平均分類準(zhǔn)確率相比DistillLSTM模型提升了0.6%,說明了學(xué)生模型自身分類性能的提高對于提升最終模型的分類效果是有利的。
BERT-Base模型的參數(shù)量約為108M,DistillBiGRU模型的參數(shù)量約為13M。根據(jù)表3和表4,BERT模型在數(shù)據(jù)集上的平均分類準(zhǔn)確率達(dá)到了91.2%,DistillBiGRU模型的平均分類準(zhǔn)確率達(dá)到了91.0%,在文本分類方面DistillBIGRU模型與BERT模型的分類效果相當(dāng),驗(yàn)證了本文提出方法的合理性及有效性。
針對現(xiàn)有預(yù)訓(xùn)練語言模型參數(shù)量龐大、算法復(fù)雜以及訓(xùn)練成本高等問題,本文提出了基于知識蒸餾的文本分類模型DistillBiGRU。首先結(jié)合MPNet模型和GCN提出了MPNetGCN語言模型作為教師模型,該模型在實(shí)驗(yàn)中的多個數(shù)據(jù)集上取得了最好的分類效果,與BERTGCN模型相比平均分類準(zhǔn)確率提高了1.3%。在知識蒸餾階段,本文選擇BiGRU作為學(xué)生模型,在實(shí)驗(yàn)中,通過蒸餾得到的模型DistillBiGRU在多個數(shù)據(jù)集上的平均分類準(zhǔn)確率達(dá)到了91.0%,在參數(shù)量遠(yuǎn)小于BERT模型(約為BERT模型的1/9)的前提下,平均分類準(zhǔn)確率與BERT模型相當(dāng)。但是,在利用MPNetGCN進(jìn)行文本分類時,存在GCN和MPNet模型加載數(shù)據(jù)不同步的問題,雖然該問題對最終的分類結(jié)果影響不大,但不可忽略。另外,不同的蒸餾方法對蒸餾后的學(xué)生模型的性能也有影響。這些問題在后續(xù)的研究中值得關(guān)注。