蔡德潤 李紅燕
(北京大學(xué)信息科學(xué)技術(shù)學(xué)院 北京 100871)
(機器感知與智能教育部重點實驗室(北京大學(xué)) 北京 100871)
突發(fā)公共衛(wèi)生安全事件往往會對社會醫(yī)療資源造成巨大的壓力.例如2020年初新冠肺炎疫情的暴發(fā)所帶來的醫(yī)護人員人手短缺、醫(yī)療資源擠兌等問題.其原因之一是新型冠狀病毒感染者容易出現(xiàn)“炎癥風(fēng)暴”[1],導(dǎo)致病情迅速惡化,死亡風(fēng)險上升.醫(yī)護人員需要投入大量的精力去觀察和跟蹤患者生理狀況的變化,并需要根據(jù)患者死亡風(fēng)險程度調(diào)配不同的醫(yī)療設(shè)備.例如體外膜肺氧合設(shè)備能夠為搶救贏得寶貴的時間,但是數(shù)量比較少,適用于重癥心肺功能衰竭患者.如果能夠利用患者的生命體征數(shù)據(jù)構(gòu)建深度學(xué)習(xí)模型,對死亡風(fēng)險上升的患者發(fā)出預(yù)警,則可以節(jié)省醫(yī)護人員的精力,及時對醫(yī)療設(shè)備進行合理的配置,增加醫(yī)療資源的利用率[2-7].
深度學(xué)習(xí)模型的成功應(yīng)用是建立在大量帶標(biāo)簽訓(xùn)練數(shù)據(jù)上的,并往往要求測試數(shù)據(jù)和訓(xùn)練數(shù)據(jù)服從同一分布,這在實際應(yīng)用中常常不能得到滿足.由于各種現(xiàn)實條件的限制,收集到的訓(xùn)練數(shù)據(jù)具有一定的局限性,例如可能某年齡段[8]、某科室或者某種并發(fā)癥占據(jù)了大多數(shù).這種局限性導(dǎo)致深度學(xué)習(xí)模型不能夠?qū)ζ渌闆r下的數(shù)據(jù)進行普適地預(yù)測.域適應(yīng)(domain adaptation)方法能夠利用源域和目的域的相似性,將源域上學(xué)習(xí)到的知識遷移到目的域上,從而解決該問題.
但是,將域適應(yīng)方法應(yīng)用在重癥監(jiān)護病人死亡風(fēng)險預(yù)測任務(wù)上時還遇到主要來自3個方面的困難:整體數(shù)據(jù)分布偏移、類別之間的數(shù)據(jù)分布偏移以及時序數(shù)據(jù)的多樣性和復(fù)雜性.其中整體數(shù)據(jù)分布偏移與類別之間的數(shù)據(jù)分布偏移如圖1所示:
Fig. 1 Data distribution shift圖1 數(shù)據(jù)分布偏移
整體數(shù)據(jù)分布偏移指的是源域和目的域整體的數(shù)據(jù)分布往往不相同.例如,在重癥監(jiān)護室內(nèi)收集到的數(shù)據(jù)中可能老年人占據(jù)大多數(shù).圖1中老年患者A與青年患者B的生命體征不相類似,表示以老年患者為主體的源域和以青年患者為主體的目的域的數(shù)據(jù)分布是有差異的.以醫(yī)療領(lǐng)域的MIMIC-Ⅲ(medical information mart for intensive care Ⅲ)數(shù)據(jù)集為例,血壓作為反映患者生理狀況的重要指標(biāo)之一,在不同年齡段的患者之間的分布是不同的.如圖2所示,患者的平均血壓隨著年齡增加而逐漸變低.這些生理指標(biāo)分布的差異導(dǎo)致在老年患者數(shù)據(jù)上訓(xùn)練的模型不能夠很好地泛化到青年患者的數(shù)據(jù)上.域適應(yīng)方法能夠適當(dāng)減小患者A與患者B的高級特征之間的距離,消除整體數(shù)據(jù)分布偏移所帶來的影響.
Fig. 2 Distribution of mean blood pressure with age圖2 平均血壓隨年齡的分布情況
類別之間的數(shù)據(jù)分布偏移指的是不同域之間同一類別的數(shù)據(jù)分布往往不同.無論特征來自哪個域,相同類別的特征之間應(yīng)該相隔較近,不同類別的特征之間應(yīng)該相隔較遠(yuǎn).因此需要在域適應(yīng)的基礎(chǔ)上進行類別適配.如圖1所示,患者A與患者C的年齡相近,數(shù)據(jù)分布也類似.但二者的存活結(jié)果不相同,屬于不同的類,因此需要進行類別適配,增加二者高級特征之間的距離.
時序數(shù)據(jù)的多樣性和復(fù)雜性所帶來的困難指的是患者各項生理指標(biāo)構(gòu)成了數(shù)據(jù)分布互不相同的不同通道,不同通道的時序變化趨勢共同描繪了病人的生理狀況.深度學(xué)習(xí)模型只有在理解不同時間步之間復(fù)雜的時序依賴關(guān)系并且有效地提取高級特征之后,才能進行域適應(yīng).
本文提出了一種基于域?qū)购图有杂嘞议g隔損失的無監(jiān)督域適應(yīng)方法(additive margin softmax based adversarial domain adaptation, AMS -ADA).其中域?qū)故且环N類似生成對抗網(wǎng)絡(luò)的方法,能夠解決整體數(shù)據(jù)分布偏移的困難.加性余弦間隔損失引入了度量學(xué)習(xí)的思想,能夠解決類別之間數(shù)據(jù)分布偏移帶來的困難.此外,本文使用帶有注意力機制的雙向長短程網(wǎng)絡(luò)作為特征提取器來應(yīng)對時序數(shù)據(jù)的多樣性和復(fù)雜性.
無監(jiān)督域適應(yīng)問題指的是利用源域的數(shù)據(jù)和標(biāo)簽以及目的域的數(shù)據(jù)訓(xùn)練深度學(xué)習(xí)模型,希望模型能夠在目的域上取得盡可能高的準(zhǔn)確度.與許多其他的遷移學(xué)習(xí)方法[9-13]相比,域適應(yīng)對目的域上的標(biāo)簽不做要求,進一步降低了獲取標(biāo)注數(shù)據(jù)的壓力.深度學(xué)習(xí)模型可以簡單地視作特征提取器和分類器2個部分.如果深度學(xué)習(xí)模型的特征提取器能夠從不同域之間的數(shù)據(jù)提取出域不變(domain invariant)的特征,那么在源域上訓(xùn)練的分類器就可以很好地應(yīng)用在目的域上.域不變的特征是指在源域和目的域都具有表現(xiàn)力和判別力的特征,蘊涵了源域和目的域之間可以共享的知識.為了實現(xiàn)提取出域不變特征的這一目標(biāo),減少整體數(shù)據(jù)分布的偏移,通常的做法有2種:
1) 基于特征映射的方法.對深度學(xué)習(xí)模型從源域和目的域提取出的高級特征之間施加距離約束,使得神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)出的高級特征的分布相似.如DDC(deep domain confusion)[14],DAN(deep adaptation network)[15]等方法使用了最大均值差異來衡量高級特征之間的分布差異,Deep CORAL[16]方法采用CORAL距離來衡量高級特征之間的分布差異.
2) 基于域?qū)沟姆椒?引入生成對抗網(wǎng)絡(luò)的思想,用域判別器判斷深度學(xué)習(xí)模型學(xué)習(xí)出的高級特征屬于源域還是目的域.以對抗訓(xùn)練的方式使特征提取器和域判別器達到平衡.當(dāng)域判別器無法辨別特征來自哪一個域的時候,說明特征提取器提取了具有域不變性的特征.如Adversarial Discriminative Domain Adaptation[17],Domain Adversarial Neural Networks[18]等.
近年來,域?qū)狗椒ㄒ云鋬?yōu)異的表現(xiàn)而備受關(guān)注.為了減少類別之間的數(shù)據(jù)分布偏移,進一步提升無監(jiān)督域適應(yīng)的效果,一些工作在域?qū)狗椒ǖ幕A(chǔ)上引入了度量學(xué)習(xí)的思想.例如Wang等人[19]和Yin等人[20]在域適應(yīng)任務(wù)中引入了三元組損失(triplet loss),在一定程度上最小化類內(nèi)距離和最大化類間距離.但是三元組損失的計算需要遍歷大量樣本對,增加了額外的計算量,并且需要選取合適大小的隱層特征作為三元組損失的優(yōu)化對象,增加了調(diào)整超參數(shù)的負(fù)擔(dān).
為了解決將域適應(yīng)方法應(yīng)用在死亡風(fēng)險預(yù)測任務(wù)上時遇到的困難以及相關(guān)工作的不足,本文提出了一種基于域?qū)购图有杂嘞议g隔損失的無監(jiān)督域適應(yīng)方法AMS -ADA.該方法在沒有目的域樣本標(biāo)簽的情況下,利用源域帶標(biāo)簽的數(shù)據(jù)和目的域不帶標(biāo)簽的數(shù)據(jù)進行訓(xùn)練,提升模型在目的域的準(zhǔn)確度.該方法主要由特征提取器G、域判別器D和加性余弦間隔損失分類器C組成,其架構(gòu)圖如圖3所示,源域和目的域數(shù)據(jù)流向分別用實線和虛線的箭頭表示.
Fig. 3 Overall architecture圖3 整體架構(gòu)
本文的研究目的是將無監(jiān)督域適應(yīng)方法應(yīng)用在重癥監(jiān)護病人死亡風(fēng)險預(yù)測任務(wù)上.在重癥監(jiān)護室內(nèi)各種醫(yī)療設(shè)備每隔一段時間記錄下病人的各項生命體征,這些記錄可以自然地視為時序數(shù)據(jù).
特征提取器負(fù)責(zé)從輸入數(shù)據(jù)提取有效的高級特征.為了應(yīng)對時序數(shù)據(jù)的復(fù)雜性和多樣性所帶來的困難,本文選取了帶有注意力機制的雙向長短程記憶網(wǎng)絡(luò)作為特征提取器.其中雙向長短時記憶網(wǎng)絡(luò)(bidirectional long short term memory, BiLSTM)作為嵌入層,對輸入的特征進行初步的提取,捕捉基本的時序信息.嵌入層將輸入x∈m×d變成輸出H∈u×d,即每個時間步的特征維度從m變?yōu)閡,并且包含了上下文的信息.
為了更好地提取時序信息,本文使用了自注意力機制[21],對嵌入層輸出的每一個時間步計算注意力值ai,i=1,2,…,d,再根據(jù)注意力值對所有時間步進行加權(quán)求和.注意力機制能夠使得深度學(xué)習(xí)模型更關(guān)注重要的時間步,從而能夠提取出表現(xiàn)力更強的特征.
記W1∈na×u,W2∈r×na為參數(shù)矩陣,na為計算注意力的隱層向量維度,r為注意力頭的個數(shù).Softmax操作對每個行向量進行,目的是使得每個時間步的注意力值的和為1.注意力矩陣A∈r×d的計算方式表示為
A=Softmax(W2tanh(W1H)).
(1)
最后,注意力層的輸出即為整個特征提取器的輸出G(x)∈r×u,表示為
M=AHT.
(2)
域判別器的作用是以域?qū)沟男问竭M行域適應(yīng),學(xué)習(xí)到域不變的特征,試圖解決整體數(shù)據(jù)分布偏移的問題.域?qū)菇梃b了生成對抗網(wǎng)絡(luò)的思想,使特征提取器和域判別器之間相互競爭,當(dāng)域判別器無法辨別特征來自源于還是目的域時,特征提取器學(xué)會了如何提取域不變的特征.
記源域和目的域的概率分布為p(Xs)和p(Xt),域判別器D的優(yōu)化目標(biāo)可以表達為
(3)
特征提取器G的優(yōu)化目標(biāo)可表示為
(4)
特征提取器和域判別器的優(yōu)化目標(biāo)可以結(jié)合在一起,寫成極小極大的優(yōu)化形式:
(5)
特征提取器和域判別器都是深度學(xué)習(xí)模型,在實踐中通常以梯度下降最小化損失函數(shù)的形式進行優(yōu)化.記特征提取器和域判別器的模型參數(shù)為θG和θD,域判別器的損失函數(shù)Ldisc(θG,θD)可以寫為
(6)
在對抗訓(xùn)練的過程中,特征提取器和域判別器的優(yōu)化是交替進行的,形式化地表達為
(7)
以域?qū)沟姆绞竭M行域適應(yīng),能夠利用生成對抗網(wǎng)絡(luò)強大的擬合數(shù)據(jù)分布的能力,更好地提取出域不變的特征.
加性余弦間隔(additive margin softmax, AM-Softmax)損失引入了度量學(xué)習(xí)的思想,能夠增強不同類別的樣本之間的可區(qū)分性.它作為最終的分類損失函數(shù),能夠同時端到端地最小化類內(nèi)距離和最大化類間距離,不需要再耗費時間去選取深度學(xué)習(xí)模型中哪一層的特征作為優(yōu)化目標(biāo).相比于三元組損失函數(shù),它不需要額外計算樣本對之間的距離,節(jié)省了訓(xùn)練所需時間.此外,在角度空間端到端地對類內(nèi)距離和類間距離進行優(yōu)化相比于三元組損失對隱層向量進行優(yōu)化能取得更好的效果.接下來以對Softmax損失進行改進的形式介紹加性余弦間隔損失的動機和原理.
記n為當(dāng)前批次訓(xùn)練樣本的數(shù)量,yi為樣本xi的類別標(biāo)簽,共有c類.Ⅱ(·)為示性函數(shù),當(dāng)括號內(nèi)表達式為真時其值為1,當(dāng)表達式為假時其值為0.p(j|xi)為模型給出的樣本xi屬于第j類的概率.Softmax損失函數(shù)LS可以寫為
(8)
記樣本xi對在深度學(xué)習(xí)模型中最后一層的輸入為fi,W為最后一層的權(quán)重矩陣,Wj為權(quán)重矩陣中對應(yīng)輸出類別j的行向量.省略偏置項,Softmax損失函數(shù)進一步寫為
(9)
記fi與Wj的夾角為cosθi,j,對權(quán)重矩陣和輸入進行歸一化,即令‖fi‖=1,‖Wj‖=1.記縮放值為η,Softmax函數(shù)可以用余弦值來表示:
(10)
將向量內(nèi)積寫成夾角的形式,使得對決策邊界的分析從歐氏空間轉(zhuǎn)變?yōu)榻嵌瓤臻g.現(xiàn)在以二分類的場景對決策邊界進行分析,如圖4所示.此時類別數(shù)c=2.當(dāng)cosθi,0>cosθi,1時,樣本xi被判定為c0類.同理,當(dāng)cosθi,1>cosθi,0時,樣本xi被判定為c1類.當(dāng)前情況下,Softmax損失能夠為不同類別劃分清晰的界限,但是沒有顯式地優(yōu)化類間的離散度度以及類內(nèi)的聚合度.為了增加決策邊界的寬度,引入邊界閾值m.現(xiàn)在對決策邊界施加更加嚴(yán)格的要求,當(dāng)cosθi,0-m>cosθi,1時,樣本xi被判定為c0類,當(dāng)cosθi,1-m>cosθi,0時,樣本xi被判定為c1類.將二分類的情況推廣為多分類便可得到加性余弦間隔損失.記特征提取器和分類器的參數(shù)分別為θG和θC,加性余弦間隔損失LAMS(θG,θC)形式化地表達為
(11)
Fig. 4 Comparison between AM-Softmax Loss and Softmax Loss圖4 加性余弦間隔損失和Softmax損失的對比
對決策邊界施加的限制能夠在角度空間最大化分類器的決策邊界,從而達到最小化類內(nèi)距離和最大化類間距離的目的.
本文提出的方法含有可訓(xùn)練參數(shù)的部分為特征提取器、域判別器和分類器,其參數(shù)分別記為θG,θD,θC.由式(6)和式(11)可得最終的損失函數(shù)L(θG,θC,θD):
L(θG,θC,θD)=LAMS(θG,θC)-λLdisc(θG,θD),
(12)
其中λ為平衡因子,調(diào)節(jié)LAMS和Ldisc的比例.
本文提出方法的詳細(xì)訓(xùn)練流程如AMS-ADA算法所示.首先對特征提取器、域判別器和分類器的參數(shù)進行隨機的初始化.訓(xùn)練過程中對這些參數(shù)以梯度下降的形式進行交替優(yōu)化.本文選用深度學(xué)習(xí)領(lǐng)域中常用的Adam優(yōu)化器完成梯度下降的任務(wù).在對抗訓(xùn)練的每次迭代的過程中,為了使得域判別器能夠更好地指導(dǎo)特征提取器生成域不變的特征,需要增加域判別器的更新次數(shù),即域判別器更新Ndisc次之后,特征提取器和分類器才更新一次.域判別器的更新是指計算Ldisc(θG,θD)后通過反向傳播更新域判別器的參數(shù).特征提取器和分類器的更新也是類似地計算各自的損失函數(shù)后通過反向傳播對參數(shù)進行更新.當(dāng)損失函數(shù)收斂之后,得到訓(xùn)練好的模型.此時將目的域的數(shù)據(jù)輸入模型,得到最終的預(yù)測值.
算法1.基于域?qū)购图有杂嘞议g隔損失的無監(jiān)督域適應(yīng)方法.
① 隨機初始化θG,θD,θC;
② repeat
③ fori=1,2,…,Ndiscdo
④ 根據(jù)式(6)計算Ldisc(θG,θD);
⑥ end for
⑦ 根據(jù)式(12)計算L(θG,θC,θD);
⑩ until模型參數(shù)收斂
本文選用MIMIC-Ⅲ數(shù)據(jù)集[22]進行實驗.MIMIC-Ⅲ數(shù)據(jù)集是麻省理工大學(xué)維護的公共臨床數(shù)據(jù)庫,包含2001—2016年之間約6萬例的住院記錄,每條記錄包括人口統(tǒng)計特征、醫(yī)療干預(yù)記錄、成像報告、生命體征記錄、護理記錄等信息.
Harutyunyan等人[2]在MIMIC-Ⅲ數(shù)據(jù)集的基礎(chǔ)上定義了死亡風(fēng)險預(yù)測任務(wù).一般來說,患者進入重癥監(jiān)護室后的48 h以內(nèi)的情況較為危急,因此本文選取患者進入重癥監(jiān)護室之后的48 h以內(nèi)的數(shù)據(jù)對患者的存活結(jié)果進行預(yù)測.
根據(jù)Harutyunyan等人[2]的工作,本文在MIMIC-Ⅲ數(shù)據(jù)集中提取了76維的特征,包括心率(heart rate)、舒張壓(systolic blood pressure)、收縮壓(diastolic blood pressure)、血氧飽和度(SpO2)、毛細(xì)血管填充率(capillary refill rate)等60維的連續(xù)特征和格拉斯哥昏迷指數(shù)(Glasgow coma scale)等12維的離散特征以及4維的關(guān)于患者信息的常量.經(jīng)過數(shù)據(jù)清洗和預(yù)處理后,最終得到的輸入數(shù)據(jù)共有48個時間步,每個時間步有76維的特征.
Purushotham等人[23]嘗試在不同年齡段的急性低氧性呼吸衰竭患者之間進行了遷移學(xué)習(xí).本文沿用了該文的實驗設(shè)置,將MIMIC-Ⅲ數(shù)據(jù)集的ICU數(shù)據(jù)庫中所有患者按照年齡分為4組,如表1所示:
Table 1 Different Domains of MIMIC-Ⅲ Dataset表1 MIMIC-Ⅲ數(shù)據(jù)集不同域的劃分
由于數(shù)據(jù)集中的正負(fù)樣本比例相差較大,且屬于二分類問題,為了避免正負(fù)樣本不均衡對評價指標(biāo)帶來的影響,本次實驗采用ROC曲線(receiver operating characteristic curve)下的面積值(area under curve, AUC)作為評價標(biāo)準(zhǔn).本文采用了5種方法與本文提出的AMS-ADA方法進行對比:
1) BiLSTM.使用結(jié)合自注意力機制的BiLSTM網(wǎng)絡(luò)作為基線,在源域上訓(xùn)練,在目的域上測試,沒有使用任何無監(jiān)督域適應(yīng)學(xué)習(xí)方法.
2) CORAL.使用結(jié)合自注意力機制的BiLSTM網(wǎng)絡(luò)提取特征,采用基于特征映射的遷移學(xué)習(xí)方法Deep CORAL[16].該方法以CORAL距離衡量源域特征分布和目的域特征分布的差異.
3) DAN.使用結(jié)合自注意力機制的BiLSTM網(wǎng)絡(luò)提取特征,采用基于特征映射的遷移學(xué)習(xí)方法DAN[15].該方法以最大均值差異衡量源域特征分布和目的域特征分布的差異.
4) ADA(adversarial domain adaptation).與1)~3)所述方法采用相同的特征提取器,并且使用了域?qū)狗椒?,使用Softmax損失函數(shù).
5) Tri-ADA(triplet loss guided adversarial domain adaptation)[19].與1)~4)所述方法使用相同的特征提取器.使用了域?qū)狗椒?并且在此基礎(chǔ)上加上三元組損失函數(shù),以解決類別之間的數(shù)據(jù)分布偏移的問題.
本文在4個不同年齡段,即4個域之間兩兩進行無監(jiān)督域適應(yīng)任務(wù),實驗結(jié)果如表2所示.例如將青年患者的數(shù)據(jù)作為源域,將中年患者的數(shù)據(jù)作為目的域,無監(jiān)督域適應(yīng)任務(wù)記為青年→中年.
Table 2 Experimental Results of Mortality Prediction Task Based on Unsupervised Domain Adaptation表2 基于無監(jiān)督域適應(yīng)的死亡風(fēng)險預(yù)測實驗結(jié)果
本文提出的AMS-ADA方法在12個無監(jiān)督域適應(yīng)任務(wù)中的10個取得了最高的AUC值,說明了該方法的有效性.BiLSTM方法沒有使用任何無監(jiān)督域適應(yīng)方法,因此表現(xiàn)較差.對于相隔較遠(yuǎn)的域,BiLSTM方法的表現(xiàn)下降較為明顯.比如對于任務(wù)中年→青年,BiLSTM方法的AUC值為0.867,而對于任務(wù)中年→高齡老年,BiLSTM方法的AUC值下降為0.754.相隔較遠(yuǎn)的域意味著年齡相隔較大,數(shù)據(jù)的分布差異更為顯著,因此對模型的準(zhǔn)確度影響較大.CORAL方法和DAN方法使用了基于特征映射的遷移方法嘗試解決全局的數(shù)據(jù)分布差異的問題,從結(jié)果上可以看出這2種方法相比BiLSTM方法有一定的提升.ADA方法引入了域?qū)?,相比于基于特征映射的方法能夠更好地減少全局的數(shù)據(jù)分布差異,因此效果更好.Tri-ADA以域?qū)沟男问竭M行域適應(yīng),并且加入了三元組損失以減少類別之間的數(shù)據(jù)分布差異.實驗結(jié)果較之CORAL和DAN方法有了一定的提升.為了更精細(xì)地對齊類別之間的數(shù)據(jù)分布,本文提出的AMS-ADA方法引入了加性余弦間隔損失,相比ADA方法和Tri-ADA方法的準(zhǔn)確度有了進一步的提升,說明了本文提出方法的有效性.
為了直觀體現(xiàn)本文提出方法的優(yōu)越性,分別訓(xùn)練BiLSTM,Tri-ADA,AMS-ADA這3種方法,取各個方法的分類器的最后一層輸出特征投影到角度空間進行可視化.選取青年患者的數(shù)據(jù)作為源域,高齡老年患者的數(shù)據(jù)作為目的域.BiLSTM方法的源域和目的域特征可視化結(jié)果分別如圖5和圖6所示.BiLSTM方法在目的域的準(zhǔn)確度下降,其原因之一是類別之間的分布偏移.在源域訓(xùn)練時,不同類別的特征之間具有明顯的界限.但是決策邊界不夠?qū)挘谀康挠驕y試時由于分布偏移導(dǎo)致分類錯誤.
Fig. 5 Source domain feature visualization of BiLSTM method圖5 BiLSTM方法的源域特征可視化
Fig. 6 Target domain feature visualization of BiLSTM method圖6 BiLSTM方法的目的域特征可視化
因此,模型應(yīng)該顯式地增大決策邊界,保持類內(nèi)緊湊性和類間可分離性.Tri-ADA方法的源域和目的域特征可視化結(jié)果分別如圖7和圖8所示.Tri-ADA方法在源域訓(xùn)練時以三元組損失的形式增大了類間距離,因此在目的域測試時不同類別的特征之間可分離性加強,從而降低了錯誤率.
Fig. 7 Source domain feature visualization of Tri-ADA method圖7 Tri-ADA方法的源域特征可視化
Fig. 8 Target domain feature visualization of Tri-ADA method圖8 Tri-ADA方法的目的域特征可視化
AMS-ADA方法引入了AM-Softmax損失函數(shù),能夠進一步在角度空間增加決策邊界的寬度,其源域和目的域特征可視化結(jié)果分別如圖9和圖10所示.存活患者的特征與死亡患者的特征的重疊部分進一步縮小,取得了很好的類間可分離性和類內(nèi)緊湊性.得益于更寬的決策邊界,在源域上訓(xùn)練的分類器對類別偏移的敏感程度下降,因此在目的域上測試時能夠取得更好的準(zhǔn)確度.
Fig. 9 Source domain feature visualization of AMS-ADA method圖9 AMS-ADA方法的源域特征可視化
Fig. 10 Target domain feature visualization of AMS-ADA method圖10 AMS-ADA方法的目的域特征可視化
深度學(xué)習(xí)模型的實際應(yīng)用中容易遇到訓(xùn)練數(shù)據(jù)不足、整體數(shù)據(jù)分布偏移和類別之間數(shù)據(jù)分布偏移的問題.本文提出了一種基于域?qū)购图有杂嘞议g隔損失的無監(jiān)督域適應(yīng)方法應(yīng)對這些問題.本文以域?qū)沟男问綔p少了整體數(shù)據(jù)之間數(shù)據(jù)分布偏移.為了進一步改善無監(jiān)督域適應(yīng)的效果,引入度量學(xué)習(xí)的思想,以最小化加性余弦間隔損失的形式減少了類別之間的數(shù)據(jù)分布偏移.所提出的方法在重癥監(jiān)護病人死亡風(fēng)險預(yù)測任務(wù)上進行了驗證,在MIMIC-Ⅲ數(shù)據(jù)集上的實驗結(jié)果和可視化分析結(jié)果證明了該方法的有效性.未來的工作會嘗試將所提出方法擴展到醫(yī)療領(lǐng)域的其他任務(wù)中,例如疾病預(yù)測和住院時長預(yù)測等任務(wù).
作者貢獻聲明:蔡德潤提出算法思路、完成實驗并撰寫論文;李紅燕提出了指導(dǎo)意見并修改論文.