曹潤芝,韓 斌,劉嘎瓊
(江蘇科技大學 計算機學院,鎮(zhèn)江 212114)
世界衛(wèi)生組織公布肺炎是導致兒童死亡的主要原因之一[1],據(jù)估計每年160 萬5 歲以下兒童因患肺炎而死亡.肺炎病患的醫(yī)學影像需要具備專業(yè)知識的醫(yī)生進行評估,但在實際診斷中醫(yī)生由于缺少經(jīng)驗和視覺疲憊等主觀因素會造成誤診,因此準確高效地通過肺部影像進行輔助診斷對于肺炎的治療至關重要[2,3].面對醫(yī)療診斷過程決策難且不確定性高的問題,Chouhan等人[4]采用遷移學習建立肺炎檢測的深度學習框架,利用AlexNet、DenseNet121、Inception V3、GoogLeNet和ResNet18 網(wǎng)絡模型提取圖像特征,并使用集成學習方法將各網(wǎng)絡的輸出合并到一個預測向量中,使用多數(shù)表決權進行最終預測.潘麗艷等人[5]使用深度學習技術對肺區(qū)域病毒或細菌的病原學類型進行判斷,從而在臨床上規(guī)范兒童肺炎治療.Al-Antari 等人[6]提出一種基于YOLO 預測器的同步深度學習計算機輔助診斷系統(tǒng),能可靠地將新冠肺炎與其它呼吸系統(tǒng)疾病進行區(qū)分.何新宇等人[7]提出基于深度卷積神經(jīng)網(wǎng)絡的肺炎圖像識別模型,采用GoogLeNet Incepion V3 網(wǎng)絡模型進行特征提取,以及隨機森林分類器進行分類預測,最終在識別準確率和敏感度上實現(xiàn)明顯提高.
深度學習技術在圖像的分類[8]、檢測[9]等方面具有高精度的優(yōu)勢,能準確地提取醫(yī)學圖像的特征從而輔助疾病診斷.然而在現(xiàn)實中仍然存在以下問題:(1)訓練模型需要大量的優(yōu)質數(shù)據(jù),而醫(yī)療數(shù)據(jù)涉及病人隱私,容易造成數(shù)據(jù)孤島現(xiàn)象.(2)傳統(tǒng)的深度學習過程將各方數(shù)據(jù)集匯總到一起進行模型訓練,對于單機的算力需求較高且效率較低.(3)經(jīng)典殘差網(wǎng)絡忽略了圖像的通道特征且對批量大小具有高依賴性.
Mao 等人[10]論述聯(lián)邦學習具有不共享本地數(shù)據(jù)實現(xiàn)聯(lián)合訓練的優(yōu)勢,能夠解決深度學習過程中的醫(yī)療數(shù)據(jù)隱私泄漏問題.利用聯(lián)邦學習對醫(yī)學圖像進行分析,能為臨床醫(yī)生提供更好的輔助診斷工具[11].
針對以上問題,本文提出了一種由聯(lián)邦學習框架(federated learning,FL)、改進殘差網(wǎng)絡(residual networkgroup normalization,ResNet-GN)和壓縮激勵網(wǎng)絡(squeezeand-excitation network,SE)融合形成的FL-SE-ResNet-GN方法,將殘差網(wǎng)絡的批量歸一化方式(batch normalization,BN)替換為組歸一化方式(group normalization,GN),嵌入壓縮激勵網(wǎng)絡,在提取圖像深層特征的同時關注通道特征,最終融合聯(lián)邦學習框架進一步提升肺炎輔助診斷的效果和保護患者數(shù)據(jù)隱私.
基于卷積神經(jīng)網(wǎng)絡的模型通過不斷地增加卷積層和池化層組合的數(shù)量從而學習圖像更深層的特征,例如LeNet[12]、AlexNet[13]、VGG[14]等模型.然而當模型的層數(shù)不斷地增加,淺層網(wǎng)絡參數(shù)會逐漸趨于零,導致其無法更新,造成梯度消失現(xiàn)象.為了解決上述問題,2016年He 等人[15]提出了深度殘差網(wǎng)絡(residual network,ResNet)模型.其結構如圖1所示,通過在兩個權重層外部使用跳躍連接實現(xiàn)恒等映射,從而構成一個殘差塊.定義該殘差塊的輸入為x,模塊最終的整體輸出為H(x),則殘差模塊兩層權重層部分的殘差函數(shù)為F(x)=H(x)?x.其中跳躍連接不會引入新的參數(shù),一定程度上保證了訓練的效率.相較于殘差模塊的整體輸出H(x),計算通過優(yōu)化轉換后的殘差函數(shù)F(x)更容易.
圖1 經(jīng)典殘差模塊
壓縮激勵網(wǎng)絡[16]結構簡單,便于耦合到已有的深度學習網(wǎng)絡模型框架中.SE 模塊主要關注了通道之間的相關性,通過訓練得到各個通道不同的權重信息,含有主要特征的通道分配更高的權重,從而提高模型的表達能力.SE 模塊由3 部分組成:壓縮(squeeze)、激勵(excitation)和重分配(reweight).定義SE 模塊的輸入為W×H×C的特征,其中,W表示寬,H表示高,C表示通道數(shù);使用縮放參數(shù)減少通道個數(shù)從而降低計算量.
首先SE 模塊利用全局池化對輸入特征進行壓縮,將特征轉化為1×1×C的向量,使其獲得全局的感受野;
然后利用兩次全連接操作、ReLU 函數(shù)和縮放參數(shù)實現(xiàn)降維和升維操作.最后實現(xiàn)重分配,使用Sigmoid函數(shù)對激勵操作的輸出實現(xiàn)歸一化,從而得到對應通道的歸一化權重,并將權重與輸入特征相乘,實現(xiàn)了輸入特征的重標定,從而增強了有用特征的提取,提高特征提取的準確度.
聯(lián)邦學習[10]是一種特殊的分布式機器學習框架,與傳統(tǒng)的機器學習的區(qū)別在于缺少了中心服務器匯總數(shù)據(jù)的過程.聯(lián)邦學習采用客戶端-服務器架構如圖2所示,聯(lián)邦學習的組成部分為中心服務器和若干個客戶端,其每輪的訓練過程是客戶端首先從中心服務器下載聯(lián)邦共享模型和參數(shù),利用本地的數(shù)據(jù)進行迭代訓練,最后將該輪訓練得到的參數(shù)或模型上傳到中心服務器進行聚合,實現(xiàn)聯(lián)邦共享模型和參數(shù)的更新.
圖2 基于聯(lián)邦學習的客戶端-服務器架構
本文為了準確地提取肺部圖像更深層次的特征同時保護醫(yī)療數(shù)據(jù)的隱私性.首先,在殘差網(wǎng)絡的基礎上引入壓縮激勵網(wǎng)絡.然后,將殘差網(wǎng)絡中的批量歸一化方式轉換為組歸一化方式形成改進后的深度神經(jīng)網(wǎng)絡模型(SE-ResNet-GN).最后,將改進后的深度學習模型融合聯(lián)邦學習框架,運用聯(lián)邦平均算法實現(xiàn)分布式模型訓練以獲得最終的聯(lián)邦共享模型.
本文以經(jīng)典的ResNet-18 作為基礎網(wǎng)絡,總共由17 個卷積層和1 個全連接層組成,第一個卷積層使用64個7×7的卷積核,剩下的卷積層使用 3×3卷積核.
首先在殘差塊內部的輸出后嵌入激勵壓縮網(wǎng)絡,本文中令縮放參數(shù)r為16,輸入特征依次經(jīng)過兩次全連接層進行16 倍降維和升維操作,利用Sigmoid 激活函數(shù)獲得通道的歸一化權重并與通道特征相乘從而重標定輸入的特征.
其次,殘差塊使用BN 方式實現(xiàn)歸一化,在學習過程中,BN 方式對一個批量的數(shù)據(jù)計算得到均值和方差并通過滑動平均的方式獲得訓練的全局均值和方差,因此BN 方式對批量大小(batch size)的依賴程度較高,當數(shù)據(jù)批量較小時難以保證訓練效果,而較大批量的數(shù)據(jù)對于計算機的算力要求較高.因此本文使用組歸一化[17]方式替換經(jīng)典殘差塊內部的BN.改進后的網(wǎng)絡如圖3所示.GN 通過對通道進行分組操作,將卷積層的輸出特征圖的通道分為G組,G為16,圖3中表示為64//G,并計算每組的均值和方差,因此使用GN方式性能穩(wěn)定,對批數(shù)量的魯棒性更強.
圖3 融合SE 塊的改進殘差網(wǎng)絡
定義歸一化運算的輸入的特征圖x為[N,C,H,W],其中N表征批量數(shù),C表征通道數(shù),H、W表征高度和寬度,y與 β為學習的映射參數(shù),BN 保留輸入通道C的維度,其歸一化公式如下所示:
GN 進行歸一化首先將輸入x的通道劃分為G組,則每組包含C/G個通道數(shù),然后計算組內元素的均值和方差,公式如下所示:
基于聯(lián)邦學習的肺炎輔助診斷中對于病人數(shù)據(jù)的數(shù)量和質量要求較高,對于單一的醫(yī)院或者醫(yī)療機構而言本地的數(shù)據(jù)集數(shù)量較少而且類型單一,難以達到深度學習的要求.與此同時,醫(yī)療數(shù)據(jù)因涉及大量病人的隱私,形成大量的數(shù)據(jù)孤島.本文聯(lián)邦學習采用客戶端-服務器架構,將單個醫(yī)院或者醫(yī)療機構作為一個客戶端以及擁有足夠算力的可信第三方作為中心服務器,提供FL-SE-ResNet-GN 框架足夠的算力支持并進行數(shù)據(jù)傳輸,能夠讓客戶端在不暴露本地的數(shù)據(jù)的情況下進行深度學習的模型訓練,各個客戶端將本地訓練的參數(shù)上傳到中心服務器應用聯(lián)邦平均算法進行聚合.聯(lián)邦平均算法步驟如算法1 所示.
算法1.聯(lián)邦平均算法輸入:參與運算的客戶端比例:;客戶端:;學習率:;本地訓練批數(shù)量B:;本地迭代次數(shù):E;服務端擬合輪數(shù):.輸出:全局權重.C(0~1) client j(1≤j≤J) η b∈Bt∈T Step 1.初始化t=1 w0 Step 2.從 到T m←max(C·J,1)1) S t← m 2) 數(shù)量為的客戶端子集client j∈S ti∈E 3) 并行進行本地訓練,同時本地訓練輪數(shù).client j ωi,j,t=ωi?1,j,t?η??j,t(ωi?1,j,t;b)i∈E 4) 利用本地數(shù)據(jù)對下載的模型參數(shù)進行本地E 輪迭代更新:,client j ω j,t ωt+1←J∑j=1 5)獲得各的本地參數(shù),利用本地客戶端的數(shù)據(jù)量占所有參與訓練的客戶端的總數(shù)據(jù)量的比值得到聚合模型:nj n ω j,t
在聯(lián)邦學習過程中,各個客戶端在本地進行獨立的模型訓練,其采用的批處理數(shù)量與各自計算機算力相關,同時使用小批量數(shù)據(jù)在BN 中誤差較大,從而影響聚合的聯(lián)邦共享模型的準確度.因此,為了減少對批處理數(shù)量的依賴.本文使用聯(lián)邦學習框架與改進的SEResNet-GN 網(wǎng)絡模型相結合實現(xiàn)分布式訓練,當聯(lián)邦聚合的次數(shù)達到限制或者訓練模型收斂后將會結束訓練.具體的實現(xiàn)過程如圖4所示.圖中使用clienti和clientj表示不同的客戶端.
圖4 FL-SE-ResNet-GN 模型的訓練流程
改進的深度神經(jīng)網(wǎng)絡模型融合聯(lián)邦學習框架后的訓練算法如算法2 所示.
算法2.FL-SE-ResNet-GN輸入:本地數(shù)據(jù)集,服務端擬合總輪數(shù)Smax.w0{Data1,Data2,···,Datan}初始化:模型網(wǎng)絡結構,權重.輸出:收斂的FL-SE-ResNet-GN 模型.Step 1.客戶端從中心服務器下載初始的FL-SE-ResNet-GN 模型和初始參數(shù).
?
本文實驗的環(huán)境為Windows 10 操作系統(tǒng),Intel(R)CPU E5-1620 v3@3.5 GHz 3.5 GHz 處理器,8 GB NVIDIA Quadro M4000 顯卡,32 GB 內存,500 GB 硬盤,計算機語言為Python,實驗框架基于PyTorch 實現(xiàn).
本實驗所使用的數(shù)據(jù)集是來自2018年美國加州大學圣迭戈分校公開的Chest X-Ray Images 圖像數(shù)據(jù)集[18].該數(shù)據(jù)集由兩類圖像組成,分別是正常(normal)和患肺炎(pneumonia).訓練集和測試集的數(shù)量如表1所示.
表1 數(shù)據(jù)集分布
本次實驗所使用的評估指標分別為準確率(Accuracy,ACC),召回率(Recall),精確率(Precision,P),F1 分數(shù)(F1).以上指標的計算公式如下.
其中,TP表示正常樣本被正確分類的數(shù)量,TN表示肺炎樣本被正確分類的數(shù)量,FP表示肺炎樣本被分類為正常樣本數(shù)量,FN表示正常樣本中被分類為肺炎樣本的數(shù)量.
本節(jié)使用Chest X-Ray Images 圖像[18]作為數(shù)據(jù)集進行實驗.首先,為了驗證批量歸一化和組歸一化方式受批處理數(shù)量參數(shù)的影響,在不同的批處理數(shù)量下將SE-ResNet-GN 模型與嵌入SE的原始殘差網(wǎng)絡(以下簡稱SE-ResNet-BN) 進行比較,并分析模型的準確率、召回率、精準率和F1 等指標.然后,運用聯(lián)邦學習框架與改進SE-ResNet-GN 模型相融合,并與經(jīng)典的神經(jīng)網(wǎng)絡模型進行對比.最后,將本文改進的方法與其它研究者已有的研究成果進行對比.
3.4.1 Batch size 對融合BN和GN 網(wǎng)絡的影響
本節(jié)將改進后的SE-ResNet-GN 與原始網(wǎng)絡對比,批數(shù)量分別設定為4、8、16和32,其訓練結果如表2.
表2 Batch size 大小對SE-ResNet-BN和SE-ResNet-GN的指標影響
從表2可知,batch size為4–32 之間,SE-ResNet-BN模型的準確率最大差值為11.5 個百分點,準確率最小差值為2 個百分點.在上述batch size 下的SE-ResNet-GN模型的準確率變化較為平穩(wěn),上述batch size 下準確率差值最大為3 個百分點,差值最小為0.2 個百分點.當batch size為32 時,SE-ResNet-GN的準確率、召回率、F1 分數(shù)分別比SE-ResNet-BN 高1.1 個百分點、5.8 個百分點、1.8 個百分點.以上結果表明,整體上,SE-ResNet-GN 模型不僅具有較好的評價指標效果而且受batch size的影響較少.
3.4.2 本文方法與經(jīng)典深度神經(jīng)網(wǎng)絡模型對比
本節(jié)實驗假定有3 個客戶端進行20 輪聯(lián)邦擬合過程,客戶端本地進行10 輪迭代訓練.實驗數(shù)據(jù)集將根據(jù)客戶端的數(shù)量等量劃分成為各自本地的數(shù)據(jù).本節(jié)實驗選用AlexNet[13]、VGG[14]、ResNet[15]三種模型以及相關改進SE-ResNet-BN 模型與本文方法在相同的實驗環(huán)境和數(shù)據(jù)集下進行對比實驗,其結果如表3.
表3 本文方法與經(jīng)典深度神經(jīng)網(wǎng)絡模型性能對比
由表3可知,SE-ResNet-BN 模型的指標優(yōu)于ResNet,表明引入SE 模塊實現(xiàn)注意力機制能夠提升圖像分類的性能.經(jīng)典深度神經(jīng)網(wǎng)絡模型中準確率最高的SEResNet-BN 模型的準確率比本文方法低1.2 個百分點.本文方法的各項評價指標與AlexNet、ResNet、SEResNet-BN 相比具有明顯提升.本文方法利用改進殘差網(wǎng)絡能提取更深層次的圖像特征,嵌入壓縮激勵模塊關注通道特征,并進行特征重標定增強有效特征的提取,能夠提高聯(lián)邦共享模型的準確性,同時本文方法具有本地數(shù)據(jù)不對外共享的優(yōu)勢,能夠有效保護醫(yī)療數(shù)據(jù)的隱私,從而打破醫(yī)療數(shù)據(jù)孤島現(xiàn)象.因此與上述經(jīng)典方法相比,在診斷準確率和數(shù)據(jù)隱私保護方面具有優(yōu)勢.本節(jié)實驗將聯(lián)邦學習訓練的迭代擬合過程與傳統(tǒng)方法迭代過程對比,實驗準確率和損失值變化如圖5和圖6所示,整體上訓練過程中的準確率和損失值變化結果均優(yōu)于其它網(wǎng)絡.
圖5 迭代次數(shù)和準確率關系
圖6 迭代次數(shù)和損失值關系
3.4.3 與其它方法的對比
通過對Chest X-Ray Images 圖像數(shù)據(jù)集進行訓練,文獻[18]使用遷移學習方法建立了一種基于深度學習框架的診斷工具.文獻[19]提出一種結合殘差思想和膨脹卷積[20]的肺炎圖像分類方法[19],通過殘差網(wǎng)絡結構克服模型深度增加引起訓練過程中的過擬合和退化問題,并利用膨脹卷積避免肺炎圖像分類過程中的損失.文獻[21]利用AlexNet和InceptionV3 網(wǎng)絡模型并結合知識蒸餾方法提高對肺炎CT 圖像的分類性能,提出了AlexNet_S 方法.本文方法與上述3 種方法的實驗結果對比如表4所示.實驗結果表明,本文方法的準確率和召回率最優(yōu).在精度指標上本文方法比文獻[19]方法提升4.2 個百分點,但低于遷移學習[18]和AlexNet_S[21]的方法的精度指標.整體上本文方法在準確率、精度、召回率方面性能較好,同時融合聯(lián)邦學習框架,在醫(yī)療數(shù)據(jù)隱私安全性方面具有較大優(yōu)勢,這在現(xiàn)實應用中具有重要的實際意義.
表4 本文方法與其它方法對比
為了提高醫(yī)生疾病診斷的效率和準確性,本文融合聯(lián)邦學習架構和改進后的SE-ResNet-GN 模型,利用聯(lián)邦學習過程中數(shù)據(jù)保存在本地且不對外共享的優(yōu)勢,保證了醫(yī)療數(shù)據(jù)的隱私性并具有打破數(shù)據(jù)孤島的優(yōu)勢.為了提取更深層次的特征同時避免梯度消失,本文以殘差網(wǎng)絡作為基礎模型,在聯(lián)邦學習過程中,各個客戶端在本地進行獨立訓練,為了避免批處理數(shù)量對于聯(lián)邦共享模型的影響,本文將傳統(tǒng)的批量歸一化方式轉換為組歸一化方式,對輸入特征的通道進行分組運算,提高了模型的穩(wěn)定性,同時引入激勵壓縮網(wǎng)絡關注通道間的相關性.經(jīng)過與其它的深度神經(jīng)網(wǎng)絡以及融合聯(lián)邦學習框架后的模型對比實驗后發(fā)現(xiàn)本文提出的模型具有更好的準確率與安全性.