張?zhí)? 郭輝 郭靜純
摘要:針對增量學(xué)習(xí)存在的災(zāi)難性遺忘和新任務(wù)數(shù)據(jù)逐步積累問題,提出了基于新舊任務(wù)之間相似度的樣本重放優(yōu)化學(xué)習(xí)方法,相似度越高,重放樣本越少。并選擇MINIST數(shù)據(jù)集在卷積神經(jīng)網(wǎng)絡(luò)上進行了實驗研究,驗證了該方法的可行性和有效性。
關(guān)鍵詞: 增量學(xué)習(xí);災(zāi)難性遺忘;樣本重放;任務(wù)相似度
中圖分類號: TP183? ? ? ? 文獻標(biāo)識碼:A
文章編號:1009-3044(2021)08-0013-03
Abstract: To solve the problem of catastrophic forgetting and gradual accumulation of new task data in incremental learning, an optimal learning approach based on the similarity difference between old and new tasks is proposed. The more similar the tasks are, the less the old samples will be replayed. Moreover, MINIST data set is selected to conduct experimental research on the convolutional neural network, which verifies the feasibility and effectiveness of the method.
Key words:incremental learning; catastrophic forgetting; sample replay; task similarity
隨著深度學(xué)習(xí)的快速發(fā)展和在圖像、語音等領(lǐng)域的應(yīng)用,其在單個任務(wù)處理方面取得了優(yōu)異的性能。但當(dāng)它面對多任務(wù)增量學(xué)習(xí)時,常常產(chǎn)生“災(zāi)難性遺忘”現(xiàn)象[1],即學(xué)習(xí)新任務(wù)時會改變原有的網(wǎng)絡(luò)參數(shù),相應(yīng)的舊任務(wù)記憶就會急劇下降甚至完全消失。
樣本重放是緩解災(zāi)難性遺忘的主要方法之一,包括兩種典型方式:一種通過舊任務(wù)的偽樣本生成器保留其信息,如深層生成重放[2]和記憶重放GANs[3],不使用舊任務(wù)原始數(shù)據(jù),但GAN模型訓(xùn)練較復(fù)雜;另一種直接選用舊任務(wù)的原始數(shù)據(jù)子集,如內(nèi)存固定的iCaRl[4]及其改進訓(xùn)練樣本不均衡的增量學(xué)習(xí)文獻[5],文獻[6]提出一種自動記憶框架,基于樣本參數(shù)化選取具有代表性的舊樣本子集,采用雙層優(yōu)化訓(xùn)練框架。這些方法均未考慮新舊任務(wù)之間的相似度差異:相似度越高,網(wǎng)絡(luò)提取的共有信息越多,則對舊任務(wù)的回顧應(yīng)越少。此外,真實環(huán)境下新任務(wù)數(shù)據(jù)通常按照時間順序流式到達,新數(shù)據(jù)較少,無法滿足上述方法的需要。針對這些問題,本文提出了一種基于任務(wù)相似度的增量學(xué)習(xí)優(yōu)化方法,根據(jù)兩者之間相似度差異設(shè)置不同比例的訓(xùn)練數(shù)據(jù),避免重復(fù)訓(xùn)練,減少資源占用,加快訓(xùn)練速度。
1 樣本重放增量學(xué)習(xí)優(yōu)化方法
增量學(xué)習(xí)優(yōu)化方法的實現(xiàn)過程主要分為以下三個階段:首先,當(dāng)新任務(wù)到達時,用特征提取器提取新舊類特征,進行相似度差異分析;其次,根據(jù)相似度差異結(jié)果,計算新舊任務(wù)不同比例的訓(xùn)練數(shù)據(jù)增量,構(gòu)建每批次增量訓(xùn)練數(shù)據(jù)集;最后,進行增量優(yōu)化訓(xùn)練,實現(xiàn)符合真實場景下的新任務(wù)數(shù)據(jù)增量訓(xùn)練和任務(wù)增量學(xué)習(xí)。
1.1 符號表示
假設(shè)增量學(xué)習(xí)分為1個初始階段和N個新任務(wù)的增量階段。在初始階段使用數(shù)據(jù)[D0]進行訓(xùn)練得到網(wǎng)絡(luò)模型[Θ0];在第[i]個增量階段,若有[s]個舊類[X1,X2,...,Xs],新類[Xi,i∈N],模型狀態(tài)為[Θi-1],令[Di?j]、[Dij]、[Dj]分別表示第[i]類第[j]個批次的新增樣本數(shù)據(jù)、前[j]個批次新數(shù)據(jù)和第[j]個批次的新舊訓(xùn)練數(shù)據(jù)。
1.2 任務(wù)相似度分析
根據(jù)假設(shè),新任務(wù)數(shù)據(jù)流式到達。當(dāng)新任務(wù)到達時,首先,選取同等數(shù)量的舊任務(wù)樣本和首次到達的新任務(wù)樣本作為代表性樣本一起訓(xùn)練特征提取網(wǎng)絡(luò)作為特征提取器[φ],通過使用新舊任務(wù)的平衡數(shù)據(jù)集,特征提取器可以更均衡地提取新舊任務(wù)的樣本特征,使網(wǎng)絡(luò)能充分學(xué)習(xí)新舊任務(wù)樣本之間的差異,得到更具有代表性的樣本特征。
對新任務(wù)樣本數(shù)據(jù)提取特征后,采用余弦相似度衡量新舊任務(wù)之間的相似程度,其值越大,特征越相似,計算公式如下:
1.3 構(gòu)建增量訓(xùn)練數(shù)據(jù)集
由于相似度較高的兩個任務(wù),在進行網(wǎng)絡(luò)訓(xùn)練時,相同部分特征已經(jīng)被提取到了,所以對于相似度較高的任務(wù),新舊任務(wù)越相似,則越應(yīng)減少舊任務(wù)的重放訓(xùn)練樣本數(shù)量,減少重復(fù)訓(xùn)練造成的資源浪費;反之,則應(yīng)增加舊類的數(shù)量,強化舊類知識,減少網(wǎng)絡(luò)對于新類的偏向。根據(jù)新舊任務(wù)之間的相似度,令每批次重放舊任務(wù)的樣本增量為[Doldk?j],其計算公式如下:
1.4 蒸餾損失和分類損失計算
蒸餾損失最早在文獻[7]中提出,在增量學(xué)習(xí)中適用于文獻[4,6,8],主要用來促使新的模型和舊的模型在舊類上保持相同的預(yù)測能力。增量學(xué)習(xí)損失包括蒸餾損失[LdΘi;Θi-1;x]和衡量分類準(zhǔn)確度的交叉熵損失[LcΘi;x]之和,兩者的計算公式分別如下:
1.5 增量優(yōu)化訓(xùn)練
通過分析不同任務(wù)之間的相似性差異,在新任務(wù)數(shù)據(jù)流式到達時設(shè)置不同比例的新舊數(shù)據(jù)進行增量優(yōu)化訓(xùn)練,整個的訓(xùn)練流程總結(jié)如下:
算法1 增量優(yōu)化訓(xùn)練
輸入 1個初始任務(wù)(2個類別的分類任務(wù))的數(shù)據(jù)集[D0],N個新增任務(wù)(一個類別表示一個任務(wù))的流式數(shù)據(jù)集[Di,i∈N]
輸出 N+1個任務(wù)(N+2個類別)的分類性能
(1) 用數(shù)據(jù)[D0]訓(xùn)練得到網(wǎng)絡(luò)模型[Θ0]
(2) 新任務(wù)到達,[Di1=500],[Di?j=500],有s個舊類(s的初始值為2)
(3) 新舊類之間進行相似度差異分析,用公式(1)計算新類與每個舊類的余弦相似度[sφXold,φXnew]
(4) 根據(jù)相似度差異結(jié)果,用公式(2)計算舊類每批次投放的樣本增量[Doldk?j,k∈s]
(5) 用公式(3)構(gòu)建第j個批次的訓(xùn)練數(shù)據(jù)
(6) 進行增量訓(xùn)練
(7) if各個類別的分類性能達到預(yù)期 //測試網(wǎng)絡(luò)分類性能
(8) then if 還有未完成的任務(wù) then 返回步驟(2) //繼續(xù)訓(xùn)練下一個增量任務(wù)
(9) else 輸出N+1個任務(wù)(N+2個類別)的分類性能 //已經(jīng)完成N+1個任務(wù)的增量學(xué)習(xí)
(10) end if
(11) else then 返回步驟(5) //任務(wù)分類準(zhǔn)確率沒有達到要求,繼續(xù)訓(xùn)練
(12) end if
2 實驗研究
選取MNIST數(shù)據(jù)集中的數(shù)字0、1、2在三層卷積神經(jīng)網(wǎng)絡(luò)上進行增量學(xué)習(xí),以數(shù)字0和1作為初始階段,數(shù)字2為新增類別階段。實驗結(jié)果如表1所示。
由表1可知,采用本文方法進行增量學(xué)習(xí),在第6批次時的平均準(zhǔn)確率為0.9818,比重放全部舊數(shù)據(jù)的準(zhǔn)確率0.99稍小,但訓(xùn)練數(shù)據(jù)量急劇下降,由5923+6741個舊樣本變?yōu)?0+66,顯著提升了訓(xùn)練效率。以此類推依次完成數(shù)字3-9的增量學(xué)習(xí),對比結(jié)果如圖1所示。
圖1中折線圖的橫坐標(biāo)為增量學(xué)習(xí)的各個階段,縱坐標(biāo)為平均分類精度,圖中結(jié)果表明相較于使用全部的新舊類訓(xùn)練數(shù)據(jù),使用新的基于任務(wù)相似度的增量學(xué)習(xí)優(yōu)化方法雖然在分類精度上有所下降,但是結(jié)果相差不大,能有效緩解災(zāi)難性遺忘的影響,且所使用的訓(xùn)練數(shù)據(jù)集要遠小于使用全部的訓(xùn)練集,減少了訓(xùn)練量,加快了訓(xùn)練速度。
3 結(jié)論
針對增量學(xué)習(xí)中的災(zāi)難性遺忘問題,提出了一種基于新舊任務(wù)相似度的樣本重放學(xué)習(xí)方法,在盡量保持對舊任務(wù)記憶的同時著力提升學(xué)習(xí)效率,據(jù)此選用MINIST數(shù)據(jù)集進行實驗研究,驗證了該方法的可行性與有效性,為緩解災(zāi)難性遺忘提供了新的解決思路。
參考文獻:
[1] McCloskey M,Cohen N J.Catastrophic interference in connectionist networks:the sequential learning problem[J].Psychology of Learning and Motivation,1989,24:109-165.
[2] Shin H, Lee J K, Kim J, et al. Continual learning with deep generative replay[C]. Advances in Neural Information Processing Systems. Curran Associates: New York, 2017:2991-3000.
[3] Wu C S, Herranz L, Liu X L, et al. Memory Replay GANs: learning to generate images from new categories without forgetting[C].Advances in Neural Information Processing Systems. Curran Associates: New York, 2018: 5962-5972.
[4] Rebuffi S A, Kolesnikov A, Sperl G, et al. iCaRL: Incremental Classifier and Representation Learning[C]. Proc of the IEEE Conf on Computer Vision and Pattern Recognition. Piscataway: IEEE Computer Society, 2017: 5533-5542.
[5] Castro F M, Marin-Jimenez M J, Guil N, et al. End-to-End Incremental Learning[C]. European Conference on Computer Vision. Berlin: Springer, 2018:233-248.
[6] Liu Y Y, Su Y , Liu A A , et al. Mnemonics Training: Multi-Class Incremental Learning Without Forgetting[C]. CVPR, 2020:12242-12251.
[7] Hinton G, Vinyals O, Dean J. Distilling the Knowledge in a Neural Network[J]. Computer Science, 2015, 14(7)38-39.
[8] Zenke F, Poole B, Ganguli S. Continual Learning Through Synaptic Intelligence[C].International Conference on Machine Lea rning. Lille: International Machine Learning Society, 2017:3987-3995.
【通聯(lián)編輯:唐一東】