費(fèi)經(jīng)泰,郝慶一,程一元,孫 釗
(1.巢湖學(xué)院 數(shù)學(xué)與統(tǒng)計(jì)學(xué)院,安徽 巢湖 238024;2.安慶師范大學(xué) 數(shù)理學(xué)院,安徽 安慶 246133;3.安徽建筑大學(xué) 數(shù)理學(xué)院,合肥 230601)
在大規(guī)模機(jī)器學(xué)習(xí)問題中,常??紤]如下的優(yōu)化問題[1]:
(1)
其中,n表示樣本的數(shù)目,d表示樣本的維度,每個(gè)分量函數(shù)fi是凸的,并且具有連續(xù)的Lipschitz梯度。這里假定ω*是問題(1)的最優(yōu)解。
對(duì)于問題(1),傳統(tǒng)的方法采用的是確定性梯度下降算法(GD)[2],GD算法的迭代公式如下:
其中,ηt表示學(xué)習(xí)率??梢钥吹矫康淮我?jì)算全部樣本的梯度,故計(jì)算的復(fù)雜度較高,所以后續(xù)學(xué)者提出了隨機(jī)梯度下降算法(SGD)[3],迭代公式如下:
ωt+1=ωt-ηt?fit(ωt)
SGD算法每次只選取一個(gè)樣本進(jìn)行迭代,因此計(jì)算量大大減少,并且隨機(jī)梯度是真實(shí)梯度的無偏估計(jì),即E[?fit(ωt)]=?F(ωt)。但隨著迭代的進(jìn)行,SGD算法所產(chǎn)生的方差也在不斷加大。因此即使在強(qiáng)凸條件下也僅具有次線性的收斂速度。[4-6]
為了克服上述問題,2013年,Johnson[7]提出了方差縮減梯度下降算法(SVRG),該算法采用正則化的隨機(jī)梯度,迭代公式如下:
它仍是真實(shí)梯度的無偏估計(jì),即E[vt]=?F(ωt)。并且文獻(xiàn)[7]證明了隨著迭代進(jìn)行,方差在不斷地減小,從而在強(qiáng)凸條件下能夠達(dá)到線性收斂速率,因此得到廣泛應(yīng)用。此后有學(xué)者相繼提出一些改進(jìn)或者相似的方差縮減類算法,具有代表性的有隨機(jī)平均梯度算法(SAG)[8],該算法每輪隨機(jī)選擇一個(gè)樣本計(jì)算梯度,其他樣本的梯度保持不變,最后再將整個(gè)樣本的梯度進(jìn)行平均來更新參數(shù)值。隨機(jī)加速平均梯度下降算法(SAGA)[9]將SAG算法中的梯度替換思想和SVRG的方差縮減思想相結(jié)合,從而在強(qiáng)凸條件下收斂速度更快些。但SAG算法和SAGA算法需要用一張?zhí)荻缺韥韮?chǔ)存每個(gè)樣本的梯度,故占用內(nèi)存比較大。2017年,Nguyen[10]提出了隨機(jī)遞歸梯度下降算法(SARAH),該算法在更新梯度的過程中使用了遞歸方法,所以不需要對(duì)梯度進(jìn)行保存,并且在迭代的過程當(dāng)中方差也在減小,因此也是有效的方差縮減算法。
本文在傳統(tǒng)的SARAH算法基礎(chǔ)上,提出了一種基于加權(quán)平均思想的方差縮減算法—WA-SARAH算法,證明了該算法在強(qiáng)凸條件下具有線性收斂速率,并且得到了更好的收斂階,最后在經(jīng)典機(jī)器學(xué)習(xí)數(shù)據(jù)集上,算法的實(shí)驗(yàn)結(jié)果表現(xiàn)良好。
本節(jié)先介紹傳統(tǒng)的SARAH算法,算法的流程如下:
算法1 SARAH算法輸入:初始向量ω0,內(nèi)循環(huán)數(shù)m,學(xué)習(xí)率η1:for s=1,2,…,do2: ω0=ω~s-13: v0=1n∑ni=1儊fi(ω0)4: ω1=ω0-ηv05: for t=1,…m-1,do6: 隨機(jī)選擇一個(gè)樣本it7: vt=儊fit(ωt)-儊fit(ωt-1)+vt-18: ωt+1=ωt-ηvt9: end for10: ω~s=ωt; t從0,1,2,…,m 隨機(jī)選取11:end for
SARAH算法采取的是兩層循環(huán)迭代,外層循環(huán)計(jì)算樣本的全梯度,內(nèi)層循環(huán)計(jì)算正則化的隨機(jī)梯度。不同于經(jīng)典的方差縮減梯度下降算法(SVRG),SARAH算法的內(nèi)層循環(huán)中隨機(jī)梯度采取的是遞歸地更新方式,它是最優(yōu)下降方向的有偏估計(jì),即
E[vt]=?F(ωt)-?F(ωt-1)+vt-1≠?F(ωt)
現(xiàn)對(duì)上述算法進(jìn)行推廣,提出一種基于加權(quán)平均隨機(jī)遞歸梯度下降算法(WA-SARAH),具體見算法2。從算法2中可以看到SARAH算法是WA-SARAH算法的一種特殊情形。
算法2 WA-SARAH算法輸入:初始向量ω0,內(nèi)循環(huán)數(shù)m,學(xué)習(xí)率η1:for s=1,2,…,do2: ω0=ω~s-13: v0=1n∑ni=1儊fi(ω0)4: ω1=ω0-ηv05: for t=1,…m-1,do6: 隨機(jī)選擇一個(gè)樣本it7: vt=p儊fit(ωt)-儊fit(ωt-1) +vt-18: ωt+1=ωt-ηvt9: end for10: ω~s=ωt; t從0,1,2,…,m 隨機(jī)選取11:end for
對(duì)WA-SARAH算法中第7行中的vt的更新方式
vt=p[?fit(ωt)-?fit(ωt-1)]+vt-1
(2)
重新進(jìn)行改寫,得到如下:
本節(jié)將給出WA-SARAH算法的收斂性分析,為此先給出相關(guān)的假設(shè)和引理。
假設(shè)1:(L-光滑).每個(gè)fi:Rd→R,i=1,2,…n是L-光滑的,即對(duì)任意的ω,ω′∈Rd,存在常數(shù)L>0,使得
‖?fi(ω)-?fi(ω′)‖≤L‖ω-ω′‖
假設(shè)2:(μ-強(qiáng)凸).函數(shù)F:Rd→R是μ-強(qiáng)凸的,即對(duì)任意的ω,ω′∈Rd,存在常數(shù)μ>0,使得
假設(shè)3:每個(gè)fi:Rd→R,i=1,2,…n是凸的,即對(duì)任意的ω,ω′∈Rd,有
fi(ω)≥fi(ω′)+?fi(ω′)T(ω-ω′)
引理1[2]:如果F是凸函數(shù)并且L-光滑,則對(duì)任意的ω,ω′∈Rd,有
(3)
2L[F(ω)-F(ω*)]≥‖?F(ω)‖2
(4)
(5)
引理2[2]如果F是強(qiáng)凸函數(shù),則對(duì)任意的ω∈Rd,有
2μ[F(ω)-F(ω*)]≤‖?F(ω)‖2
(6)
證明由ωt+1=ωt-ηvt,有
引理4在假設(shè)1的條件下,根據(jù)WA-SARAH算法,對(duì)?t≥1,有
證明E‖ρ?F(ωj)-vj‖2=
E‖[ρ?F(ωj-1)-vj-1]+[ρ?F(ωj)-ρ?F(ωj-1)]-[vj-vj-1]‖2=
‖ρ?F(ωj-1)-vj-1‖2+‖ρ?F(ωj)-ρ?F(ωj-1)‖2+E‖vj-vj-1‖2+
2ρ(ρ?F(ωj-1)-vj-1)T(?F(ωj)-?F(ωj-1))-
2(ρ?F(ωj-1)-vj-1)TE[vj-vj-1]-
2ρ(?F(ωj)-?F(ωj-1))TE[vj-vj-1]=
‖ρ?F(ωj-1)-vj-1‖2-ρ2‖?F(ωj)-?F(ωj-1)‖2+E‖vj-vj-1‖2
(7)
其中
注意?F(ω0)=v0,則‖ρ?F(ω0)-v0‖2=(ρ-1)2‖v0‖2。(7)式兩邊對(duì)j=1,2,…,t相加并取期望得
證明對(duì)于?j≥1,有
E‖vj‖2=E‖vj-1-ρ(?fij(ωj-1)-?fij(ωj))‖2=
上式兩邊對(duì)j=1,2,…,t相加得
(8)
由引理4可得
其中合理地選擇ρ,η和m,使得
證明由引理5可得
根據(jù)F的強(qiáng)凸性,結(jié)合?F(ω0)=v0可得
迭代可得
下面通過數(shù)值實(shí)驗(yàn)來驗(yàn)證算法的效率,這里采用標(biāo)準(zhǔn)的機(jī)器學(xué)習(xí)數(shù)據(jù)集:Mnist數(shù)據(jù)集和合成logistic回歸數(shù)據(jù)集(Synthetic logistic data)。其中Mnist數(shù)據(jù)集中訓(xùn)練集樣本數(shù)n=6 000,維度d=784,測(cè)試集樣本數(shù)n=1 000。logistic回歸數(shù)據(jù)集中訓(xùn)練集樣本數(shù)n=10 000,維度d=100,測(cè)試集樣本數(shù)n=3 000。采用的機(jī)器學(xué)習(xí)回歸任務(wù)是帶L2正則項(xiàng)的logistic回歸:
其中,(xi,yi)為給定樣本數(shù)據(jù),λ是正則化參數(shù),F(xiàn)(ω)是一個(gè)強(qiáng)凸光滑的損失函數(shù)。
(a) synthetic data訓(xùn)練集下收斂速率 (b) minist data訓(xùn)練集下收斂速率圖1 不同加權(quán)系數(shù)ρ下收斂速率對(duì)比圖
圖2表示不同隨機(jī)梯度下降算法(SGD,SVRG,SAGA,SARAH,WA-SARAH)在訓(xùn)練集下收斂速率對(duì)比圖。從圖中可以看到SAGA算法在訓(xùn)練集下收斂速度稍快一些,但該算法所占內(nèi)存較大。SGD算法和SVRG算法在迭代前期效果比SARAH算法和WA-SARAH算法好一些,但到了迭代后期,SARAH算法和WA-SARAH算法下降地更快些,即更加逼近最優(yōu)解。
圖3為在不同測(cè)試集下,不同算法在當(dāng)前迭代下錯(cuò)誤率對(duì)比圖。隨著迭代地進(jìn)行,最終可以看到SAGA算法的錯(cuò)誤率比其他算法要低些,其次是WA-SARAH算法和SARAH算法,SGD和SVRG的錯(cuò)誤率要相對(duì)高一點(diǎn)。從而說明本文所提算法的有效性。
(a) synthetic data測(cè)試集下錯(cuò)誤率 (b) minist data測(cè)試集下錯(cuò)誤率圖3 不同算法在不同測(cè)試集下錯(cuò)誤率對(duì)比圖