活到老,學到老,人類可以在不斷變化的環(huán)境中連續(xù)自適應地學習——在新的環(huán)境中不斷吸收新知識,并根據(jù)不同的環(huán)境靈活調(diào)整自己的行為。模仿碳基生命的這一特性,針對連續(xù)學習(continual learning,CL)的機器學習算法的研究應運而生,并成為大家日益關注的焦點。
那么,什么是連續(xù)學習?相較于傳統(tǒng)單任務的機器學習方法,連續(xù)學習旨在學習一系列任務,即在連續(xù)的信息流中,從不斷改變的概率分布中學習和記住多個任務,并隨著時間的推移,不斷學習新知識,同時保留之前學到的知識。
(相關資料圖)
然而,這個領域的技術發(fā)展并非一帆風順,面臨著許多難題?!肚f子·秋水》中曾描述過一個這樣的故事:戰(zhàn)國時期,燕國有一少年聽聞趙國都城邯鄲人走路姿勢異常優(yōu)美,心向往之。遺憾的是,他在跟隨邯鄲人學步數(shù)月后,卻把之前走路姿勢忘記了,最后甚至都不會走路了,無奈只好爬回了燕國。有趣的是,這則寓言故事深蘊著當前連續(xù)學習模型的困境之一——災難性遺忘(catastrophic forgetting),模型在學習新任務之后,由于參數(shù)更新對模型的干擾,會忘記如何解決舊任務。而對于機器學習技術而言,另一普遍關注的概念便是泛化誤差(generalization error),這是衡量機器學習模型泛化能力的標準,用以評估訓練好的模型對未知數(shù)據(jù)預測的準確性。泛化誤差越小,說明模型的泛化能力越好。
盡管目前很多實驗研究致力于解決連續(xù)學習中的災難性遺忘問題,但是對連續(xù)學習的理論研究還十分有限。哪些因素與災難性遺忘和泛化誤差相關?它們?nèi)绾蚊鞔_地影響模型的連續(xù)學習能力?對此我們所知甚少。
近期,來自美國俄亥俄州立大學Ness Shroff教授團隊的研究工作“Theory on Forgetting and Generalization of Continual Learning”或有望為這一問題提供詳細的解答。他們從理論上解釋了過度參數(shù)化(over parameterization)、任務相似性(task similarity)和任務排序(task ordering)對遺忘和泛化誤差的影響,發(fā)現(xiàn)更多的模型參數(shù)、更低的噪聲水平、更大的相鄰任務間差異,有助于降低遺忘。同時,通過深度神經(jīng)網(wǎng)絡(DNN),他們在真實數(shù)據(jù)集上驗證了該理論的可行性。
圖注:論文封面,該論文于2023年2月刊登在ArXiv上
連續(xù)學習線性模型的構建
在經(jīng)典的機器學習理論中,參數(shù)越多,模型越復雜,往往會帶來不期望見到的過擬合。但以DNN為代表的深度學習模型則不然,其參數(shù)越多,模型訓練效果越好。為了理解這一現(xiàn)象,作者更加關注在過參數(shù)化的情況下(p>n),連續(xù)學習模型的表現(xiàn)。文章首次定義了基于過參數(shù)化線性模型的連續(xù)學習模型,考量其在災難性遺忘和泛化誤差問題上的閉合解(定理1.1)。
定理1.1當p≥n+2時,則:
T={1,…,T}代表任務序列;||wi? - wj?||2表征任務i和j之間的相似性;p為模型實際參數(shù)的數(shù)量;n為模型需要的參數(shù)數(shù)量;r為過參數(shù)化的比例,r=1-n/p;σ為噪聲水平;ci,j =(1-r)(rT-i-rj-i+rT-j),其中1≤i≤j≤T;更多參數(shù)介紹詳看原始文獻和附錄部分。
(9)式和(10)式分別為災難性遺忘FT和泛化誤差GT的數(shù)學表示。它們不僅描述了連續(xù)學習在線性模型中是如何工作的,還為其在一些真實的數(shù)據(jù)集和DNN中的應用提供指導。
連續(xù)學習中的鼎足三分
在上述數(shù)學模型的基礎上,作者還研究了在連續(xù)學習過程中,過參數(shù)化、任務之間的相似程度和任務的訓練順序三個因素對災難性遺忘和泛化誤差的影響。
1)過參數(shù)化
·更多的模型訓練參數(shù)將有助于降低遺忘
如定理1.1所示,當表示參數(shù)數(shù)量的p趨近于0時,E[FT]也將趨近于零。
·噪聲水平和(或)任務間相似度低的情況下,過參數(shù)化更好
為了比較過參數(shù)化和欠參數(shù)化時模型的性能,作者構建了與定理1.1類似的,在欠參數(shù)情況下的理論模型定理1.2。
定理1.2當n≥p+2時,則:
如定理1.2所示,欠參數(shù)化的情況下,當噪聲水平σ較大時,以及當訓練的任務間區(qū)分度較大時,E[FT]和E[GT]都變大。相反,過參數(shù)化的情況下,當噪聲水平σ較大時,以及當訓練的任務間不太相似時,E[FT]和E[GT]都變小。這表明當噪聲水平高和(或)訓練任務相似性較低時,過參數(shù)化的情況可能比欠參數(shù)化的情況訓練效果更好,即存在良性過擬合。
2)連續(xù)訓練任務的相似性
· 泛化誤差隨著任務相似性的增加而降低,而遺忘則可能不會隨之降低
如定理1.1所示,由于公式(10)中G2項的系數(shù)始終為正,所以當任務之間越相似,區(qū)分度越少時,泛化誤差會相應降低。但是由于公式(9)中,F(xiàn)2項的系數(shù)并不總是為正,所以可能出現(xiàn)任務之間的相似性增加模型的遺忘性能也增加的情況。
3)任務訓練順序
· 在早期階段將差異大的任務相鄰訓練,將有助于降低遺忘
為了找到連續(xù)學習中,任務的最優(yōu)訓練順序。作者考慮了兩種特殊情況。情況一,任務集由一個特殊的任務,和剩余其它完全一模一樣的任務組成。情況二,任務集由數(shù)目相同的不同任務組成。通過對兩種情況的比較分析得出:
首先,特殊的任務在訓練時,應優(yōu)先在前半段執(zhí)行;
其次,相鄰任務之間應差異較大;這些措施都將有助于降低連續(xù)學習模型的遺忘。但是,最小化的遺忘和最小化的泛化誤差的最佳任務訓練排序有時并不相同。
DNN對連續(xù)學習模型的驗證
最后,為了驗證上述推論的可靠性,作者使用DNN在真實數(shù)據(jù)集上進行實驗。后續(xù)的實驗結果明確地證實了,任務相似性對連續(xù)學習模型災難性遺忘的非單調(diào)性影響。而關于任務排序影響的實驗結果也與前面線性模型中的發(fā)現(xiàn)一致,即應在模型訓練早期設置區(qū)分度較大的任務學習,并安排區(qū)分度較大任務相鄰訓練。
表1:使用TRGP和TRGP+兩種任務策略在不同數(shù)據(jù)集中訓練得到的準確性和反向遷移(用負值表示遺忘;值越大/正,表示知識反向遷移效果越好)結果
正向遷移:在學習新任務的過程中,利用以前的任務中學習到的經(jīng)驗來幫助新任務的知識學習。
反向遷移:在學習新任務的過程中,學習到的新知識,鞏固了以前任務的知識學習。
PMNIST數(shù)據(jù)集:MNIST數(shù)據(jù)集是機器學習模型訓練所使用的經(jīng)典數(shù)據(jù)集,包含0-9這10個數(shù)字的手寫樣本,其中每個樣本的輸入是一個圖像,標簽是圖像所代表的數(shù)字。PMNIST是基于MNIST數(shù)據(jù)集的變種,由10種不同的MNIST樣本置換順序的連續(xù)學習任務組成,可進行連續(xù)學習問題的評估。Split CIFAR-100數(shù)據(jù)集:CIFAR-100數(shù)據(jù)集也是機器學習模型訓練所使用的經(jīng)典數(shù)據(jù)集,包含100種分類任務,如蜜蜂、蝴蝶等。每類有600張彩色圖像,其中500張作為訓練集,100張作為測試集。同樣,為了在該數(shù)據(jù)集上進行連續(xù)學習問題的評估,作者將CIFAR-100數(shù)據(jù)集等分為10組,每一組由10個完全不同的分類任務組成,重構了Split CIFAR-100連續(xù)學習數(shù)據(jù)集。
更有趣的是,作者發(fā)現(xiàn),相較于賦以不同時間點學習的舊任務相同的權重(TRGP)的策略,賦以最近學習的舊任務更多的權重(TRGP+),可以更好地促進連續(xù)學習模型的知識正向遷移和反向遷移(表 1)。這些發(fā)現(xiàn)有望為后續(xù)連續(xù)學習策略的設計提供理論參考。
Lin, S., Ju, P., Liang, Y., & Shroff, N. (2023). Theory on Forgetting and Generalization of Continual Learning. ArXiv. /abs/2302.05836韓亞楠, & Liu, Jianwei & Luo, Xiong-Lin. (2021). 連續(xù)學習研究進展. Journal of Computer Research and Development. 10.7544/issn1000-1239.2022.20201058.
關鍵詞: