陳慧玲, 張 曄, 田奧升, 趙晗馨
(國防科技大學(xué)電子科學(xué)學(xué)院, 長沙 410073)
隨著傳感器技術(shù)的飛速發(fā)展,序列數(shù)據(jù)分類在健康監(jiān)測、智能家居控制、設(shè)備監(jiān)控等領(lǐng)域得到了廣泛的應(yīng)用。 一些對時(shí)效性要求較高的現(xiàn)實(shí)應(yīng)用,例如災(zāi)難預(yù)測、氣體泄漏檢測、故障檢測[1-2]等,都需要提早地對序列數(shù)據(jù)進(jìn)行分類。 因此,序列數(shù)據(jù)的早期分類具有重要研究價(jià)值[3]。 然而,真實(shí)世界中的早期分類的流數(shù)據(jù)輸入形式為分類增加了難度,并且難以設(shè)置合適的停止條件退出分類。
近年來,傳統(tǒng)方法取得了較好的早期分類效果[4-6],然而手工設(shè)計(jì)的特征需要大量的專家經(jīng)驗(yàn)。此外,這些方法還需要為不同長度的數(shù)據(jù)訓(xùn)練多個(gè)不同的分類器。 深度方法由于其自動分類方案以及有效的特征提取能力在序列數(shù)據(jù)分類領(lǐng)域取得了卓越性能[7]。 部分研究人員逐漸利用深度學(xué)習(xí)的方法解決序列數(shù)據(jù)早期分類任務(wù)[8]。 因此,本文主要關(guān)注基于深度學(xué)習(xí)的方法。
在時(shí)間早期分類領(lǐng)域,基于深度的方法主要分為一階段的方法和二階段的方法。 其中,一階段的方法是指同時(shí)對分類過程及退出過程進(jìn)行優(yōu)化。 這類方法通常設(shè)置分類子網(wǎng)和提前退出子網(wǎng)并對其聯(lián)合優(yōu)化。 然而,分類子網(wǎng)和退出子網(wǎng)的優(yōu)化具有一定的沖突[8]。 這是由于在時(shí)間推移過程中,分類準(zhǔn)確率隨信息量的增長遞增的同時(shí)早期性不斷降低。因此,一階段的方法難以同時(shí)對2 個(gè)子網(wǎng)進(jìn)行優(yōu)化。二階段的方法通過將分類過程與退出過程分離來緩解這種沖突,首先單獨(dú)對分類器進(jìn)行訓(xùn)練,然后通過制定設(shè)置閾值或一定的退出規(guī)則來退出分類。 遞歸神經(jīng)網(wǎng)絡(luò)(recursive neural networks,RNN)由于其貫序的輸入方式被廣泛應(yīng)用于早期分類[9],然而其局部特征提取能力不強(qiáng)。 Huang 等學(xué)者[10-12]利用卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks,CNN)良好的局部特征提取能力,結(jié)合CNN 和RNN 構(gòu)建混合模型來同時(shí)提取局部特征及時(shí)序信息,然后利用分類概率計(jì)算置信度,并制定了一定的退出規(guī)則。Hsu 等學(xué)者[13]同樣使用了混合分類模型,并引入了注意力機(jī)制以增強(qiáng)模型的可解釋性。 此外,現(xiàn)有方法也未考慮到識別正確的概率隨時(shí)間變化的規(guī)律。
為了解決這些問題,本文提出了掩碼時(shí)間注意力機(jī)制以及置信度損失函數(shù)。 首先,本文利用基于掩碼時(shí)間注意力機(jī)制的時(shí)間卷積網(wǎng)絡(luò)對于不同長度的數(shù)據(jù)產(chǎn)生自適應(yīng)的注意力權(quán)重,從而動態(tài)地抑制無關(guān)信息,并更加關(guān)注關(guān)鍵區(qū)域的有效信息,增強(qiáng)特征圖的信息表達(dá)能力。 然后,本文受到正確類別的概率隨時(shí)間推移而遞增[14-15]的啟發(fā),設(shè)計(jì)了置信度損失函數(shù)。 通過對不滿足該條件的概率進(jìn)行懲罰,使得正確類別的概率隨數(shù)據(jù)長度增加而平滑地增加,利于退出閾值的設(shè)置。
深度模型的表達(dá)能力隨著參數(shù)的增加而不斷提升。 然而參數(shù)的增加帶來了更大的計(jì)算量,同時(shí)也增加了大量的冗余信息。 因此,注意力機(jī)制被引入深度模型對網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)制。 該機(jī)制的核心思想是抑制無關(guān)信息,使模型關(guān)注更有效的關(guān)鍵特征[16]。 注意力機(jī)制的核心思想是通過一定的變換來學(xué)習(xí)不同特征重要性的差異,顯著提高了信息處理與應(yīng)用的效率,此外,還具有通用性、直觀性和可解釋性等優(yōu)點(diǎn)。 因此被廣泛應(yīng)用于機(jī)器翻譯、文本分類、語音識別、圖像處理等多個(gè)領(lǐng)域。 根據(jù)注意力機(jī)制插入的位置,注意力機(jī)制可分為時(shí)間注意力、空間注意力、通道注意力等[17]
實(shí)現(xiàn)注意力機(jī)制,通常首先對特征圖進(jìn)行非線性變化得到注意力分?jǐn)?shù),然后對其進(jìn)行SoftMax歸一化當(dāng)成注意力權(quán)重,最后將相應(yīng)的權(quán)重作用于原特征圖進(jìn)行加權(quán)或者逐點(diǎn)相乘獲得新的特征表示。
由于RNN 不能大規(guī)模并行以及具有長時(shí)間遺忘的缺陷,Bai 等學(xué)者[18]提出了具有時(shí)序處理能力的時(shí)間卷積網(wǎng)絡(luò)(Temporal Convolutional Network,TCN)。 膨脹因果卷積為TCN 的主要組成部分,其結(jié)構(gòu)如圖1 所示。
圖1 膨脹因果卷積示意圖Fig. 1 Dilated causal convolution
由圖1 可知,膨脹因果卷積具有嚴(yán)格的時(shí)間約束,這是由于因果卷積的應(yīng)用使得某一時(shí)刻的特征只能觀察到該時(shí)刻之前的數(shù)據(jù)[19]。 相比于常規(guī)卷積,因果卷積防止了未來信息的泄漏。 單純的因果卷積受限于卷積核大小,難以有效提取全局特征。常規(guī)CNN 通過引入pooling 層來增加感受野,然而pooling 層會造成一定的信息損失。 因此Chen 等學(xué)者[20]提出膨脹卷積,通過對卷積時(shí)的輸入間隔采樣來增加感受野。 采樣率、即膨脹率,指的是kernel的間隔數(shù)量(標(biāo)準(zhǔn)的CNN 中膨脹率為1)。 通常,膨脹率隨著深度模型的層數(shù)加深而增大,因此膨脹卷積使得感受野大小隨著層數(shù)呈指數(shù)型增長。 除膨脹因果卷積外,TCN 使用WeightNorm和Dropout來正則化網(wǎng)絡(luò),并且對不同卷積層進(jìn)行殘差連接以更好地對網(wǎng)絡(luò)進(jìn)行訓(xùn)練。
為了解決模型難以動態(tài)關(guān)注流數(shù)據(jù)的關(guān)鍵識別區(qū)域的問題,本文提出了掩碼時(shí)間注意力機(jī)制。 此外,考慮到識別正確的概率隨時(shí)間推移而增加,本文提出了置信度損失函數(shù)。
為了使模型能夠自適應(yīng)地關(guān)注不同長度數(shù)據(jù)的關(guān)鍵特征區(qū)域,利用有限的參數(shù)提取更有效的特征,本文為序列數(shù)據(jù)的早期分類設(shè)計(jì)了基于掩碼的時(shí)間注意力機(jī)制。 常規(guī)的注意力機(jī)制對于所有時(shí)刻的特征計(jì)算其注意力分布,然而對于早期分類持續(xù)輸入的序列數(shù)據(jù),在某一時(shí)刻只能觀察到該時(shí)刻之前的特征。 因此,本文將掩碼引入常規(guī)的時(shí)間注意力機(jī)制以防止未來信息的泄露。 具體的掩碼時(shí)間注意力過程如圖2 所示。
圖2 掩碼時(shí)間注意力機(jī)制結(jié)構(gòu)圖Fig. 2 Masked time attention mechanism
首先,本文將輸入特征經(jīng)過線性層以及Tanh 激活函數(shù)變換得到不同時(shí)刻的注意力分布,其大小為1 ×T(T為完整序列數(shù)據(jù)的長度)。 然后,本文對注意力分?jǐn)?shù)進(jìn)行擴(kuò)充,對其重復(fù)T次、并拼接在一起,得到一個(gè)大小為T×T的注意力矩陣。 同時(shí),輸入特征也被采取同樣的擴(kuò)充,得到大小為C×T×T的特征矩陣(C為特征圖的通道數(shù))。 接下來,將注意力矩陣的上三角填充為負(fù)無窮(該步驟簡稱為掩碼,即圖2 中的Mask 操作),使得注意力矩陣經(jīng)過SoftMax操作后上三角的注意力值為0。 這表示在t時(shí)刻,模型只會關(guān)注t時(shí)刻之前的時(shí)刻特征。 此后,對特征矩陣和經(jīng)過掩碼的注意力矩陣進(jìn)行逐元素乘法,得到不同時(shí)刻的動態(tài)特征(大小為C×T×T)。最后,本文對該動態(tài)特征使用平均池化得到C×T的融合特征,每個(gè)時(shí)刻的特征都是由該時(shí)刻前的特征通過加權(quán)相應(yīng)的注意力權(quán)重得到。
將提出的掩碼時(shí)間注意力機(jī)制嵌入TCN 中,得到基于掩碼時(shí)間注意力的TCN 網(wǎng)絡(luò),該網(wǎng)絡(luò)結(jié)構(gòu)如圖3 所示。 首先,完整的序列數(shù)據(jù)被輸入到多個(gè)時(shí)域卷積塊提取出局部時(shí)序特征。 然后,網(wǎng)絡(luò)利用掩碼時(shí)間注意力機(jī)制對不同時(shí)刻的局部特征進(jìn)行動態(tài)加權(quán),輸出各時(shí)刻的動態(tài)融合特征。 最后,這些動態(tài)特征通過線性層以及SoftMax函數(shù)得到分類概率。
圖3 基于掩碼時(shí)間注意力機(jī)制的時(shí)間卷積網(wǎng)絡(luò)結(jié)構(gòu)圖Fig. 3 The architecture of temporal convolutional network based on masked time attention mechanism
考慮到當(dāng)分類器觀察到更多的信息時(shí)應(yīng)該對正確的活動類別有更大的影響,本文引入了對正確類別分類概率隨時(shí)間的約束,即隨著數(shù)據(jù)長度的增加,正確的類別輸出更高的概率分?jǐn)?shù)。
具體地,本文設(shè)計(jì)了一個(gè)置信度損失,該損失定義為:
其中,θ為模型的所有參數(shù),Lp(θ) 、Lc(θ) 分別表示常規(guī)的交叉熵?fù)p失函數(shù)和本文設(shè)計(jì)的違背時(shí)間約束的懲罰損失,對此求得的數(shù)學(xué)定義見如下公式:
其中,表示分類器輸入第i個(gè)樣本的前t個(gè)數(shù)據(jù)得到的輸出分類概率,N為訓(xùn)練集的樣本總數(shù)。為了便于理解,本文在圖4 中對該損失函數(shù)做進(jìn)一步說明。
圖4 正確類別分類概率隨時(shí)間變化曲線Fig. 4 The classification probability curve of correct class regard to time
圖4 繪制了一個(gè)樣本的正確類別分類概率隨時(shí)間變化的示意圖。 圖4 中,在ta時(shí)刻之前,概率P一直單調(diào)遞增,該現(xiàn)象符合正確類別概率隨時(shí)間遞增的約束。 因此,本文提出的置信度損失不對其進(jìn)行懲罰,即此時(shí)li為0。 在ta時(shí)刻之后,正確類別的分類概率開始下降。 例如在tb時(shí)刻,正確類別的概率低于其在tb時(shí)刻前的最大正確類別概率(),這不滿足本文提出的置信度約束。 因此,該樣本在tb時(shí)刻的損失通過tb時(shí)刻之前的最大正確類別概率減去tb時(shí)刻的正確類別概率計(jì)算得到,具體參見式(4)。Lc(θ) 的設(shè)計(jì)將正確類別的檢測分?jǐn)?shù)限制為隨著活動的進(jìn)展而單調(diào)地不減少。
本節(jié)對具體的訓(xùn)練及測試流程進(jìn)行介紹,設(shè)計(jì)研發(fā)過程如圖5 所示。
圖5 訓(xùn)練及測試過程Fig. 5 Training and testing process
在訓(xùn)練階段,利用訓(xùn)練集數(shù)據(jù)對提出的基于掩碼時(shí)間注意力機(jī)制的時(shí)間卷積網(wǎng)絡(luò)進(jìn)行訓(xùn)練。 隨后,將訓(xùn)練集的所有序列數(shù)據(jù)輸入到之前訓(xùn)練的模型中,得到所有樣本不同時(shí)刻的分類概率。 利用這些分類概率,采用Sharma 等學(xué)者[14]提出的退出規(guī)則計(jì)算出該數(shù)據(jù)集的退出閾值β。
在測試階段,將測試數(shù)據(jù)隨時(shí)間逐漸輸入到訓(xùn)練好的模型中。 在t時(shí)刻,將長度為t的數(shù)據(jù)輸入到模型得到該時(shí)刻的分類概率,當(dāng)該分類概率的最大值大于閾值β時(shí),則停止繼續(xù)輸入更多的數(shù)據(jù),將t時(shí)刻的分類結(jié)果作為該樣本的分類結(jié)果,并將t時(shí)刻作為提前退出的時(shí)刻。 通過該分類結(jié)果和該退出時(shí)刻來計(jì)算準(zhǔn)確率及早期性。 如果t時(shí)刻的分類概率的最大值小于閾值β,則繼續(xù)輸入數(shù)據(jù),重復(fù)測試過程,直至不能再觀測到任何數(shù)據(jù)。 保留最后時(shí)刻的分類結(jié)果作為該樣本的分類結(jié)果,且該樣本的早期性為1。
為了驗(yàn)證提出方法的有效性,本文采用了公開的UCR[21]存儲庫提供的單變量數(shù)據(jù)集,從其中選取了不重復(fù)的8 個(gè)數(shù)據(jù)集。 UCR 存儲庫中的序列數(shù)據(jù)從諸多現(xiàn)實(shí)應(yīng)用采集而來,包括電氣設(shè)備監(jiān)控?cái)?shù)據(jù)、心電圖數(shù)據(jù)、動作識別數(shù)據(jù)以及其他傳感器數(shù)據(jù)等。 UCR 存儲庫依據(jù)一定的規(guī)則將這些數(shù)據(jù)集劃分了訓(xùn)練集和測試集,并對數(shù)據(jù)進(jìn)行了歸一化。
實(shí)驗(yàn)采用的深度學(xué)習(xí)框架Pytotch1.9.0,所使用的硬件環(huán)境為NVIDIA RTX 3080 GPU。 實(shí)驗(yàn)中,使用Adam 優(yōu)化器對模型參數(shù)進(jìn)行訓(xùn)練,學(xué)習(xí)率衰減為原來的一半。 采用訓(xùn)練集損失最小的模型作為測試模型。 每個(gè)模型的訓(xùn)練迭代次數(shù)設(shè)置為200。 所有模型均使用了3 個(gè)時(shí)域卷積模塊,卷積核尺寸為3,隱藏層的通道數(shù)為64。 根據(jù)經(jīng)驗(yàn),本文將2.2 節(jié)中式(1)中的參數(shù)μ設(shè)置為6。 為了衡量早期分類性能,研究使用準(zhǔn)確率和早期性的調(diào)和平均值(harmonic mean,HM) 作為評價(jià)指標(biāo),HM[21]具體定義為:
據(jù)式(6)可知,HM的值隨早期分類性能的提升而增加。
為驗(yàn)證本文提出方法的有效性,將本文提出的模型與ECLN[22]、 ETMD[14]、EARLIE[23]進(jìn)行對比,4種方法在8 個(gè)數(shù)據(jù)集上的測試結(jié)果見表1。
表1 在8 個(gè)數(shù)據(jù)集上的對比實(shí)驗(yàn)結(jié)果Tab. 1 The comparative experimental results on 8 datasets
從表1 可以看出,對比其他3 種方法,本文提出的模型在8 個(gè)數(shù)據(jù)集上均取得了最優(yōu)的早期分類結(jié)果,證明了本文提出方法的先進(jìn)性。
為了分別驗(yàn)證本文提出的掩碼時(shí)間注意力機(jī)制以及置信度損失函數(shù)的有效性,本節(jié)分別對這2 個(gè)部分進(jìn)行消融。 本文設(shè)置的基線模型為去除了掩碼時(shí)間注意力機(jī)制的時(shí)間卷積網(wǎng)絡(luò),并使用經(jīng)典的交叉熵?fù)p失函數(shù)對模型進(jìn)行訓(xùn)練。 首先,為了證明提出的掩碼時(shí)間注意力機(jī)制的效果,本文將該模塊添加到基線模型進(jìn)行第一個(gè)消融實(shí)驗(yàn)。 該實(shí)驗(yàn)在8 個(gè)數(shù)據(jù)集上實(shí)驗(yàn)結(jié)果見表2 第2、3 列。 其次,本文將交叉熵?fù)p失函數(shù)替換為提出的置信度損失函數(shù),以證明該損失函數(shù)的有效性。 該實(shí)驗(yàn)結(jié)果見表2 第3、4 列。
表2 在8 個(gè)數(shù)據(jù)集上的消融實(shí)驗(yàn)結(jié)果Tab. 2 The ablation experimental results on 8 datasets
觀察表2 第2、3 列,相比于基線方法,添加了掩碼時(shí)間注意力的基線方法分別將8 個(gè)數(shù)據(jù)集上的HM分?jǐn)?shù)提高了1%,5.48%,10.56%,3.57%,0.51%,2.33%,1.81%,0.02%。 因此,本文提出的掩碼時(shí)間注意力機(jī)制顯著提升了模型的早期分類性能。
為了進(jìn)一步說明添加了掩碼時(shí)間注意力的模型能實(shí)現(xiàn)更有效的分類,本文在圖6 中繪制了Synthetic Control 數(shù)據(jù)集的分類準(zhǔn)確性隨數(shù)據(jù)長度變化的結(jié)果。 圖6 中,本文提出的掩碼時(shí)間注意力機(jī)制提高了幾乎所有長度的數(shù)據(jù)的分類性能。
圖6 注意力機(jī)制對Synthetic Control 數(shù)據(jù)集的不同長度數(shù)據(jù)準(zhǔn)確率的影響Fig. 6 The effect of the attention mechanism on the accuracy of varied-length data on the Synthetic Control dataset
觀察表2 的第3、4 列,用本文提出的置信度損失替換經(jīng)典的交叉損失函數(shù)后,在8 個(gè)數(shù)據(jù)集上的HM分?jǐn)?shù)分別提高了1.67%,1.76%,3.3%,3.1%,2.13%,4.72%,4.83%,0.39%。 這表明,使用了本文提出的置信度損失函數(shù)訓(xùn)練模型使得模型的早期分類性能得到了顯著的提升。
本文提出了基于掩碼時(shí)間注意力機(jī)制的時(shí)間卷積網(wǎng)絡(luò),提高了模型對不同長度數(shù)據(jù)的自適應(yīng)能力。此外,本文通過設(shè)計(jì)的置信度損失函數(shù)促使正確類別的概率隨信息量的增加遞增,有利于設(shè)置更合理的退出閾值。 在8 個(gè)公開數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果證明了本文提出的方法的有效性。 然而,固定閾值難以適應(yīng)難度程度不同的數(shù)據(jù),該問題將在未來進(jìn)行更深入的探討研究。