田豆,李鳳蓮,張雪英,張晉義
(太原理工大學信息與計算機學院,山西榆次 030600)
腦卒中[1]是導致死亡的最常見疾病之一。近幾年,基于機器學習方法開展腦卒中發(fā)病風險預測成為研究的熱點[2-3]。區(qū)別于一般機器學習分類模型的構(gòu)建思路,文獻[4]中提出一種采用線性近似Q-learning算法訓練分類模型。文獻[5-6]通過用神經(jīng)網(wǎng)絡代替線性近似Q-learning,提出一種基于深度強化學習的分類算法。目前,國內(nèi)外學者將強化學習與環(huán)境交互直接用于數(shù)據(jù)分類的研究[7]逐漸增多。該文將特征選擇和深度強化學習融合,并采用線性衰減ε-貪婪策略優(yōu)化探索與利用的過程,提出一種融合特征選擇的新型深度強化學習(Feature Selection and new Deep Reinforcement Learning,F(xiàn)S-nDRL)分類模型并應用于腦卒中經(jīng)顱多普勒(Transcranial Doppler,TCD)數(shù)據(jù)集的輔助分類診斷。
卡方檢驗的基本公式如式(1)所示:
其中,χ2為卡方值,W為實際頻數(shù),T為理論頻數(shù)。當用于特征選擇時,特征f與類別c之間的相關(guān)性如式(2)[8]所示:
卡方值越大,表明特征與類別之間的相關(guān)性越大,即特征對類別的區(qū)分度較高,因此在選擇時,會優(yōu)先挑選卡方值更高的特征。
卡方檢驗法屬于一種過濾式特征選擇算法,比較適用于分類問題[9-10]特征選擇的場景。故該文采用卡方檢驗法進行特征選擇,以去除冗余特征。
ε-貪婪策略,即學習器目前選擇最佳動作的概率是1-ε,而隨機選擇其他動作的概率是ε。如式(3)所示:
線性衰減ε-貪婪策略,則是在ε-貪婪策略的基礎(chǔ)上,使探索率ε從定義的初始值開始,隨時間線性下降到最小值。
強化學習主要通過不斷地嘗試摸索來尋找最優(yōu)策略,即智能體通過對環(huán)境的自我探索逐漸學習控制行為。強化學習問題常采用馬爾科夫決策過程(MDP)進行建模,并使用五元組描述模型。
深度強化學習(Deep Reinforcement Learning,DRL)是谷歌的DeepMind 人工智能研究團隊將具有智能感知能力的深度網(wǎng)絡和具有智能決策能力的強化學習相結(jié)合研發(fā)的新算法,廣泛應用于智能 通 信[11]、網(wǎng)絡安全[12]、無人駕駛[13]、3D 對象分類[14]等領(lǐng)域。
1)狀態(tài)空間S={s1,s2,…,snm},n表示輸入訓練樣本個數(shù),m表示輸入訓練樣本包含的屬性數(shù);狀態(tài)st=(xt,ft)指智能體在當前時刻t對樣本xt進行分類時選擇了特征ft。
2)動作空間A=Ay∪Af={a1,a2,…,am},表示智能體在狀態(tài)st做出的動作at。
Af是屬性特征動作的集合,指智能體在狀態(tài)st時對樣本xt選擇其中一個特征ft作為對應的動作。
Ay是決策特征動作的集合,指當前樣本xt可能的預測類別。如果A=Ay,則停止決策過程。
3)獎賞函數(shù)R={r1,r2,…,rnm}指智能體針對當前預測問題在狀態(tài)st時采取動作at獲得的獎懲。
在DQN 中,定義一個狀態(tài)動作值函數(shù)Q(st,at),式(5)表示智能體在狀態(tài)st利用最優(yōu)策略π選擇動作at時獲得的Q值。
定義最優(yōu)函數(shù)Q*(st,at)表示智能體在狀態(tài)st采取動作at獲得的最高累計獎賞值,公式服從貝爾曼(Bellman)方程:
其中,γ是折扣因子,st+1是下一狀態(tài),at+1是根據(jù)最優(yōu)策略π選擇的動作,Q*(st,at)是獲得的最優(yōu)Q值。經(jīng)過反復的值迭代,Q(st,at)收斂到Q*(st,at)。
在使用深度神經(jīng)網(wǎng)絡擬合Q函數(shù)時,Q函數(shù)由權(quán)重為θ的深度神經(jīng)網(wǎng)絡近似為Q(st,at;θ),其損失函數(shù)定義如下:
其中,y是目標值,θ′是目標網(wǎng)絡的參數(shù),目標網(wǎng)絡的結(jié)構(gòu)與當前網(wǎng)絡Q(st,at;θ)相同。
該文利用經(jīng)驗回放來穩(wěn)定收斂過程,構(gòu)造經(jīng)驗池P以存儲智能體狀態(tài)轉(zhuǎn)移記錄et,其中智能體狀態(tài)轉(zhuǎn)移記錄et采用ε-貪婪策略得到,然后抽取存儲在P中的經(jīng)驗樣本,通過最小化損失函數(shù)(式(7))來訓練Q 網(wǎng)絡。根據(jù)式(9)使用梯度下降法迭代更新Q 網(wǎng)絡參數(shù)θ,繼續(xù)模擬環(huán)境中的交互過程,存儲新的狀態(tài)轉(zhuǎn)移記錄以替換舊的轉(zhuǎn)移記錄。直到目標網(wǎng)絡與當前網(wǎng)絡融合,訓練過程反復進行。
其中,a為學習率。
在每次優(yōu)化后,使用式(10)更新目標網(wǎng)絡的參數(shù)θ′:
其中,ρ為目標網(wǎng)絡的更新因子。
該文提出的分類預測模型首先使用卡方檢驗特征選擇法對數(shù)據(jù)進行冗余特征處理,然后建立輸入樣本對應的MDP,在MDP 中使用線性衰減ε-貪婪策略選擇動作at對環(huán)境進行探索和利用,然后使用梯度下降法對Q 網(wǎng)絡進行訓練,得到訓練完成的深度Q 網(wǎng)絡;進一步使用驗證集對模型進行優(yōu)化,得到優(yōu)化后的分類預測模型;最后采用測試集進行分類模型性能驗證。具體FS-nDRL 分類模型流程圖如圖1 所示。
圖1 FS-nDRL分類模型流程圖
實驗數(shù)據(jù)集采用UCI 中的4 個數(shù)據(jù)集:Ecoli(E)、Wine(W)、Yeast(Y)、Lymphography(L)以及從山西省人民醫(yī)院隨機選取的腦卒中TCD 數(shù)據(jù),各數(shù)據(jù)集詳情如表1、2 所示。
表1 UCI數(shù)據(jù)集
該文模型的神經(jīng)網(wǎng)絡有3 個完全連接的隱含層,其中每層網(wǎng)絡的激活函數(shù)使用修正線性單元(Rectified Linear Unit,ReLU)激活函數(shù),各層中的神經(jīng)元數(shù)量為64。最后一層是網(wǎng)絡的輸出層,主要用于分類預測,使用softmax 作為該層的分類器。學習率α設置為0.000 5,折扣因子γ設置為0.95,智能體采用動作Af選擇屬性特征ft時,獲得的獎賞值l 設置為0.001。采用線性衰減ε-貪婪策略選擇動作a,探索率ε從設置的初始值1 開始,并隨著時間的推移每100 步線性降低至最小值0.1。該文使用交叉驗證方法選擇數(shù)據(jù)的60%作為訓練集,用來構(gòu)建分類模型,再將其中的20%作為驗證集,優(yōu)化模型,最后將剩余的20%作為測試集,來評估模型的性能。
表2 腦卒中TCD數(shù)據(jù)集
為說明FS-DRL 方法的有效性,該文首先構(gòu)建了新型深度強化學習分類預測模型(newDeep Reinforcement Learning,nDRL),并采用支持向量機(Support Vector Machines,SVM)、極限學習機(Extreme Learning Machine,ELM)和代價調(diào)整極限學習機[15](Classspecific Cost Regulation Extreme Learning Machine,CCR-ELM)分類預測模型作為對比方法,進行模型性能驗證。
該文采用準確率評估預測模型的分類性能,準確率如式(11)所示:
其中,TP代表被分類模型預測為正類的正類樣本數(shù),F(xiàn)N代表被分類模型預測為負類的正類樣本數(shù),TN代表被分類模型預測為負類的負類樣本數(shù),F(xiàn)P代表被分類模型預測為正類的負類樣本數(shù)。準確率A越大,說明模型的分類效果越好。
該文提出的方法nDRL、FS-nDRL 與已有方法ELM、CCR-ELM、SVM 在UCI 數(shù)據(jù)集上分類準確率性能對比如圖2 所示。由圖2 可知,該文提出的FSnDRL 算法相比其他算法,在E 數(shù)據(jù)集的準確率分別提高了8.92%、5.39%、7.58%、2.47%;在W 數(shù)據(jù)集的準確率分別提高了19.38%、13.3%、6.76%、1.09%;在Y 數(shù)據(jù)集的準確率分別提高了8.38%、5.26%、8%、2.41%;在L 數(shù)據(jù)集的準確率分別提高了5.74%、4.98%、5.7%、3.02%。
圖2 UCI數(shù)據(jù)集準確率結(jié)果對比
進一步將該文提出的方法用于腦卒中TCD 數(shù)據(jù)集,與ELM、CCR-ELM、SVM 和nDRL 的性能對比如圖3 所示。由圖3 可知,該文提出的FS-nDRL 算法相比其他算法,在data1 數(shù)據(jù)集的準確率分別提高了18.39%、15.86%、11.55%、5.05%;在data2 數(shù)據(jù)集的準確率分別提高了19.15%、16.41%、17.48%、5.33%;在data3數(shù)據(jù)集的準確率分別提高了20.64%、14.44%、23.24%、5.42%;在data4 數(shù)據(jù)集的準確率分別提高了19.65%、13.07%、10.66%、5.33%。
圖3 腦卒中數(shù)據(jù)集準確率結(jié)果對比
綜上可知,該文提出的融合特征選擇深度強化學習的新型分類預測模型FS-nDRL,對公共數(shù)據(jù)集UCI 和腦卒中TCD 數(shù)據(jù)集整體預測性能比已有算法的分類性能更優(yōu)[16]。
針對數(shù)據(jù)集采用傳統(tǒng)機器學習算法分類預測性能偏低的問題,該文提出一種融合特征選擇和深度Q 網(wǎng)絡的分類算法。實驗結(jié)果表明,與已有方法SVM、ELM、CCR-ELM 相比,該文提出的方法對UCI和腦卒中TCD 數(shù)據(jù)集在準確率上有所提升。但基于深度強化學習的分類預測模型,總體計算復雜度依然遠高于其他已有分類預測模型,因此仍有必要進一步研究降低其計算復雜度的有效方法,并且實際的腦卒中TCD 數(shù)據(jù)經(jīng)常是極不平衡的。下一步主要從強化學習模型的構(gòu)建機制來研究有效方法,改進算法適用于分類實際中的非平衡腦卒中TCD 數(shù)據(jù)集。