孔銳,莊峻賢,梁冠燁
(1.暨南大學(xué) 智能科學(xué)與工程學(xué)院,廣東 珠海 519070;2.暨南大學(xué) 信息科學(xué)技術(shù)學(xué)院,廣東 廣州 510632)
2019年底開始新型冠狀病毒肺炎(COVID-19)疫情在全球大規(guī)模暴發(fā)并危害著人類的健康,病毒檢測必不可少[1],而CT檢查為其主要輔助診斷手段。CT檢查由于其便捷性、可重復(fù)性、陽性率高等優(yōu)點(diǎn),在COVID-19患者的影像學(xué)輔助檢測中起到重要作用[2]。近年來,人工智能(AI)在醫(yī)學(xué)各個(gè)領(lǐng)域發(fā)展迅猛,基于AI的CT影像輔助診斷系統(tǒng)具有重要臨床意義,例如甲狀腺結(jié)節(jié)CT診斷AI系統(tǒng)、肺結(jié)節(jié)良惡性疾病篩查等[3],而對COVID-19的CT輔助檢測的AI方向探索仍然是當(dāng)下研究的熱點(diǎn)。
深度學(xué)習(xí)技術(shù)尤其是卷積神經(jīng)網(wǎng)絡(luò)在醫(yī)學(xué)影像上的應(yīng)用越來越廣泛,基于AI的COVID-19的CT輔助檢測,按照是否借助分割技術(shù)可以分為直接進(jìn)行分類任務(wù)訓(xùn)練以及借助圖像分割技術(shù)進(jìn)行分類任務(wù)訓(xùn)練[4]。在借助圖像分割技術(shù)的分類任務(wù)訓(xùn)練中,通常將圖像分割作為預(yù)處理部分,Wang等[5]使用肺部分割模型3D UNet++對CT圖像進(jìn)行預(yù)處理,并對預(yù)處理圖像進(jìn)行隨機(jī)翻轉(zhuǎn)、平移、放大等數(shù)據(jù)增強(qiáng),最后用ResNet-50模型對處理后的圖像進(jìn)行分類。Wang等[6]提出了一種專門為COVID-19設(shè)計(jì)的深度卷積神經(jīng)網(wǎng)絡(luò)COVID-Net,對包含13 870個(gè)病人的13 975張胸部X射線檢查的COVIDx數(shù)據(jù)集進(jìn)行三分類任務(wù)訓(xùn)練,對比VGG-19和Resnet-50的效果,發(fā)現(xiàn)在較少的模型參數(shù)和乘加累積操作數(shù)(multiplyaccumulate operations,MACs)的情況下COVIDNet具有更高的準(zhǔn)確率。
有監(jiān)督學(xué)習(xí)依賴高質(zhì)量人工標(biāo)注的標(biāo)簽信息,一方面,在醫(yī)學(xué)領(lǐng)域中需要專業(yè)的人員去對訓(xùn)練樣本進(jìn)行標(biāo)注,這需要耗費(fèi)大量的時(shí)間和精力;另一方面,在訓(xùn)練集中樣本數(shù)量較少的情況下,通過有監(jiān)督學(xué)習(xí)訓(xùn)練的模型缺少泛化性。然而,自監(jiān)督學(xué)習(xí)依賴的是樣本數(shù)據(jù)來提供監(jiān)督信息,并不依賴人工標(biāo)注的標(biāo)簽信息,能夠有效提高模型的泛化性[7]。自監(jiān)督學(xué)習(xí)的其中一種方法是令網(wǎng)絡(luò)學(xué)習(xí)到關(guān)于目標(biāo)的一個(gè)良好的特征表示,最大化正負(fù)樣本的距離,而且負(fù)樣本的數(shù)量越多自監(jiān)督學(xué)習(xí)的效果越好,譬如He等[8]提出動量對比(momentum contrast,MoCo)用于無監(jiān)督表征學(xué)習(xí),能很好地用于下游任務(wù);Chen等[9]提出表征對比學(xué)習(xí)框架SimCLR,在特征表示和對比損失之間加入了非線性變換等使模型性能極大提升。
多任務(wù)學(xué)習(xí)(multi-task learning,MTL)中假設(shè)不同任務(wù)數(shù)據(jù)分布之間存在一定的相似性[10],在分類任務(wù)中引入相關(guān)任務(wù)(related task)訓(xùn)練模型,可以有效利用相關(guān)任務(wù)訓(xùn)練信號包含的領(lǐng)域特定信息來提高模型的泛化能力[11]。MTL有很多形式,如聯(lián)合學(xué)習(xí)(joint learning),自主學(xué)習(xí)(learning to learn),借助輔助任務(wù)學(xué)習(xí)(learning with auxiliary tasks)等,這些只是其中一些別名[12]。本研究利用模型融合技術(shù)將MoCo[8,13]引入到借助輔助任務(wù)學(xué)習(xí)中并在此之上提出交替訓(xùn)練模式(alternate training mode,ALTM)算法。
MoCo[8,13]是一種無監(jiān)督方法,該方法編碼器由兩部分組成,分別是特征表示編碼器fq和動量編碼器fk。對一個(gè)訓(xùn)練樣本分別進(jìn)行兩種不同的數(shù)據(jù)增強(qiáng)生成一對查詢(query)樣本xq和鍵(key)值樣本xk。查詢樣本和鍵值樣本分別通過特征表示編碼器和動量編碼器生成一對編碼查詢q=fq(xq)和編碼鍵k=fk(xk)。特征表示編碼器fq和動量編碼器fk網(wǎng)絡(luò)參數(shù)的更新方式有所不同(圖1),θq和θk分別為fq和fk的參數(shù),通過對比損失分別對比編碼查詢與作為負(fù)樣本的字典的相似性和編碼查詢與作為正樣本的編碼鍵的相似性,并利用反向傳播進(jìn)行更新θq,θk在θq更新后通過動量跟隨進(jìn)行更新(其中m為動量系數(shù),通常取0.999):
圖1 MoCo更新過程Figure 1 The update process of MoCo
為了增加對比學(xué)習(xí)中負(fù)樣本的個(gè)數(shù),MoCo創(chuàng)建一個(gè)樣本容量大并且編碼鍵相差不大的字典用于訓(xùn)練特征表示編碼器,字典是數(shù)據(jù)樣本的動態(tài)序列。由于θk依靠動量跟隨進(jìn)行更新,且動量系數(shù)m接近于1,這種更新方式令字典在增加負(fù)樣本個(gè)數(shù)的同時(shí),字典中由同一樣本產(chǎn)生的鍵值的變化相差不大。圖2為字典Di:{k0,k1,……,kK}更新示意圖,字典長度為K+1,簡單起見一般設(shè)置為Batch大小的倍數(shù)。當(dāng)一個(gè)Batch的鍵值樣本通過動量編碼器生成鍵值序列}后,字典Di淘汰最早一批長度為大bs的鍵值序列并添加新生成的鍵值序列組成字典Di+1:,其中bs為Batch大小。除此之外,MoCo還提出了shuffling BN等技巧[8],使得該方法能夠很好地解決負(fù)樣本特征由于模型更新導(dǎo)致的前后不一致問題,并有效增加負(fù)樣本的個(gè)數(shù)。經(jīng)過MoCo訓(xùn)練的模型能夠很好地用于下游任務(wù)。
圖2 MoCo動態(tài)序列Figure 2 The dynamic queue of MoCo
本文基于模型融合方法將MoCo[8,13]引入到借助輔助任務(wù)學(xué)習(xí)中,在有監(jiān)督學(xué)習(xí)的過程中使分類任務(wù)受益于輔助任務(wù)。在此基礎(chǔ)上,本文提出一種有利于該融合模型的訓(xùn)練模式:交替訓(xùn)練模式(alternate training mode,ALTM),并通過實(shí)驗(yàn)驗(yàn)證該算法在分類任務(wù)中的可行性。同時(shí)探索了ALTM中輔助任務(wù)在損失函數(shù)中的權(quán)重和字典的長度之間的關(guān)系。
本研究提出的融合模型由兩個(gè)DenseNet[14](以下簡稱Model0和Model1)組成,融合模型(圖4)在兩個(gè)DenseNet的原本結(jié)構(gòu)(圖3)DT Block i之間加入了注意力(attention)機(jī)制,超參數(shù)為λ。Model0和Model1在Combined DT Block 1前的相關(guān)層共享參數(shù)。
圖3 Ensemble DenseNet結(jié)構(gòu)Figure 3 The struction of Ensemble DenseNet
圖4 Ensemble DenseNet結(jié)構(gòu)Figure 4 The struction of Ensemble DenseNet
基于融合模型的ALTM分為兩個(gè)階段。
(1)借助輔助任務(wù)學(xué)習(xí)階段:圖5為借助輔助任務(wù)學(xué)習(xí)階段中,輸入一批訓(xùn)練樣本{I0,I1,…,Ibs-1}的示意圖。與[8]的處理類似,本階段Model1中參數(shù)θ1不具有梯度,參數(shù)的更新依靠動量跟隨。一批訓(xùn)練樣本分別進(jìn)行兩次不同的數(shù)據(jù)增廣后得到查詢樣本和鍵值樣本對并輸入到融合模型中。
圖5 借助輔助任務(wù)學(xué)習(xí)Figure 5 Learning with Auxiliary Tasks
該階段的前向傳播的步驟主要如下:查詢樣本通過 Model0的卷積池化層和融合模型Combined DT Block1~4得到查詢樣本的高層特征圖,高層特征圖通過Model0的分類層和特征表示線性層分別輸出分類概率預(yù)測和查詢樣本的特征表示;而鍵值樣本通過Model1的卷積池化層和DT Block1~4得到鍵值樣本的高層特征圖,最后通過Model1的特征表示線性層輸出鍵值樣本的特征表示。
該步驟的反向傳播的步驟主要如下:對于查詢樣本的特征表示來說,鍵值樣本的特征表示作為正樣本而動態(tài)字典中儲存的特征表示作為負(fù)樣本,利用InfoNCE計(jì)算輔助任務(wù)損失。同時(shí),根據(jù)Model1分類層輸出的分類概率預(yù)測和查詢樣本的標(biāo)簽計(jì)算分類損失。最終計(jì)算多任務(wù)學(xué)習(xí)損失函數(shù)計(jì)算梯度對Model0的參數(shù)θ0進(jìn)行更新,而Model1參數(shù),依靠動量系數(shù)進(jìn)行更新。最后,利用本次計(jì)算的鍵值樣本特征表示更新動態(tài)字典。
本階段為借助輔助任務(wù)學(xué)習(xí),在進(jìn)行學(xué)習(xí)分類任務(wù)的同時(shí),希望融合模型中能夠?qū)W習(xí)良好的特征表示。仿照MoCo v2[13]和SimCLR中的設(shè)計(jì),在特征表示層前增加一層線性層優(yōu)化網(wǎng)絡(luò)的性能。本階段的使用的多任務(wù)損失函數(shù)為(γ為輔助任務(wù)損失函數(shù)的權(quán)重):
反向傳播時(shí),Model0的參數(shù)θ0依靠多任務(wù)損失函數(shù)(式1,γ>0)進(jìn)行更新,而Model1的參數(shù)θ1依靠θ0和動量系數(shù)m進(jìn)行更新:
由于Model0和Model1在Combined DT Block 1前的相關(guān)層共享參數(shù),在Model1參數(shù)進(jìn)行更新時(shí),這部分層參數(shù)不參與動量跟隨更新,而是與Model1相關(guān)層參數(shù)保持參數(shù)共享。
(2)分類任務(wù)訓(xùn)練階段:本階段為分類任務(wù)學(xué)習(xí),輸入一批訓(xùn)練樣本{I0,I1,…,Ibs-1},分別進(jìn)行兩次不同的數(shù)據(jù)增廣后得到查詢樣本和鍵值樣本,將查詢樣本輸入到融合模型中,融合模型利用多任務(wù)損失函數(shù)(式1,γ=0)進(jìn)行訓(xùn)練,而鍵值樣本通過Model1的前向傳播,在Model1的特征表示層輸出鍵值樣本的特征表示,最后利用鍵值樣本的特征表示對動態(tài)字典進(jìn)行更新。
為了避免Model1更新過快導(dǎo)致鍵值序列中的特征表示前后變化較大,一般在ALTM中,將序列長度定為upe倍的訓(xùn)練集大小,進(jìn)行upe個(gè)epoch的多任務(wù)學(xué)習(xí)階段后進(jìn)行1個(gè)epoch的分類任務(wù)訓(xùn)練階段,然后以這個(gè)順序不斷進(jìn)行循環(huán)訓(xùn)練。顯然,upe和γ的取值在ALTM中起著至關(guān)重要的作用。
本實(shí)驗(yàn)基于pytorch1.7cuda10.2進(jìn)行,使用NVIDIA Corporation GP104GL[Tesla P4]內(nèi)存為8G的GPU。評價(jià)指標(biāo)選用機(jī)器學(xué)習(xí)常用評價(jià)指標(biāo):精確率(precision)、召回率(recall)、F1-score、準(zhǔn)確率(accuracy)和曲線下面積(area under curve,AUC)。評價(jià)指標(biāo)的計(jì)算如圖6所示。交替訓(xùn)練中,Model0和Model1使用獨(dú)立的Adam優(yōu)化器,學(xué)習(xí)率為1e-4,學(xué)習(xí)率策略采用CosineAnnealingLR,T_max=10。訓(xùn)練參數(shù)設(shè)置:epoch=100,batch_size=16,分類損失函數(shù)采用交叉熵?fù)p失函數(shù),動態(tài)字典長度為upe倍的訓(xùn)練集大小,動態(tài)字典初始化于MoCo相同,使用標(biāo)準(zhǔn)化后的標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)序列。
圖6 評價(jià)指標(biāo)Figure 6 Evaluation metrics
本實(shí)驗(yàn)數(shù)據(jù)集采用COVID-CT-Dataset[15],訓(xùn)練集、驗(yàn)證集、測試集中樣本個(gè)數(shù)分別為425、118、203(表1),包含兩類樣本(COVID 和NonCOVID)。訓(xùn)練樣本來自從2020年1月19日到2020年3月25日在medRxic和bioRxic中關(guān)于COVID-19的預(yù)印本中的CT圖像。COVID-CT數(shù)據(jù)集部分樣本如圖7所示,a~d為COVID樣本,e~h為NonCOVID樣本。受實(shí)驗(yàn)結(jié)果[16]的啟發(fā),實(shí)驗(yàn)選用兩個(gè) DenseNet169 構(gòu)成融合DenseNet169(Ensemble DenseNet169,ED169)模型,由于ALTM算法需要融合模型中Model0和Model1參數(shù)初始化相同,實(shí)驗(yàn)中無論單一模型DenseNet169還是融合模型ED169中的Model0和Model1,均引入ImageNet預(yù)訓(xùn)練模型。
圖7 COVID-CT數(shù)據(jù)集示例Figure 7 Example CT images from the COVID-CT-Dataset
表1 COVID-CT數(shù)據(jù)集Table 1 COVID-CT-Dataset n
表2結(jié)果顯示,無論λ取值為多少,從5個(gè)評價(jià)指標(biāo)的結(jié)果來看,融合模型ED169整體表現(xiàn)比DenseNet169更優(yōu)。
表2 融合模型訓(xùn)練結(jié)果Table 2 The result of Ensemble DenseNet169
融合模型雖然利用了參數(shù)共享方法,但是總體來看參數(shù)量大致為單模型的兩倍,并且融合模型加入了注意力機(jī)制,組成更有效的結(jié)構(gòu),所以即使不使用ALTM時(shí)模型表現(xiàn)也比DenseNet169更優(yōu)。
使用ALTM后的融合模型的性能和未使用ALTM的融合模型性能對比實(shí)驗(yàn)結(jié)果如表3所示,可見無論λ取值為多少,5個(gè)評價(jià)指標(biāo)的結(jié)果大多數(shù)都取得一定程度上的提升,證明ALTM對融合模型各方面都有實(shí)質(zhì)性的提升。
表3 ALTM 實(shí)驗(yàn)結(jié)果Table 3 The result of ensemble DenseNet169 trained by ALTM
融合模型由Model0和Model1構(gòu)成,當(dāng)λ越大時(shí),Model0的對融合模型的影響越大。ALTM在ED169中有所提升,但是在λ=0.95時(shí)提升較小,這是由于Model0對融合模型的影響過大,從而削弱了ALTM的作用。此外,研究發(fā)現(xiàn)對于每個(gè)γ值來說,upe值較大時(shí)模型泛化能力有時(shí)候雖然不是最優(yōu),但是也有一定的提升,研究認(rèn)為這得益于在表示學(xué)習(xí)時(shí)負(fù)樣本個(gè)數(shù)的增加。同時(shí),對于驗(yàn)證集來說,融合模型與單一模型相比有時(shí)候出現(xiàn)相差無幾甚至有所下降的情況,這可能是由于數(shù)據(jù)集樣本較少導(dǎo)致融合模型過擬合,采用ALTM算法后極大改善這種情況。
針對每一個(gè)不同取值λ的融合模型ED169,在upe固定的情況下選取性能表現(xiàn)最佳的模型所對應(yīng)的γ值。
為了進(jìn)一步探討超參數(shù)upe和γ對ALTM的影響,本階段實(shí)驗(yàn)采取以下設(shè)置:
1)upe=1,2,3,4,5,6;
2)γ=0.01,0.1,0.2,0.3,0.4,0.5,0.6;
3)λ=0.5,0.6,0.7,0.8,0.9,0.95。
如圖4所示,圖中縱坐標(biāo)為在λ為固定值,而upe取不同值時(shí),訓(xùn)練性能最好的模型所對應(yīng)的γ值??梢钥闯?,無論λ取值為多少,結(jié)果呈中間高兩邊低的趨勢。
在使用ALTM后,雖然融合模型中超參數(shù)λ不同導(dǎo)致融合模型最終的結(jié)果有所差異,但是可以看出λ影響的是模型本身的表現(xiàn)能力,對ALTM性能影響更重要的是upe和γ的取值關(guān)系。研究認(rèn)為,當(dāng)upe取值較小的時(shí)候,鍵值序列長度較小,負(fù)樣本數(shù)目較少,特征學(xué)習(xí)任務(wù)的難度較小,所以γ值較小模型的性能表現(xiàn)會更佳。在鍵值序列中負(fù)樣本的數(shù)目越多,負(fù)樣本提供輔助分類任務(wù)的相關(guān)信息量就越大,越有利于模型進(jìn)行分類任務(wù)的訓(xùn)練,所以當(dāng)upe值變大時(shí),適當(dāng)提高γ值能提升融合模型ED169的泛化能力。當(dāng)upe值較大的時(shí)候,鍵值序列長度較大,負(fù)樣本數(shù)目較多,但是由于upe值和Model1更新次數(shù)息息相關(guān),所以此時(shí)鍵值序列中負(fù)樣本的特征表示與upe值較小時(shí)相比變化較大,特征學(xué)習(xí)任務(wù)的難度較大。如果此時(shí)γ值較大的話,會導(dǎo)致模型偏向于學(xué)習(xí)更復(fù)雜的特征表示,并不利于分類任務(wù),所以理應(yīng)選取一個(gè)較小的γ值。
圖8 upe和γ對ALTM的影響Figure 8 Effects of upe andγon ALTM
本文在單GPU上利用并探究AI的方法對COVID-19的CT圖像進(jìn)行圖像識別,并基于深度學(xué)習(xí)提出一種交替訓(xùn)練的融合模型。He等[16]運(yùn)用MoCo對深度卷積神經(jīng)網(wǎng)絡(luò)DenseNet169在COVID-CT和LUNA 16的1000個(gè)樣本組成的數(shù)據(jù)集上進(jìn)行自監(jiān)督學(xué)習(xí),在COVID-CT數(shù)據(jù)集的測試集上準(zhǔn)確率(Accuracy):0.86,F(xiàn)1:0.85,AUC:0.94。由于在小數(shù)據(jù)集上對深度學(xué)習(xí)模型進(jìn)行訓(xùn)練容易導(dǎo)致模型過擬合,Loey M[17]等使用深度遷移學(xué)習(xí)模型的經(jīng)典數(shù)據(jù)增強(qiáng)技術(shù)以及條件生成對抗網(wǎng)絡(luò)(CGAN)生成更多的胸部CT圖像,用ResNet50對COVID進(jìn)行分類,模型在COVIDCT測試集上最好的分類表現(xiàn):精確率(Precision):0.876,召回率(Precision):0.814,F(xiàn)1-score:0.844,準(zhǔn)確率:0.829(指標(biāo)由測試集TP,TN,F(xiàn)P,F(xiàn)N計(jì)算得)。本研究提出ALTM算法,融合模型ED169在COVID-CT數(shù)據(jù)集的測試集上的精確率:0.86,召回率:0.91,F(xiàn)1-score:0.89,準(zhǔn)確率:0.88,AUC:0.92,與之前的相關(guān)研究相比,本研究提出的算法整體結(jié)果較優(yōu),能夠良好地解決模型在小數(shù)據(jù)集上的過擬合問題,提升模型的泛化性能。
在多任務(wù)學(xué)習(xí)中,引入特征表示來增加分類任務(wù)中模型的泛化能力是分類任務(wù)中常用的方法,有多種損失函數(shù)來最大化正負(fù)樣本之間的距離,但是如何選取負(fù)樣本和負(fù)樣本的個(gè)數(shù)一直是研究的重點(diǎn)。本研究結(jié)果顯示,利用模型融合技術(shù)將MoCo引入到多任務(wù)學(xué)習(xí)中也能有效提高模型性能。本研究特別提出了一種適用于融合模型的訓(xùn)練方法ALTM,通過對比DenseNet169,ED169和ED169+ALTM 3種情況在COVID19-CT數(shù)據(jù)集上的實(shí)驗(yàn),結(jié)果證明了ALTM的有效性。同時(shí),對影響ALTM超參數(shù)之間的內(nèi)在關(guān)系進(jìn)行實(shí)驗(yàn)和分析,深度分析了ALTM對融合模型訓(xùn)練的影響。
本實(shí)驗(yàn)仍有不足之處,在融合模型上ED169僅為兩個(gè)DenseNet169,各個(gè)部分進(jìn)行加權(quán)平均,而在今后的研究中,可基于經(jīng)典模型例如ResNet,DenseNet等構(gòu)造出一個(gè)高效,分類能力強(qiáng)的融合模型結(jié)合ALTM。除此之外,COVID-CT數(shù)據(jù)集的樣本個(gè)數(shù)較少,今后在擴(kuò)大COVID-CT的訓(xùn)練樣本個(gè)數(shù)的情況下來繼續(xù)提高新型冠狀病毒輔助檢測能力。
作者貢獻(xiàn)聲明
孔銳:項(xiàng)目指導(dǎo)及負(fù)責(zé)人,指導(dǎo)論文寫作與修改;莊峻賢:提出研究思路和框架,設(shè)計(jì)實(shí)驗(yàn),撰寫和修改論文;梁冠燁:統(tǒng)計(jì)數(shù)據(jù),修改論文。
利益沖突聲明
本研究未受到企業(yè)、公司等第三方資助,不存在潛在利益沖突。