機器學習2021 — 元學習 (Meta Learning):學習如何學習

機器學習2021 — 元學習 (Meta Learning):學習如何學習

機器學習2021 — 元學習 (Meta Learning):學習如何學習

什麼是元學習?

元學習(Meta Learning)的核心概念是「學習的學習」(Learn to Learn)。這個詞彙中的「元」(Meta)意味著「X 的 X」。它處理的是比傳統機器學習(Machine Learning, ML)更高一個層次的議題。

在傳統機器學習中,我們的目標是找到一個特定的函式(或模型),例如一個影像辨識分類器 fθ​。而在元學習中,我們的目標是找到一個學習演算法 Fϕ​,這個演算法本身也是一個函式。

這個學習演算法 F 的輸入是一組訓練資料集(Dataset),而它的輸出則是訓練完成的模型(例如一個分類器)。因此,元學習旨在透過資料,讓機器自動學習如何找到一個好的學習演算法。

導入元學習的動機

傳統的深度學習(Deep Learning)往往需要人工調整超參數(hyper parameter),例如網路架構、學習率(learning rate)等。調整這些參數是一件非常繁瑣的事情。

元學習的目的之一,就是讓機器能夠根據資料自動學習並決定這些原本由人工設定的超參數。當我們找到了這個好的學習演算法 Fϕ∗​ 後,它就可以應用於任何新的任務上,未來在新任務上就不必再手動調整參數了。

元學習的運作框架(三個步驟)

元學習與傳統機器學習一樣,可以分解為三個核心步驟:

步驟一:定義學習演算法 Fϕ​

我們首先要定義一個學習演算法 F。這個 F 包含了需要被學習出來的未知參數 ϕ。

在元學習的不同方法中,ϕ 可以代表不同的組件:

  1. 初始化參數:例如 MAML(Model-Agnostic Meta-Learning)就是學習一組最佳的初始化參數。

  2. 網路架構:例如神經網路架構搜索(NAS)就是將網路架構當作 ϕ 來學習。

  3. 優化器(Optimizer):學習如何更新參數的機制,例如學習率。

步驟二:定義元損失函式 L(ϕ)

我們需要定義一個損失函式 L(ϕ) 來判斷一個學習演算法 Fϕ​ 的好壞。如果 L(ϕ) 的值越小,代表這個學習演算法越好。

• 訓練資料的單位是「任務」(Task):在元學習中,我們收集的訓練資料是一大堆訓練任務。

• 任務結構:每一個訓練任務 N 都有其各自的訓練資料(在文獻中常稱為 Support Set)和測試資料(常稱為 Query Set)。

• 損失計算:對於每一個任務 N,我們將其訓練資料(Support Set)丟給學習演算法 Fϕ​ 進行任務內訓練(within-task training),得到一個模型 fθN​。接著,我們將這個訓練好的模型 fθN​ 用該任務的測試資料(Query Set)上,計算出任務損失 Ln​。

• 總損失:總損失 L(ϕ) 則是將所有 N 個訓練任務的損失 Ln​ 加總或平均起來。

整個「任務內訓練」加上「任務內測試」的過程,合起來稱為一個 Episode。

步驟三:優化與求解

最後一步是找到能讓總損失 L(ϕ) 最小化的 ϕ∗。

• 如果 L(ϕ) 對 ϕ 可以計算梯度(微分),我們就可以使用梯度下降法(Gradient Descent)進行優化。

• 如果 ϕ 是離散的參數(例如網路架構的層數),無法進行微分,則可能需要使用強化學習(Reinforcement Learning, RL)或演化演算法(Evolutionary Algorithm)來求解。

MAML:學習初始化參數的代表性方法

在眾多元學習方法中,MAML(Model-Agnostic Meta-Learning)是最廣為人知的方法之一。MAML 是 Model-Agnostic Meta-Learning 的縮寫,它在 2017 年 ICML 會議上提出。

• Model-Agnostic 意指「與模型無關」,即它可以跨模型使用,不受特定模型限制。但值得注意的是,MAML 仍有一個限制:所有任務的模型結構必須是相同的。

• 目標:MAML 旨在學習一個最好的初始化參數 ϕ。

• 核心思路:MAML 關注的是參數 ϕ 經過訓練以後的潛力,而不是它目前的表現。一個好的 ϕ 即使本身表現不佳,但只要將其作為初始參數,在不同的任務上經過梯度更新後,都能迅速找到各任務的最佳參數 θhead,那麼它就是一個好的 ϕ。

MAML 與模型預訓練(Model Pre-training)的區別

MAML 的訓練假設

在 MAML 訓練的元訓練(across-task training)過程中,通常假設任務內訓練(within-task training)只進行一次參數更新。

• 目的:這主要是為了減少計算量,因為元學習的計算需求通常非常大。此外,這也鼓勵 MAML 找到一個強大的初始化參數 ϕ,使其只需一次更新就能達到很好的結果。

• 測試階段:雖然訓練時假設只更新一次,但在實際測試新的任務時,你可以進行多次參數更新,以獲得更好的結果。

元學習的應用與相關概念

Few-shot Learning (小樣本學習)

元學習常被應用於解決 Few-shot Learning(小樣本學習)問題。

• 區別:Few-shot Learning 是期望達成的「目標」,即機器只看少量範例(如每個類別只有幾張圖片)就能學會分類。元學習是達成該目標的「手段」。

• N-way K-shot:在 Few-shot 影像分類任務中,常見的任務定義是 N-way K-shot classification。這意味著在一個任務內,有 N 個類別(class),而每個類別只有 K 個範例。

• 基準數據集:Omniglot 是常被用作 Few-shot Learning 基準測試(Benchmark)的語料庫,它包含 1623 個不同字符,每個字符有 20 個範例。

其他可學習的參數

除了初始化參數,元學習框架還可用於學習其他組件,證明了「萬物皆可 Meta」的潛力:

• 優化器(Optimizer):機器可以學會如何決定學習率或優化策略。

• 網路架構(NAS):Network Architecture Search (NAS) 本質上就是一種元學習,其學習的 ϕ 是網路架構。

• 資料增強(Data Augmentation):元學習可以自動學習出最佳的資料增強策略。

開發任務 (Development Task) 的重要性

如同傳統機器學習需要劃分訓練集、驗證集(Development Set)和測試集,元學習在訓練過程中也應當包含訓練任務(Training Task)、開發任務(Development Task)測試任務(Testing Task)。開發任務應用於選擇元學習演算法本身的超參數,避免在測試任務上過度擬合。

其他應用:

結語

元學習是機器學習發展的重要方向,它將學習的目標從單一的模型參數提升到學習演算法本身。儘管元學習在實作上仍然需要調校參數,但其核心目標是找到一個高效能的學習演算法 Fϕ∗​,使其在面對新的任務時,能像人類一樣具備快速學習與遷移的能力。目前元學習的應用已從簡單的圖像任務擴展到語音和自然語言處理等更複雜的任務中,前景可期。

參考資訊

李弘毅機器學習

生成式影片https://youtu.be/_I3GpMJhpSI

Comments

Loading comments…

Leave a Comment