許仁杰,劉寶弟,張凱,劉偉鋒
(中國石油大學(xué)(華東)海洋與空間信息學(xué)院,青島 266580)
元學(xué)習(xí)(meta learning)是一種“學(xué)習(xí)如何學(xué)習(xí)”的機器學(xué)習(xí)算法。模型無關(guān)的元學(xué)習(xí)(Model Agnostic Meta Learning,MAML)算法[1]通過元任務(wù)集合中的數(shù)據(jù)學(xué)習(xí)可以快速適應(yīng)某些目標數(shù)據(jù)任務(wù)的初始模型[2-5],可以使用數(shù)量有限、有標記的任務(wù)樣本進行訓(xùn)練,且在訓(xùn)練時可使用不同的模型,因而被廣泛應(yīng)用于解決各個領(lǐng)域的問題[6-14]。
雖然MAML 在解決回歸、分類和強化學(xué)習(xí)等問題中都有較好的表現(xiàn),但是其計算復(fù)雜度高、過擬合、梯度下降速度慢等問題還有待解決,因此研究者們從多個角度對MAML 進行了改進,包括簡化運算、改進損失函數(shù)和運算流程等。針對MAML 的計算復(fù)雜度太高與梯度更新方法計算過于復(fù)雜的問題,有學(xué)者提出了只用一階導(dǎo)數(shù)對二階導(dǎo)數(shù)進行逼近的元參數(shù)優(yōu)化的一階MAML、Reptile[10]方法;這一類方法雖然可以簡化MAML 的計算復(fù)雜度,但通常是以準確率的降低作為代價。針對MAML 在某些情況下會產(chǎn)生過擬合或者無法有效訓(xùn)練等問題,有學(xué)者通過信息論、高斯過程等方法提出了更緊的損失函數(shù)[6-7,12],或者通過設(shè)置參數(shù)使MAML 對過去學(xué)到的元知識進行遺忘[9];但這種方式通常是針對某一類具體問題而進行的,往往有它自身的局限性。針對MAML 下降速度過慢的問題,有學(xué)者從一個函數(shù)空間中構(gòu)造更好的權(quán)函數(shù)[8,15]或者提出了更有效的梯度更新方式[16-18];但這一類改進方法會在提升準確率的同時帶來更大的計算量。
盡管上述方法都對MAML 進行了有效的改進,但是由于通過前兩種方法改進MAML 在訓(xùn)練過程中認為每個樣本對于元知識的影響都是一樣的,無法很好地根據(jù)不同的任務(wù)對損失函數(shù)進行調(diào)整,也不能根據(jù)抽取樣本是否能很好地體現(xiàn)該任務(wù)的性質(zhì)而改變樣本對整體的影響,所以在學(xué)習(xí)過程中依然可能會產(chǎn)生訓(xùn)練速度低、過擬合或準確率較低等問題;而后一種方法的計算量較大。為了解決這些問題,本文通過概率方法構(gòu)造出了一個更好的權(quán)函數(shù)來提高MAML 的訓(xùn)練速度以及準確率。與文獻[8]中從再生核希爾伯特空間搜索損失函數(shù)不同,本文提出了一種更輕量、更便于計算的權(quán)函數(shù),對每個任務(wù)損失函數(shù)進行加權(quán),用來表示不同的任務(wù)在訓(xùn)練過程中的重要程度。具體地,本文認為隨機抽取的任務(wù)近似符合一個高斯分布,越靠近這個高斯分布的期望的任務(wù)在元參數(shù)更新過程中占據(jù)更加重要的地位;相反地,越遠離高斯分布期望的任務(wù)所占的權(quán)重應(yīng)該越小。添加這個權(quán)函數(shù)的MAML 可以在更快逼近任務(wù)分布的期望的同時避免一些小概率出現(xiàn)的任務(wù)對網(wǎng)絡(luò)訓(xùn)練造成更大的影響,從而在提升訓(xùn)練速度的同時增加模型的準確度,訓(xùn)練好的元參數(shù)也能更適用于高概率出現(xiàn)的任務(wù)。
將本文方法與基礎(chǔ)的MAML 方法在Omniglot 與Mini-ImageNet 數(shù)據(jù)集上進行小樣本圖像分類實驗,結(jié)果表明在大多數(shù)情況下,本文方法的準確率都高于傳統(tǒng)的MAML。
本文主要工作包括:從高斯隨機過程的角度提出了一種與迭代相關(guān)的MAML 解釋方法,并根據(jù)這種解釋方法通過貝葉斯分析提出了一種加權(quán)的MAML——BW-MAML,最后通過實驗驗證了BW-MAML 的有效性。
本文工作的基礎(chǔ)是HB-MAML(Model-Agnostic Meta-Learning as Hierarchical Bayesian)[3]以及加權(quán)元學(xué)習(xí)[8]。文獻[7]從貝葉斯分析的角度出發(fā),將元學(xué)習(xí)的過程描述為一個高斯隨機過程,并以此提出了一個正則化項;而文獻[8]從泛函分析的角度,認為常用的平方誤差與Hinge 誤差在原空間的核函數(shù)都能構(gòu)成再生核希爾伯特空間,并在這個空間中選取最優(yōu)的損失函數(shù)。通過文獻[8]中方法可以找到下降速度更快的損失函數(shù),但該損失函數(shù)是通過抽取的樣本獲得的,所以根據(jù)抽取任務(wù)的不同會使損失函數(shù)產(chǎn)生較大的波動,進而使優(yōu)化難度偏高。基于上述文獻的成果,本文從貝葉斯分析[19]與高斯隨機過程[18,20]的角度在線性函數(shù)空間中找到一個更便于計算的最優(yōu)損失函數(shù),使更重要任務(wù)的損失在損失函數(shù)中占更大的權(quán)重,通過優(yōu)化這個損失函數(shù)可以使元參數(shù)更容易向最優(yōu)解進行梯度下降。
最終在外循環(huán)中使用隨機梯度下降方法通過迭代求得其最小值。在這個過程中,將每個任務(wù)的損失相加作為整個模型的損失,旨在求得在迭代一次后對每個任務(wù)的損失都最小的θ。所以將
作為外循環(huán)的迭代方法。
MAML 在訓(xùn)練時,首先從任務(wù)分布中抽取一些任務(wù),使用一個內(nèi)循環(huán)針對每個任務(wù)的參數(shù)根據(jù)損失函數(shù)進行梯度下降;然后根據(jù)更新過參數(shù)的任務(wù)損失,使用一個外循環(huán)對元參數(shù)進行梯度下降,以獲得一個最適合全部任務(wù)的元參數(shù)。在這個過程中,MAML 將所有任務(wù)視為是同等重要的。
基于貝葉斯權(quán)函數(shù)的模型無關(guān)元學(xué)習(xí)就是在MAML 的元梯度更新方法上進行改進,在本文中,根據(jù)不同任務(wù)在訓(xùn)練中重要性不同,在外循環(huán)的元梯度下降時求MAML 中每個任務(wù)損失的加權(quán)和,從而能使元參數(shù)更快地進行訓(xùn)練。本文采用由貝葉斯分析推導(dǎo)而來的損失函數(shù),因此本文將這種改進算法稱為基于貝葉斯權(quán)函數(shù)的模型無關(guān)元學(xué)習(xí)(Bayes-Weighted Model-Agnostic Meta-Learning,BW-MAML)算法。
接下來介紹貝葉斯分析角度設(shè)置的損失函數(shù)及權(quán)函數(shù)的推導(dǎo)過程。
高斯隨機過程[18,20]是機器學(xué)習(xí)中常用的方法之一,在實踐中可以對機器學(xué)習(xí)的梯度下降過程視為一串隨機的概率事件進行分析。而對于其中的一個隨機事件,與文獻[7]中的推導(dǎo)類似,根據(jù)貝葉斯分析將上文中的損失函數(shù)(3)重寫成一個概率形式:
元學(xué)習(xí)的損失函數(shù)最小的問題就轉(zhuǎn)化為一個令負log 概率最小的問題,也就是找到一個元參數(shù),使在各個任務(wù)中經(jīng)過一次或幾次梯度下降后的任務(wù)參數(shù)屬于該任務(wù)的概率最高。
基于損失函數(shù)(4),可以得到如下推斷:如果使用抽取的訓(xùn)練任務(wù)以元參數(shù)為基礎(chǔ)進行訓(xùn)練,在理想情況下,第n個元參數(shù)θ()n會在數(shù)次迭代后達到一個對該任務(wù)最優(yōu)的點,記為,本文認為所有的都是對的逼近,而且由于噪聲的存在,一般認為:
由于抽取的任務(wù)隨機,并且都屬于同一個任務(wù)分布P(T),所以這些任務(wù)都獨立同分布,即它們都擁有同樣的統(tǒng)計學(xué)規(guī)律。根據(jù)一般性假設(shè),在本文中認為這些任務(wù)在任務(wù)空間中都符合高斯分布[18,20],使用一個邊界似然函數(shù)來表示一步元參數(shù)更新的條件概率:
為元參數(shù)的更新方式,而不是簡單地把各個樣本看作是均勻分布。
又由于每個θi符合一個高斯分布,所以任意幾個的值的分布也應(yīng)該符合一個同期望的高斯分布,所以把這個公式的右側(cè)進行歸一化作為本文算法的權(quán)函數(shù)就可以得到最終的元迭代格式:
通過將添加這個權(quán)函數(shù)的元參數(shù)更新方式替代原本的元參數(shù)更新方式,可以對優(yōu)化元參數(shù)貢獻更大的損失進行強調(diào),對出現(xiàn)概率較小的損失則通過較小的權(quán)函數(shù)降低其對整個迭代過程的影響。因此BW-MAML 可以降低整個梯度下降過程的隨機性,并且使終點更加趨近于所有分布的平均值,以獲得一個更重視高概率出現(xiàn)的任務(wù),一定程度上忽略小概率出現(xiàn)任務(wù)對元參數(shù)產(chǎn)生的影響。訓(xùn)練時算法的偽代碼如算法1 所示。
算法1 BW-MAML 的訓(xùn)練過程。
輸入 任務(wù)分布p(T)步長α,β;
輸出 優(yōu)化后的參數(shù)θ。
如圖1 所示,BW-MAML 等價于將MAML 通過幾個任務(wù)的參數(shù)求得下一步的元參數(shù)的過程改為通過估計元參數(shù)的期望,并將得到的期望作為下一步的元參數(shù)開始下一次迭代。
圖1 一階BW-MAML原理Fig.1 Principle of first-order BW-MAML
BW-MAML 與基礎(chǔ)MAML 算法的不同點體現(xiàn)在算法1 中的第8)行,簡單來講,傳統(tǒng)的MAML 算法直接將幾個任務(wù)的損失相加,而本文算法在計算任務(wù)損失函數(shù)的加權(quán)和的同時使用高斯分布的權(quán)函數(shù)而不是均勻分布,使元參數(shù)能更快、更準確地逼近最優(yōu)解。
本文在Mini-ImageNet 數(shù)據(jù)集[21]與Omniglot 數(shù)據(jù)集[22]上進行了小樣本圖像分類實驗,對BW-MAML 的有效性和實用性進行驗證。
Omniglot 是一個手寫字母數(shù)據(jù)集,包含50 個不同字母的1 623 個不同手寫字符,在處理數(shù)據(jù)集時將其分成了包含30個字母的“背景”集和包含20 個字母的“評估”集;Mini-ImageNet 數(shù)據(jù)集是元學(xué)習(xí)和小樣本學(xué)習(xí)中常用的數(shù)據(jù)集之一,它包含100 類共60 000 幅彩色圖片,每類中含有600 個樣本,每幅圖片的規(guī)格為84×84。
為了驗證BW-MAML 在較小數(shù)據(jù)集上的性能,在Omniglot 數(shù)據(jù)集上測試了一階MAML(First-Order MAML,F(xiàn)OMAML)與BW-MAML 的5-way 1-shot、5-way 5-shot、20-way 1-shot 以及20-way 5-shot 的小樣本分類對比實驗,其中,NwayK-shot 意味著在任務(wù)中包含N個類,而每個類中包含K個樣本。在網(wǎng)絡(luò)選擇上,本文采用了一個使用3×3 卷積核的四層卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network,CNN)作為其內(nèi)容網(wǎng)絡(luò)。在訓(xùn)練過程中,每次從訓(xùn)練集中隨機抽取6 個訓(xùn)練任務(wù),然后對內(nèi)容網(wǎng)絡(luò)按照一階BW-MAML、一階MAML等不同算法針對每個任務(wù)每次進行5 次梯度下降,總共進行60 000 次迭代。對于超參數(shù),與MAML 相同,本文選擇任務(wù)參數(shù)學(xué)習(xí)率α=0.1,元參數(shù)學(xué)習(xí)率β=0.001,元參數(shù)的訓(xùn)練使用Adam[23]作為優(yōu)化器。本文將準確率定義為測試集中預(yù)測正確的數(shù)量與總量的比值,表1 中的準確率是10 組準確率的平均值。從表1 可以看出,在Omniglot 數(shù)據(jù)集上,1-way 5-shot 與5-way 5-shot 時BW-MAML 和MAML 的準確率接近,20-way 1-shot 與20-way 5-shot 時,BW-MAML 相對MAML 的準確率平均提升了0.199 個百分點。
表1 兩種算法在Omniglot上的準確率對比 單位:%Tab.1 Accuracy comparison of two algorithms on Omniglot unit:%
在較大的數(shù)據(jù)集Mini-ImageNet 上進行實驗時,本文將Mini-ImageNet 隨機分為不相交的訓(xùn)練集與測試集,并將訓(xùn)練集依次傳入對網(wǎng)絡(luò)進行訓(xùn)練。與上一組實驗類似,本文在Mini-ImageNet 上進行了一階、二階MAML 與一階、二階BWMAML 與其他元學(xué)習(xí)算法的5-way 1-shot、5-way 5-shot 的小樣本分類對比實驗,除每次訓(xùn)練迭代100 000 次以外,其他超參數(shù)與在Omniglot 上的實驗一致。實驗結(jié)果如表2 所示,可以看出,在Mini-ImageNet 上BW-MAML 的各項準確率都比MAML 更高。通過使用權(quán)函數(shù)對損失的重要性進行區(qū)分,BW-MAML 比MAML 的平均準確率提高了0.907 個百分點,可見本文的方法無論是在Omniglot 還是在Mini-ImageNet 這樣略大的數(shù)據(jù)集上都表現(xiàn)得更好。
表2 Mini-ImageNet上的準確率對比 單位:%Tab.2 Accuracy comparison on Mini-ImageNet unit:%
為了驗證每次抽取的不同任務(wù)數(shù)對模型的影響,在Mini-ImageNet 中使用5-way 1-shot 的一階BW-MAML 并進行60 000 次迭代,每隔500 步使用100 個測試任務(wù)對模型效果進行評估,然后選取了準確率變化較明顯的訓(xùn)練時期(前段)以使結(jié)果更為明顯,其他參數(shù)設(shè)置與之前的實驗相同。從第n=500,1 000,1 500,2 000,2 500 步與訓(xùn)練完成后最終的準確率探究了每次抽取4 個、6 個與8 個任務(wù)對BW-MAML 訓(xùn)練速度的影響,結(jié)果如表3 所示。從表3 可以看出,BW-MAML 在收斂速度方面的效果也優(yōu)于MAML,在訓(xùn)練進行2 500 步后,6 任務(wù)時BW-MAML 的準確率是最高的,且比同樣6 任務(wù)的MAML 準確率提高了1.9 個百分點。但在訓(xùn)練完成后,6 任務(wù)的最終的準確率介于8 任務(wù)和4 任務(wù)的準確率之間??梢婋m然最終的準確率和每次訓(xùn)練所用的任務(wù)數(shù)存在正比例關(guān)系,但在2 500 步內(nèi),BW-MAML 在6 任務(wù)情況下的訓(xùn)練速度最快。
表3 針對不同任務(wù)數(shù)在Mini-ImageNet上的準確率對比Tab.3 Contrast experiment for different task numbers on Mini-ImageNet
由于MAML 在選擇任務(wù)上具有隨機性,而在實際使用這些任務(wù)進行訓(xùn)練時并沒有考慮每個任務(wù)對元參數(shù)的影響。在本文中通過理論推導(dǎo)并論證了一種新的貝葉斯加權(quán)的MAML,然后通過實驗驗證了這個方法在兩個數(shù)據(jù)集上的實用性,并通過一個對比實驗檢驗了超參數(shù)(任務(wù)數(shù))的選擇,這證明本文提出的方案確實提升了實驗的準確率,本文的方法可以提升在較為符合高斯分布的數(shù)據(jù)集上的準確率。在常用的數(shù)據(jù)集中BW-MAML 比MAML 的準確率更高。但還有很多新的思路亟待嘗試,比如先選擇一種更好的損失基函數(shù),然后再對這組基函數(shù)求出最優(yōu)的權(quán)系數(shù);或者先通過一些方法求出樣本大概的分布情況,然后在這個基礎(chǔ)上再進行加權(quán);再或者直接通過高斯過程設(shè)計出新的結(jié)構(gòu)以取代梯度下降等。