跳转到内容

MAML — 学一个"好起点",几步就能学会新任务

是什么

MAML(Model-Agnostic Meta-Learning)是一套让模型学会快速学习新任务的方法。日常类比:你不是让学生背一本特定教材,而是教他一种”考前 30 分钟翻什么都能上手”的复习套路。

形式上:MAML 不直接训练模型解决某个任务,而是训练模型的初始参数 θ——这套 θ 的特点是,把它放到任意一个新任务上,做几步梯度下降就能适应得很好

传统训练:θ → 直接拿来用
MAML: θ(好起点) → K 步 SGD → θ′ → 在新任务上表现好

“好起点”四个字就是 MAML 的全部。

为什么重要

不理解 MAML,下面几件事都讲不通:

  • 为什么 GPT-3 论文里反复出现 “in-context learning”——meta-learning 是它的精神前身
  • 为什么 few-shot 分类任务上,“先 pretrain 再 fine-tune” 在 2017 年被超越,但又在 2020 年回归
  • 为什么 “学会学习”(learning to learn,Schmidhuber 1987 提出)这个 30 年前的口号到 2017 年才真正落地成可训练的算法
  • 为什么后续的 Reptile、ANIL、in-context learning 都在反复”瘦身”二阶梯度——MAML 太重了

核心要点

MAML 的训练循环可以拆成 内外两层

  1. 内循环(inner loop):从任务分布里抽一个 task T_i,在它的 support set 上做 K 步 SGD,得到适应后的参数 θ_i′。

    θ_i′ = θ - α · ∇_θ L_T_i(θ) (K = 1 时)
  2. 外循环(outer loop):把多个 task 的 query set 损失加起来,对原始 θ 求梯度——注意不是对 θ_i′,而是穿过 θ_i′ 的更新过程一直回到 θ。

    θ ← θ - β · ∑_i ∇_θ L_T_i(θ_i′)
  3. 二阶梯度:因为 θ_i′ 本身是 θ 的函数(θ_i′ = θ - α∇L(θ)),对 θ_i′ 求梯度再对 θ 求梯度,会冒出 Hessian-vector product——这就是 MAML 计算贵的根源。

三步加起来,模型学到的不是”答案”,而是”对答案最敏感的初始位置”。

实践案例

案例 1:正弦曲线 few-shot 回归

任务族:每个任务是一条正弦曲线 y = A sin(x + φ),A 和 φ 随机。

  • 给 MAML 看 5 个 (x, y) 点 → 它预测整条曲线
  • 普通预训练模型看 5 个点 → 预测一条平均曲线,对不上

为什么:MAML 训练时见过几千条不同的正弦曲线,它的 θ 已经”知道任务一定是某条正弦”,新任务只需要从 5 个点猜出 A 和 φ——5 步梯度足够。

案例 2:miniImageNet 5-way 1-shot 分类

任务:每次给 5 个类别、每类 1 张样本,问第 6 张属于哪一类。

  • MAML:48.7% 准确率(2017 年 SOTA)
  • 普通 fine-tune:跑不动,1 张样本不够 SGD 稳定下降

为什么:MAML 的 θ 已经是”识别新类别只需要看 1 张”的初始权重,内循环 1 步就能挪到合适位置。

案例 3:FOMAML 一阶近似

完整 MAML 要存 K 步中间状态、要算 Hessian——内存翻倍、速度慢。

FOMAML(first-order MAML):把二阶项当成 0,外循环直接用 ∇_θ L(θ_i′)。

# 完整 MAML(伪代码)
loss.backward(create_graph=True) # 保留计算图给二阶
# FOMAML
loss.backward() # 普通一阶

效果掉 1-2 个百分点,但训练速度快 2 倍、内存省一半。绝大多数生产实现用 FOMAML

踩过的坑

  1. 二阶梯度内存爆炸:K=5 步内循环就要存 5 份中间激活和 5 份梯度。batch size 必须开很小,否则 OOM。

  2. inner step K 是超参:K 太少欠适应(θ_i′ 还没到位),K 太多过拟合 support set(query set 反而变差)。常见取 1-5。

  3. task 分布要有结构:MAML 假设训练 task 和测试 task 来自同一族(都是正弦、都是图像分类)。跨族(正弦学完去做分类)不会泛化。

  4. first-order 近似在 RL 上明显劣化:监督任务 FOMAML 掉 1-2%,但在 MuJoCo 半猎豹换方向这种强化任务上,二阶项贡献 5%+,不能省。

  5. ANIL 的反直觉发现(Raghu 2019):MAML 学到的不是”快速适应”,而是”特征复用”。冻住前面所有层、内循环只更新最后一层(ANIL),效果几乎等同完整 MAML。这暗示 MAML 其实在做特征学习,“meta” 部分被高估。

