宋 蒙,蔣生強(qiáng)
(1.中國電子科技集團(tuán)公司第五十四研究所,河北 石家莊 050081;2.華為技術(shù)有限公司,廣東 東莞 523000)
支持向量機(jī)(Support Vector Machines, SVM)[1-4]是基于概率計算理論的機(jī)器學(xué)習(xí)類算法[5-7]中比較常見的方法。C-支持向量[8]分類應(yīng)用多以經(jīng)典的高斯函數(shù)為核函數(shù),如何縮短搜尋SVM最優(yōu)訓(xùn)練參數(shù)組合(Optimal Training Parameters Combination, OTPC)的時間一直是近年來學(xué)術(shù)界研究的熱點(diǎn)。
在實(shí)際的軟件實(shí)現(xiàn)中,最經(jīng)典的得到OTPC的方法是交叉驗(yàn)證訓(xùn)練[9]和網(wǎng)格搜索[10]相融合進(jìn)行搜索。交叉驗(yàn)證訓(xùn)練需要對訓(xùn)練數(shù)據(jù)進(jìn)行折疊處理,處理后再對每組折疊后的數(shù)據(jù)單獨(dú)訓(xùn)練;而網(wǎng)格搜索法則需要對每組訓(xùn)練數(shù)據(jù)進(jìn)行遍歷驗(yàn)證訓(xùn)練。因此這兩種方法需要訓(xùn)練的次數(shù)較多,延長了得到OTPC的時間。
在傳統(tǒng)的硬件實(shí)現(xiàn)中,多數(shù)情況下都是基于序列最小優(yōu)化 (Sequential Minimal Optimization, SMO)[11]算法來完成單次交叉驗(yàn)證的訓(xùn)練過程。雖然LIBSVM使用了改進(jìn)的SMO算法實(shí)現(xiàn)交叉驗(yàn)證過程[12-14],然而改進(jìn)的SMO算法在搜索工作集索引時,需要計算二階導(dǎo)數(shù)信息,另外還需要計算運(yùn)算量較大的核函數(shù),這些都使得軟件執(zhí)行交叉驗(yàn)證無法滿足實(shí)時應(yīng)用的需求。
為提升SVM搜索OTPC的性能,本文分別從算法和實(shí)現(xiàn)兩方面進(jìn)行了探索:首先論證了共享點(diǎn)積矩陣(Share Dot Product Matrix, SDPM)算法[15],然后進(jìn)行了硬件實(shí)現(xiàn),并對性能進(jìn)行了分析和比較。
SVM是基于兩種數(shù)據(jù)集的模型。其學(xué)習(xí)目標(biāo)是尋找最優(yōu)的超平面來劃分訓(xùn)練數(shù)據(jù)集。超平面可以用數(shù)學(xué)方法表示如下:
wTx+b=0。
(1)
SVM基本的思想是找到wT使得分隔是最大的。通過選取合適的核函數(shù)K(xi,xj)[15]和懲罰系數(shù)C就可以處理非線性的數(shù)據(jù),這樣原問題就轉(zhuǎn)化為對偶問題:
(2)
典型的高斯核函數(shù)為:
(3)
在網(wǎng)格搜索方法的一次迭代中,需要對索引集的數(shù)據(jù)先進(jìn)行計算處理,再完成后面的訓(xùn)練,因此以后每次的訓(xùn)練均需要預(yù)先處理及后續(xù)訓(xùn)練。而SDPM算法先讀數(shù)據(jù),并對數(shù)據(jù)進(jìn)行初始化,將處理后的數(shù)據(jù)直接存儲為點(diǎn)積矩陣,因此后期的迭代過程中不再有中間的初始化及計算過程。故SDPM算法僅需要在搜索之前對點(diǎn)積進(jìn)行運(yùn)算再存儲到點(diǎn)積矩陣;在之后需要點(diǎn)積向量的搜索過程中,只需統(tǒng)一在點(diǎn)積矩陣中調(diào)取向量結(jié)果即可。
下面從存儲計算量和點(diǎn)積計算量這兩方面來分析SDPM算法的復(fù)雜度。
(1) 存儲計算量
訓(xùn)練數(shù)據(jù)集的長度為N,點(diǎn)積數(shù)據(jù)位長為L,存儲計算量則為T=W×N2。
(2) 點(diǎn)積計算量
傳統(tǒng)方法搜索OTPC過程中點(diǎn)積的計算次數(shù)Num1約為:
(4)
式中,m為懲罰系數(shù)可選數(shù)量,n為高斯核函數(shù)參數(shù)σ可選數(shù)量,s為交叉驗(yàn)證折疊次數(shù),r為迭代平均次數(shù)因子。
而使用SDPM算法搜索OTPC時,點(diǎn)積的計算次數(shù)Num2為:
(5)
在SVM構(gòu)建模型搜索OTPC時,SDPM算法需要將數(shù)據(jù)集中的參數(shù)進(jìn)行組合,對數(shù)據(jù)類別精準(zhǔn)分類,并對數(shù)據(jù)折疊處理和交叉驗(yàn)證訓(xùn)練。本文使用了層次設(shè)計法:首先,將問題進(jìn)行簡化,將一個大過程分解成幾個不同層次的小過程;然后,尋找不同層次之間的銜接和關(guān)聯(lián),按序進(jìn)行處理;最后,設(shè)置各個不同層次的輸入以及輸出,實(shí)現(xiàn)各個不同層次的功能。
軟件設(shè)計流程如圖1所示。
圖1 SVM模型的軟件實(shí)現(xiàn)框圖Fig.1 Software realization diagram of SVM model
基于SDPM算法的SVM模型的軟件實(shí)現(xiàn)分為以下幾個步驟:
① 讀取數(shù)據(jù)并對數(shù)據(jù)進(jìn)行初始化。
② 訓(xùn)練集中只有兩類數(shù)據(jù)可以直接讀取,如果是多類數(shù)據(jù),選擇兩類數(shù)據(jù)s折疊,折疊后計算并存儲點(diǎn)積矩陣。
③ 選取新訓(xùn)練參數(shù)并開始新的交叉驗(yàn)證。其中包括更新拉格朗日系數(shù)、計算閾值b等。
④ 對訓(xùn)練結(jié)果進(jìn)行匯總,統(tǒng)計數(shù)據(jù)的準(zhǔn)確率。
⑤ 完成②~④后,重新選定訓(xùn)練參數(shù)之后運(yùn)行③和④,將訓(xùn)練集中所有參數(shù)組合并全部按此完成訓(xùn)練,得到所有訓(xùn)練的結(jié)果。
⑥ 在兩類數(shù)據(jù)訓(xùn)練結(jié)束后,從訓(xùn)練數(shù)據(jù)集中選擇其他的兩類數(shù)據(jù)完成③~⑤。訓(xùn)練數(shù)據(jù)集中所有類別均需如此完成訓(xùn)練。完成訓(xùn)練后將結(jié)果保存并輸出。
雖然SDPM算法利用前期計算并存儲點(diǎn)積矩陣的方法減少了點(diǎn)積的計算量。然而軟件在建立模型過程中仍然需要涉及大量的計算,無法滿足SVM實(shí)時應(yīng)用的要求。本文基于SDPM算法的軟硬件協(xié)同架構(gòu),能在最短時間內(nèi)搜索出SVM的OTPC,進(jìn)而完成模型的構(gòu)建,滿足SVM在實(shí)時場合應(yīng)用的需要。本文設(shè)計的軟硬件協(xié)同架構(gòu)如圖2所示。
圖2 軟硬件協(xié)同架構(gòu)Fig.2 Software-hardware collaborative framework
系統(tǒng)整體框架設(shè)計如下:硬件系統(tǒng)的結(jié)構(gòu)由PC機(jī)、協(xié)處理器板卡(加速板卡)和DDR3內(nèi)存構(gòu)成。PC機(jī)采用X86處理器,加速板卡使用Xilinx的Virtex 7系列FPGA芯片。
系統(tǒng)軟件實(shí)現(xiàn)如下:硬件上執(zhí)行浮點(diǎn)運(yùn)算會耗費(fèi)大量的資源,且本文對于算法的執(zhí)行結(jié)果,允許一定的精度損失。因此,在使用硬件實(shí)現(xiàn)算法時,全部采用定點(diǎn)運(yùn)算。測試表明,定點(diǎn)運(yùn)算不會對OTPC造成太大的影響。而硬件的定點(diǎn)化則要求訓(xùn)練數(shù)據(jù)在輸入到FPGA時按定點(diǎn)格式進(jìn)行編碼。
PC機(jī)首先讀取所有的訓(xùn)練數(shù)據(jù)集,存儲后進(jìn)行初始化。然后對訓(xùn)練數(shù)據(jù)集分別進(jìn)行控制,目的是迅速找到此次需要搜索的OTPC訓(xùn)練數(shù)據(jù)集,定位訓(xùn)練數(shù)據(jù)集后進(jìn)行s折疊并將首次折疊的數(shù)據(jù)輸入FPGA。數(shù)據(jù)初始化是將所有的訓(xùn)練數(shù)據(jù)集統(tǒng)一打包成16位的數(shù)據(jù),其中1位為符號位,15位為小數(shù)位;當(dāng)訓(xùn)練數(shù)據(jù)集為多類時,需要先將數(shù)據(jù)集進(jìn)行分類,從中取出兩類;而訓(xùn)練數(shù)據(jù)集為兩類時,可直接進(jìn)行訓(xùn)練。
FPGA運(yùn)算后將一組拉格朗日系數(shù)向量回傳給PC機(jī),然后PC機(jī)求解閾值b,基于b值和系數(shù)向量可以構(gòu)建出模型,并對構(gòu)建的模型進(jìn)行評估和驗(yàn)證,在得到全部組合訓(xùn)練參數(shù)的準(zhǔn)確率后,找到準(zhǔn)確率最高的參數(shù)組合,以此類推,找到任意兩類數(shù)據(jù)組合的OTPC后輸出。
FPGA的內(nèi)部邏輯模塊包括主進(jìn)程調(diào)度、數(shù)據(jù)接收和點(diǎn)積計算模塊、訓(xùn)練模塊、指數(shù)運(yùn)算模塊和中間數(shù)據(jù)存儲模塊5個部分,其中主進(jìn)程調(diào)度模塊根據(jù)FPGA的時鐘信息調(diào)度其他功能模塊,數(shù)據(jù)接收和點(diǎn)積計算模塊首先是接收定點(diǎn)格式的初始化訓(xùn)練數(shù)據(jù),將數(shù)據(jù)進(jìn)行點(diǎn)積計算后存儲為點(diǎn)積矩陣。訓(xùn)練模塊完成搜索工作集索引和更新系數(shù)、梯度等交叉訓(xùn)練過程,并將運(yùn)算結(jié)果送達(dá)至PC機(jī)。指數(shù)運(yùn)算模塊從DDR3中讀取兩路點(diǎn)積向量,最終完成核函數(shù)運(yùn)算并存儲至FPGA內(nèi)部RAM。
主要對所設(shè)計的基于SDPM算法的SVM硬件結(jié)構(gòu)的參數(shù)誤差、運(yùn)行時間等方面進(jìn)行比較。
參數(shù)誤差是表征LIBSVM和SDPM模型差異的主要參數(shù),參數(shù)誤差主要包含(a*b)及迭代次數(shù)。其中,a*為拉格朗日系數(shù)向量,b為模型的閾值。表1~表3分別統(tǒng)計出在數(shù)據(jù)集Iris(Iris-setosa & Iris-versicolor),TestD1,TestD2下的LIBSVM和SDPM兩種模型的a*參數(shù)誤差。
表1 Iris下的a*參數(shù)誤差
Tab.1a*parameter error of cross-validation training in Iris
LIBSVM α 索引LIBSVM α 值SDPM α 索引SDPM α 值索引誤差α 值誤差15-4.000 0015-4.000 0000464.000 00046-4.000 0000482.598 229482.599 38004.4*10-4821.401 771821.400 62008.2*10-497-4.000 0097-4.000 0000
表2 TestD1下的a*參數(shù)誤差
Tab.2a*parameter error of cross-validation training in TestD1
LIBSVM α 索引LIBSVM α 值SDPM α 索引SDPM α 值索引誤差α 值誤差353-0.427 90353-0.429 7604.3*10-33563.116 9503563.118 53005.0*10-4451-2.689 05451-2.688 7601.0*10-4
表3 TestD2下的a*參數(shù)誤差
Tab.3a*parameter error of cross-validation training in TestD2
LIBSVM α 索引LIBSVM α 值SDPM α 索引SDPM α 值索引誤差α 值誤差57-0.376 6157-0.377 7603.1*10-393-0.271 4793-0.272 4503.6*10-31480.514 2881480.512 93002.6*10-3211-0.366 88211-0.364 6406.1*10-33840.500 6953840.501 93202.5*10-3
由表1~表3可知,對于3種不同數(shù)據(jù)集, SDPM方式訓(xùn)練萃取得到的a*參數(shù)索引值和LIBSVM得到的一樣, 且a*參數(shù)值的誤差很小可以忽略不計。SDPM中的中間數(shù)據(jù)均是歸一化后的定點(diǎn)格式,而LIBSVM為全浮點(diǎn)運(yùn)算,兩種方式得到的a*值存在一定的誤差,但是由表1~表3可知誤差微乎其微,不會造成最終分類結(jié)果的誤判。表4統(tǒng)計出不同數(shù)據(jù)集下LIBSVM和SDPM兩種模型的閾值和迭代次數(shù)。
表4 交叉驗(yàn)證訓(xùn)練的閾值和迭代次數(shù)統(tǒng)計
Tab.4 Statistics of threshold and iteration times of cross-validation training
數(shù)據(jù)集LIBSVM b值LIBSVM迭代次數(shù)SDPM b值SDPM迭代次數(shù)b值誤差迭代次數(shù)差異Pendigits0.075 815680.075 9063231.3*10-3-245Iris-0.132 392-0.132 2881.4*10-4-84SPECTF1.049 603241.044 2931915.1*10-3-133TestD10.003 35370.003 20364.5*10-2-30TestD20.002 29110.002 337141.7*10-2+3
由表4可知,數(shù)據(jù)集為Pendigits,Iris,SPECTF,TestD1時,SDPM的迭代次數(shù)遠(yuǎn)遠(yuǎn)小于LIBSVM。這大大縮減了模型搜索OTPC的時間,而且b值的誤差均小于0.01,并不會對分組造成誤判。因此SDPM相對于LIBSVM在不對分組進(jìn)行誤判的情況下,大大減少了訓(xùn)練時間。
表5列出了不同測試數(shù)據(jù)集下LIBSVM和SDPM的生成模型時長和SDPM的速度提升倍數(shù)。由于SVM訓(xùn)練的模型只用了兩類數(shù)據(jù),因此只選取Pendigits和Iris的測試數(shù)據(jù)集中第一類和第二類數(shù)據(jù)用于生成應(yīng)用模型的時間比較。
表5 生成模型時間對比
Tab.5 Comparison of running time ofcross-validationtraining
數(shù)據(jù)集LIBSVM時長/msSDPM時長/msSDPM速度提升倍數(shù)Pendigits(1&2)3047.441.1Iris(1&2)-0.1-SPECTF462.220.9TestD11243.238.8TestD2782.432.5
由表5可知,SDPM相比于LIBSVM在生成SVM的應(yīng)用模型時,Pendigits,SPECTF,TestD1,TestD2數(shù)據(jù)集的提升倍數(shù)分別為41.1,20.9,38.8,32.5,而且速度提升的倍數(shù)與數(shù)據(jù)集的長度成正比。在硬件的設(shè)計方面也充分考慮了速度提升方法:硬件采用并行架構(gòu),較于傳統(tǒng)的串行架構(gòu)可以并行進(jìn)行處理數(shù)據(jù)更新;SDPM利用專用的RAM存取中間數(shù)據(jù),從而有效減少了訓(xùn)練時間。
目前機(jī)器學(xué)習(xí)和人工智能[16-17]已經(jīng)深入了各行各業(yè)。未來,隨著技術(shù)的發(fā)展還將不斷突破目前的計算性能,但隨之帶來的問題是如何解決機(jī)器學(xué)習(xí)所伴隨的超高計算復(fù)雜度。本文將共享點(diǎn)積矩陣SDPM算法與機(jī)器學(xué)習(xí)算法SVM模型進(jìn)行深度融合,利用SDPM算法超低復(fù)雜度的實(shí)現(xiàn)優(yōu)勢,解決機(jī)器學(xué)習(xí)算法高復(fù)雜度的實(shí)現(xiàn)難題,并將所提算法進(jìn)行了硬件實(shí)現(xiàn)驗(yàn)證。結(jié)果顯示,在硬件設(shè)計中引入了SDPM算法,減少了訓(xùn)練時間,運(yùn)行速度提升近30倍,為下一步設(shè)計具有超高速、超低功耗、高精度和低復(fù)雜度的數(shù)字信號處理技術(shù)所用的專用芯片提供技術(shù)支撐。