蘇曉明,斯 琴,王海波,董建敏
(1. 內(nèi)蒙古工業(yè)大學(xué)數(shù)據(jù)科學(xué)與應(yīng)用學(xué)院,內(nèi)蒙古 呼和浩特 010000;2. 內(nèi)蒙古自治區(qū)紀(jì)檢監(jiān)察大數(shù)據(jù)實(shí)驗室,內(nèi)蒙古 呼和浩特 010000)
聚類是通過學(xué)習(xí)分析數(shù)據(jù)的潛在分布,進(jìn)而將相似的樣本劃入相同的簇。它是數(shù)據(jù)科學(xué)以及機(jī)器學(xué)習(xí)領(lǐng)域中的一個重要研究方向。傳統(tǒng)的聚類算法包括基于劃分的K-均值算法[6]、基于層次的 CURE算法[7]、基于概率分布的 EM 算法[8]以及基于密度的DBSCAN算法[9]。但是對于高維數(shù)據(jù)集,傳統(tǒng)的聚類算法很難達(dá)到理想的聚類效果,會需要很高的內(nèi)存消耗。雖然目前也有相關(guān)方法對原始高維數(shù)據(jù)進(jìn)行降維,但是只能進(jìn)行線性降維,而不能進(jìn)行非線性關(guān)系映射處理。針對該問題,基于深度學(xué)習(xí)的聚類算法引發(fā)了眾多學(xué)者的研究興趣,它使用不同的非線性函數(shù)降維,減少噪聲數(shù)據(jù),提高數(shù)據(jù)有效性。
深度聚類算法的主要思想是將聚類目標(biāo)融合進(jìn)數(shù)據(jù)表達(dá)里面。例如Bo等[3]首次提出了深度聚類網(wǎng)絡(luò)(Deep Clustering Network,DCN),它采用棧式自動編碼器對數(shù)據(jù)進(jìn)行非線性降維,使用包含K-均值原理[6]的損失函數(shù)去訓(xùn)練模型,使得模型更容易對數(shù)據(jù)進(jìn)行K-均值劃分。深度嵌入聚類網(wǎng)絡(luò)(Deep Embedding Clustering,DEC)[4]通過 KL散度損失函數(shù)使得自動編碼器生成的數(shù)據(jù)表達(dá)更接近聚類的中心,提高了內(nèi)聚性。改進(jìn)的深度嵌入聚類網(wǎng)絡(luò)(Improved Deep Embedding Clustering,IDEC)[5]在 DEC模型的基礎(chǔ)上添加了重建損失函數(shù)來幫助模型學(xué)習(xí)到更好的表達(dá)。但是這些方法只提取了有效的數(shù)據(jù)表達(dá),沒有考慮到數(shù)據(jù)結(jié)構(gòu)信息對聚類目標(biāo)準(zhǔn)確性的影響。
針對以上問題,一些基于圖結(jié)構(gòu)的聚類算法被廣泛使用,Kip等[2]提出了圖自動編碼器和圖變分自編碼器,它們使用圖卷積網(wǎng)絡(luò)作為編碼器將結(jié)構(gòu)信息和結(jié)點(diǎn)特征融合在一起,去學(xué)習(xí)結(jié)點(diǎn)的表征。Yang等[11]設(shè)計了一種基于圖嵌入的高斯混合變分自動編碼器,將局部數(shù)據(jù)結(jié)構(gòu)添加到深度高斯混合模型中進(jìn)行聚類。Bo等[1]提出結(jié)構(gòu)化深度聚類網(wǎng)絡(luò)(Structural Deep Clustering Network,SDCN),將圖卷積網(wǎng)絡(luò)融入聚類網(wǎng)絡(luò)中,旨在捕捉數(shù)據(jù)的結(jié)構(gòu)化信息,提升模型聚類性能。但是上述網(wǎng)絡(luò)缺少針對聚類目標(biāo)更具判別性的數(shù)據(jù)表征。
為了解決上述問題,本文提出深度圖注意力聚類網(wǎng)絡(luò),旨在融合數(shù)據(jù)結(jié)構(gòu)信息的基礎(chǔ)上利用注意力機(jī)制的特點(diǎn),根據(jù)預(yù)定聚類目標(biāo)提取數(shù)據(jù)重要部分,從而提升整個模型的分類準(zhǔn)確度。
本文提出了一種基于圖注意力機(jī)制的深度聚類網(wǎng)絡(luò)。首先利用自動編碼器模塊分層提取不同表征的全局信息。根據(jù)輸入數(shù)據(jù)特點(diǎn)建立圖結(jié)構(gòu)數(shù)據(jù),計算圖結(jié)構(gòu)數(shù)據(jù)中鄰居結(jié)點(diǎn)重要度,采用多頭圖注意力模塊提取包含圖結(jié)構(gòu)的特征信息,將不同層全局表達(dá)與對應(yīng)圖結(jié)構(gòu)信息相連,通過隨機(jī)梯度下降與反向傳播來優(yōu)化重建損失與基于KL散度聚類損失的加權(quán)和,學(xué)習(xí)網(wǎng)絡(luò)表征及其簇分配,如下圖1所示。
圖1 深度圖注意力聚類網(wǎng)絡(luò)示意圖Fig.1 diagram of deep attention clustering network
對于深度聚類來說,有效的數(shù)據(jù)表達(dá)是非常重要的。在這里我們?yōu)榱吮3忠话阈圆⑹鼓P湍苓m用于不同的數(shù)據(jù)格式而選擇自動編碼器(Autoencoder,AE)[9]提取全局信息。假設(shè)有L層編碼器,i代表層數(shù)。第i層編碼器或者解碼器學(xué)習(xí)到的特征表達(dá)公式為:
這里我們采用的是多頭注意力機(jī)制思想融入圖結(jié)構(gòu)中形成多頭圖注意力機(jī)制[11],選定某一結(jié)點(diǎn),它的每個鄰居結(jié)點(diǎn)選取K個獨(dú)立的注意力機(jī)制進(jìn)行計算,K個不同結(jié)果采用平均或者拼接方式作用在該結(jié)點(diǎn)上。如圖2所示,本圖演示了三頭注意力機(jī)制,不同顏色箭頭表示不同注意力機(jī)制的計算。
圖2 多頭注意力機(jī)制示意圖Fig.2 diagram of multi-head attention mechanism
其中Ni表示與結(jié)點(diǎn)i相鄰的結(jié)點(diǎn),K表示注意力機(jī)制的個數(shù),δ表示非線性函數(shù),αi,j表示結(jié)點(diǎn) j與結(jié)點(diǎn)i之間的注意力權(quán)重系數(shù),其計算公式如(5)所示:
聚合鄰居信息時,需要對每個節(jié)點(diǎn)的所有鄰居的注意力進(jìn)行歸一化。歸一化之后的注意力權(quán)重才是真正的注意力聚合系數(shù)。通過學(xué)生 t分布計算圖注意力模塊輸出zi與聚類中心ui的相似度通過高自信度計算聚類頻率目標(biāo)分布pij,運(yùn)用KL散度優(yōu)化qij和pij得到聚類的損失函數(shù),如公式(6)所示,通過隨機(jī)梯度下降與反向傳播優(yōu)化重建損失與基于 KL散度聚類損失的加權(quán)和,學(xué)習(xí)網(wǎng)絡(luò)表征及其簇分配。
本文實(shí)驗環(huán)境如下:
操作系統(tǒng)為Ubuntu1804。
硬件平臺為 Intel(R) Xeon(R) E5-2640 v4@2.40 GHz CPU,120 GB。
內(nèi)存,11 GB。
高速緩存,GPU是Nvidia Geforce GTX 1080.編程環(huán)境Python3.8。
深度學(xué)習(xí)框架為Pytorch1.7.1。
我們的模型在六種數(shù)據(jù)集上進(jìn)行了評估,其中HHAR包含來自智能手機(jī)和手表的10 299個傳感器記錄。所有樣本分成6類:騎自行車、坐、站立、行走、上樓梯和樓梯;Reuters是一個文本數(shù)據(jù)集,包含大約 81萬標(biāo)記好的英文新聞故事的分類樹,分為企業(yè)/工業(yè)、政府/社會、市場和經(jīng)濟(jì)四類;ACM2是一個描述論文之間關(guān)系的數(shù)據(jù)集,包含數(shù)據(jù)庫、無線通信和數(shù)據(jù)挖掘選三類論文;DBLP是作者為結(jié)點(diǎn)的數(shù)據(jù)集。如果兩位作者是合作關(guān)系,那么他們之間有一個邊。作者寫作的類別為數(shù)據(jù)庫、數(shù)據(jù)挖掘、機(jī)器學(xué)習(xí)和信息檢索;Citeseer是一個描述文檔之間關(guān)系的數(shù)據(jù)庫,文檔之間有引文鏈接,那么設(shè)定一條邊。文檔的類型包括代理、商業(yè)智能、數(shù)據(jù)庫、信息檢索、機(jī)器語言和HCI。
為了衡量算法的聚類性能,采用如下聚類評價指標(biāo)進(jìn)行評價。(1)聚類準(zhǔn)確度(Clustering Accuracy,ACC)表示的是預(yù)測為正的樣本中有多少是對的;(2)歸一化互信息(Normalized Mutual Information,NMI)是衡量2個聚類之間共享信息量的信息論度量,較可靠地評價不平衡數(shù)據(jù)集聚類效果;(3)調(diào)整蘭德指數(shù)(Adjusted Rand Index,ARI)是衡量兩個數(shù)據(jù)分布的吻合程度的,值越大意味著聚類結(jié)果與真實(shí)情況越吻合;④F1值(F1-score)是精確率和召回率的調(diào)和平均數(shù),綜合考慮準(zhǔn)確率和召回率的影響。從表 1的對比實(shí)驗的結(jié)果可以看出,我們的算法在其中四種數(shù)據(jù)集上的結(jié)果準(zhǔn)確度更優(yōu)。加粗的黑色數(shù)值表示該算法在當(dāng)前數(shù)據(jù)集的結(jié)果最優(yōu)。
表1 不同算法在數(shù)據(jù)集上的指標(biāo)值對比Tab.1 model result comparison on different datasets
續(xù)表
本文提出深度圖注意力聚類網(wǎng)絡(luò),該模型全方位地抽取數(shù)據(jù)的圖結(jié)構(gòu)信息、全局表達(dá)以及對聚類更有效的局部特征,具體流程如下,首先利用自動編碼器學(xué)習(xí)較好的全局特征,在原數(shù)據(jù)的基礎(chǔ)上構(gòu)建圖結(jié)構(gòu)。根據(jù)鄰居結(jié)點(diǎn)重要度,采用圖多頭注意力模塊提取包含圖結(jié)構(gòu)的特征信息,將不同層全局表達(dá)與對應(yīng)圖結(jié)構(gòu)信息相連,通過隨機(jī)梯度下降與反向傳播優(yōu)化重建損失與面向聚類目標(biāo)的KL散度聚類損失。在5個公開的圖像數(shù)據(jù)集上的實(shí)驗表明,我們的網(wǎng)絡(luò)具有較優(yōu)的聚類性能與良好的泛化性能。