機器學習2021 -GAN 生成對抗網路
機器學習2021 -GAN 生成對抗網路¶
生成對抗網路(GAN)介紹
生成對抗網路(Generative Adversarial Network, GAN)是一種廣為人知的生成式模型(Generative Model)。GAN的核心目標是訓練一個網路,使其輸出不再是單一固定的結果,而是一個機率分佈,這對於需要創造力或擁有多種可能輸出的任務(如影像生成、對話系統)尤其重要。
GAN的架構與運作原理
GAN主要由兩個相互競爭的網路組成:生成器(Generator, G)與判別器(Discriminator, D)。
- 生成器 (Generator, G):
◦ G是一個類神經網路,負責產生假的數據或圖片。
◦ G的輸入通常是一個隨機變數 Z,Z是從一個簡單的、可取樣的分佈(如 Normal Distribution)中採樣出來的向量。
◦ G會將這個簡單的輸入分佈轉換成一個複雜的分佈(稱為 PG)。
◦ G的訓練目標是騙過判別器,使其輸出的圖片看起來足夠真實。
- 判別器 (Discriminator, D):
◦ D也是一個類神經網路,其輸入是一張圖片,輸出是一個數值。
◦ 這個數值代表輸入圖片與真實資料(Pdata)的相似程度,數值越大,代表圖片越像真實的二次元人物圖像或其他真實資料。
◦ D的訓練目標是分辨真正的資料(如圖庫中的圖片)和 G 產生出來的假圖片。
GAN的運作概念常被比喻為演化過程中的對抗關係,例如枯葉蝶與天敵的互相進化。G與D之間是互相砥礪、亦敵亦友的合作關係。
GAN的訓練目標與挑戰
GAN訓練的根本目標是希望 G 產生的分佈(PG)與真實資料的分佈(Pdata)越接近越好。這種接近程度是通過衡量兩個分佈之間的差異度(Divergence)來決定的,目標是找到 G 的參數,使 Divergence 最小。
然而,計算這種 Divergence 存在困難。GAN透過訓練 D 來克服這一限制:
-
Min-Max 問題: GAN的訓練是一個有最小化(Minimize)又有最大化(Maximize)的 Min-Max 問題。
-
訓練 D (最大化 Objective Function): D 的訓練目標是最大化一個 Objective Function (V)。D將真實圖片(從 Pdata 取樣)視為類別 1,將生成圖片(從 PG 取樣)視為類別 2,進行二元分類器的訓練。這個 Objective Function 的最大值,與兩個分佈之間的 JS Divergence(JS 散度)有關。
-
訓練 G (最小化 Divergence): G 的訓練目標是最小化 D 的 Objective Function 的最大值。G透過調整其參數,目標是讓 D 給予生成圖片更高的分數。訓練過程是反覆進行的,先訓練 D,再訓練 G。
GAN訓練的難點與WGAN的發展
GAN 以難以訓練而聞名,因為 G 和 D 必須棋逢敵手,任何一方停止進步,訓練過程都可能崩潰。
早期的 GAN 使用 JS Divergence 作為隱含的衡量標準,導致了嚴重的問題:
• JS Divergence 限制: 在高維空間中,真實資料分佈和生成資料分佈的重疊部分極少。
• Log2 困境: 如果兩個分佈沒有重疊,JS Divergence 算出來永遠是 Log2,無法提供 G 有效的梯度來調整參數。這使得 D 的準確率在實際操作中經常達到 100%。
為了解決這個問題,學界提出了使用 Wasserstein Distance(又稱 Earth Mover Distance)來衡量分佈相似度,並由此發展出 WGAN(Wasserstein GAN)。
• Wasserstein Distance (WD): WD 概念上是將一個分佈的「土」挪到另一個分佈所需的最小平均移動距離。
• WD 優勢: 即使兩個分佈沒有重疊,WD 也能反映它們之間的距離差異。這使得 Generator 即使只有微小的進步(稍微靠近 Pdata),WD 的值也會有變化,從而能持續提供梯度來訓練 G。
• WGAN 實作限制: 為了確保計算結果是 WD,WGAN 要求 Discriminator(此時常被稱為 Critic)必須是一個足夠平滑的 1-Lipschitz Function。實現此限制的方法包括梯度懲罰(Gradient Penalty, GP) 或譜歸一化(Spectral Normalization, SN)。
GAN的應用類型
GAN的應用廣泛,主要分為兩大類:
- 非條件式生成(Unconditional Generation):
◦ G只接收隨機向量 Z 作為輸入,產生圖片。
◦ 例子包括生成逼真的二次元人臉 或高清的真實人臉。
◦ 透過對輸入 Z 向量進行內插(Interpolation),可以觀察到兩張生成圖片之間連續且平滑的變化。
- 條件式生成(Conditional Generation, cGAN):
◦ G 接收條件 X(Condition)和隨機向量 Z 作為輸入,產生輸出 Y。
◦ D 不僅會看輸出的 Y 圖片,還會看條件 X,判斷 Y 是否清晰且與 X 匹配。
◦ 文字轉圖片 (Text-to-Image): 輸入一段文字描述(如「紅眼睛」),輸出符合該描述的圖片。
這時候需要把圖片跟文字敘述成對輸入
◦ 圖像轉換 (Image-to-Image Translation, Pix2pix): 輸入一張圖片,輸出另一張轉換風格或內容的圖片,例如黑白圖著色、素描轉實景、或白天轉夜晚。
無成對資料學習:CycleGAN
當訓練資料中,輸入(X domain)和輸出(Y domain)的資料是不成對的時候(Unpaired Data),例如將真人照片風格轉換為二次元人物頭像,GAN 需要額外的機制來確保輸入與輸出的關聯性。
CycleGAN 透過引入循環一致性(Cycle Consistency)來解決此問題:
• 訓練兩個生成器 G (X → Y) 和 F (Y → X)。
• 要求圖片從 X 領域轉換到 Y 領域後,再透過 F 還原回 X 領域時,最終輸出的圖片應與原始輸入越接近越好。
• 這種循環結構(Cycle)強迫 G 輸出的圖片必須與輸入圖片保有某些關係,避免 G 完全無視輸入。
• CycleGAN是雙向的,同時也要求 Y → X → Y 的循環一致性。
• CycleGAN的類似方法還有 DiscoGAN 和 Dual GAN。
• CycleGAN可用於非督導式的風格轉換、文字風格轉換及非督導式翻譯等任務。
GAN的評估方法
評估 G 生成結果的好壞並不簡單。除了主觀的人眼觀察外,常見的客觀自動評估方法包括:
• 品質 (Quality): 將生成圖片丟入影像分類器,若輸出機率分佈越集中,代表圖片品質越高。
• 多樣性 (Diversity): 將大量圖片丟入分類器,若所有輸出分佈的平均結果越平坦,代表 G 產生的多樣性(Variety)越高。
• Inception Score (IS): 一種結合 Quality 和 Diversity 的分數。
• Fréchet Inception Distance (FID): 目前較常用的指標。將真實圖片和生成圖片丟入 Inception Network 的隱藏層,取出高維向量,假設它們的分佈為高斯分佈,並計算兩分佈之間的 Fréchet Distance。FID 數值越小,代表生成品質越好。
在評估時需要注意 GAN 訓練中常見的問題:
• 模式崩塌(Mode Collapse): G 僅能生成真實資料分佈中的少數幾張圖片,重複性高。
• 模式丟失(Mode Dropping): G 只能生成真實資料分佈的一部分,多樣性不足。
參考資訊
Comments
Loading comments…
Leave a Comment