适用 vs 不适用场景

适用

  • few-shot 学习——每类只有 1-5 个样本
  • 任务族结构清晰且训练时能采样到(正弦回归、Omniglot、miniImageNet)
  • 强化学习中的快速策略适应(机器人换重力、半猎豹换跑步方向)

不适用

  • 大数据 + 单一任务 → 直接训就行,不需要 meta
  • 跨族泛化 → MAML 不假设这种泛化能力
  • 部署时算力极紧 → 即使 FOMAML,inner loop 仍要在线 SGD;in-context learning(GPT-3)连 SGD 都不要,更便宜

历史小故事(可跳过)

  • 1987 年:Schmidhuber 博士论文提出”学会学习”(learning to learn),但当时神经网络还没法支持二阶反向传播。
  • 1990s-2000s:Bengio 等人陆续提出 meta-learning 想法,多停留在概念。
  • 2016 年:Matching Networks(Vinyals)用 attention 做 few-shot;Memory-Augmented NN(Santoro)用外部记忆——都绕开了”训练初始化”这条路。
  • 2017 年 3 月:Chelsea Finn(Berkeley 博士生,Pieter Abbeel + Sergey Levine 学生)发表 MAML。5 页正文,思路简单得让人怀疑没人想到过——“为什么不直接对初始化求梯度”。
  • 2018 年:OpenAI 的 Reptile 把外循环简化到 θ ← θ + ε(θ_i′ - θ),连 query set 都不用。
  • 2019 年:ANIL 论文揭示 MAML 主要在做特征学习,不是真正的”快速适应”。
  • 2020 年:GPT-3 出现,in-context learning 不更新参数也能 few-shot,meta-learning 被重新解读为 transformer 的隐式行为。

学到什么

  1. “学会学习”可以编码进参数本身——这是 2017 年最重要的概念突破之一
  2. 二阶梯度是双刃剑:表达力强但内存贵;FOMAML 用 5% 精度换 2x 速度,工程上几乎总值
  3. 简单算法 + 正确视角:MAML 数学上就是嵌套 SGD,但”对初始化求梯度”的视角让它和所有前作区分开
  4. 后继者总在简化:MAML → FOMAML → Reptile → ANIL → in-context learning,每一步都在去掉显式梯度更新

一句话区分容易混淆的概念

  • MAML vs pretrain + fine-tune:pretrain 在大数据上学到一组参数后,fine-tune 不知道这组参数好不好”适应”;MAML 显式优化”被适应后效果好”。
  • MAML vs multi-task learning:multi-task 是把多个任务损失加起来一次训练完,输出一组通用参数;MAML 输出的是一组好起点,每个任务还要再走 K 步 SGD。
  • MAML vs in-context learning:MAML 在 inner loop 里真的更新参数;in-context learning 把 support set 写进 prompt,参数完全不动,靠 transformer 的注意力”模拟”出适应行为。
  • MAML vs Reptile:MAML 外循环对 θ_i′ 求梯度(要二阶),Reptile 直接拉 θ 朝 θ_i′ 方向走(一阶差分)。Reptile 像”看到学生用 5 步走到终点,那 θ 就往终点方向移一点”。

延伸阅读

关联

  • adam-2014 —— MAML 内外两层都用 SGD/Adam,外循环对 Adam 更敏感
  • transformer-2017 —— 同年的论文,后来 in-context learning 的载体
  • gpt-3 —— in-context learning 是 MAML 思想在 transformer 上的隐式实现
  • bert —— “pretrain + fine-tune” 范式与 MAML 是两条不同路线
  • lora-2021 —— 都在解决”少量参数适应新任务”,但 LoRA 不需要 meta 训练

反向链接

  • adam-2014 —— Adam — 让深度学习自己挑步长的优化器
  • bert —— BERT — 双向 Transformer 预训练
  • gpt-3 —— GPT-3 — Language Models are Few-Shot Learners
  • prototypical-networks-2017 —— Prototypical Networks — 每类算个均值,比距离就够了