跳转到内容

ZeRO 2020 — 把训练状态切成 N 份让万亿参数成为可能

是什么

ZeRO(Zero Redundancy Optimizer)是微软 2020 年在超算顶会 SC 提出的一组显存优化方法:训练一个大模型时,把每张 GPU 都重复存的那份”训练状态”切成 N 份,每张只留 1/N。日常类比:原本 N 个学生每人都抄一整本笔记,现在改成每人只保管 1/N 页,要看时再相互借。

论文里”万亿参数”不是噱头——它给出的内存账本说明:1024 张 V100(每张 32GB)配合 ZeRO-3,理论上能装下 1 万亿参数的模型训练。这是首次有人把”训万亿”这件事算到能落地的程度。

为什么重要

现代大模型训练框架几乎都是 ZeRO 的徒孙。理解这篇论文能帮你看懂:

  • 为什么 PyTorch 后来推出的 fsdp-2023 本质就是 ZeRO Stage 3 的官方重写
  • 为什么 megatron-lm 和 DeepSpeed 经常组合出现:一个切计算、一个切状态,正交互补
  • 为什么 gpipe-2019 / pipedream-2019 这些流水线并行方案常被叠加在 ZeRO 之上,凑成 “3D Parallelism”
  • 为什么大模型微调脚本里 zero_optimization.stage 这个旋钮值得花时间调

核心要点

ZeRO 的洞察是:训练状态有三块,逐块切分,每切一块多省一截内存。

论文给出的内存账本(以 7.5B 参数模型 + 64 张 GPU 为例,单位 GB/卡):

配置优化器状态梯度参数总计
标准 DDP1203015165
ZeRO Stage 11.9301547
ZeRO Stage 21.90.51517
ZeRO Stage 31.90.50.22.6

三个 Stage 逐级切:

  1. Stage 1(Pos):切优化器状态

    • Adam 训练时,每个参数额外存 momentum + variance(fp32,共 8 字节/参数)
    • 加上 fp32 master copy(4 字节)总共 12 字节/参数——是模型参数本身的 6 倍
    • 切完每张卡只保留 1/N 份,省 4 倍以上
  2. Stage 2(Pos+g):再切梯度

    • 关键技巧:把传统 all-reduce 替换为 reduce-scatter
    • 每张卡只拿到自己负责那段的梯度,不收集完整梯度
    • 类比:4 个学生改卷,原本每人都拿到 4 份完整成绩,现在每人只拿自己那一份
  3. Stage 3(Pos+g+p):连参数本身也切

    • forward / backward 时按层动态 all-gather:算这层就把参数借齐,算完立刻丢
    • 通信量比 Stage 2 多 50%,但内存压到 1/N
    • 类比:每人只背 1/4 课本,上课讲到哪一页谁背就传给谁

关键设计:Stage 2 用 reduce-scatter 替代 all-reduce 的等式 all-reduce = reduce-scatter + all-gather 是 ZeRO 的灵魂——把一次 collective 拆两步,不增加通信量却能”白拿”一份内存优化。

为什么是 12 字节/参数:fp16 训练时,一份参数本身是 fp16(2 字节),但 Adam 优化器要保留 fp32 master copy(4 字节)+ fp32 momentum(4 字节)+ fp32 variance(4 字节)共 12 字节。这就是 OS 体积是参数本身 6 倍的来源——也是 Stage 1 收益最大的原因。

实践案例

案例 1:论文里的万亿参数账本

1024 张 V100(32GB)+ ZeRO Stage 3:

  • 每张卡分到 1/1024 的优化器状态、梯度、参数
  • 一个万亿参数模型,每张只需 ≈ 16GB 存训练状态,余下 16GB 装 activation 和缓冲
  • 这是首次”万亿训练”在工程上有了可行配置
  • 论文同时给出 100B 参数模型在 400 张 V100 上跑出 38 TFlops/卡 的实测——是当时所有公开方案里效率最高的

案例 2:消费级硬件微调 7B 模型

ds_config.json
{
"fp16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"offload_optimizer": { "device": "cpu" }
}
}

不开 ZeRO 训 7B 模型每卡至少要 80GB(A100 80GB × 4);开 Stage 3 后 4 张 24GB 的 RTX 4090 也能跑通。这是 HuggingFace accelerate 默认推荐 DeepSpeed 的根本原因。

案例 3:DeepSpeed 在 HuggingFace accelerate 里的一行启用

Terminal window
accelerate config # 交互式选 DeepSpeed → ZeRO Stage 3
accelerate launch train.py

accelerate 把底层 DeepSpeed / FSDP 都封装起来,用户只改一行配置就能切换并行策略——这是社区接受 ZeRO 的最大原因。

案例 4:Megatron-DeepSpeed 训 530B

微软 + NVIDIA 联合训 Megatron-Turing NLG 530B:

  • Megatron 切计算(tensor parallel:一层 attention 拆给多张卡)
  • DeepSpeed 切状态(ZeRO 把跨副本状态分摊)
  • 两者正交叠加

这是后来 “3D Parallelism” 的起点:Data × Tensor × Pipeline × ZeRO 四维组合。

踩过的坑

  1. 小模型上 Stage 3 反而更慢:模型 < 1B 参数时,多一次 all-gather 的通信开销 > 内存收益。这种情况用 Stage 1 或纯 DDP 才对。

  2. ZeRO 不解决 activation 内存:训练时 activation 常占总内存 40%+,ZeRO 完全切不动这块——必须配合 gradient checkpointing 一起用。

  3. 跨 node 比单 node 慢得多:ZeRO 频繁的 all-gather 在单 node NVLink(900GB/s)上很快,跨 node InfiniBand(约 25GB/s)就成瓶颈。1024 GPU 训练时 Stage 3 通信占比能到 30%+。

  4. 配置参数多,新人容易调爆:stage / offload / bucket_size / overlap_comm 一堆旋钮,调不好可能比 DDP 还慢。这也是 PyTorch 推出 fsdp-2023(更易用版 Stage 3)来收编社区的原因之一。

适用 vs 不适用

适用

  • 单卡装不下完整模型,但有多张 GPU 可用(4-8 张消费级卡微调 7B-70B)
  • 用 Adam / AdamW 优化器(OS 是大头,切了收益最高)
  • 训练框架是 PyTorch / HuggingFace Transformers 生态

不适用

  • 单卡能装下整个模型(< 1B 参数)→ DDP 更快
  • 用 SGD(无 momentum)→ Stage 1 收益接近零
  • 推理场景 → ZeRO 是训练时优化,推理用 quantization / vllm 等其他思路
  • 极端跨 node(千卡以上)→ 通信占比过高,需混合 TP / PP 才能稳住吞吐

历史小故事(可跳过)

  • 2019 年 10 月:微软 DeepSpeed 团队挂出 ZeRO 论文 arXiv 预印本
  • 2020 年 11 月:在 SC 2020 正式发表,DeepSpeed 开源到 GitHub 立刻成为大模型训练事实标配
  • 2021 年:同作者发表 ZeRO-Infinity,加 NVMe SSD offload,单张 V100 能训 1T 参数(速度极慢但能跑)
  • 2022 年:发表 ZeRO++,针对量化通信和 hierarchical partitioning 进一步降通信开销
  • 2023 年:PyTorch 推出 fsdp-2023,本质是 Stage 3 的官方原生实现
  • 2024 年:FSDP 进入 PyTorch 主线,ZeRO 思想成为事实标准

学到什么

  1. “复制是必须的”是个可挑战的假设——DDP 让每张卡都存完整状态是工程惯例,不是物理定律。规模一大,复制本身就是瓶颈。
  2. all-reduce = reduce-scatter + all-gather 这个等式总成立——把一个 collective 拆成两步,可以”白拿”一份内存优化。
  3. 工程化的胜利:ZeRO 没发明新算法,只是把”切分”贯彻到优化器状态、梯度、参数三层——这种”已有技术的极致组合”反而比新算法更有影响力。
  4. 通信换内存是分布式训练的核心 trade-off——没有银弹,只有不同硬件代际下不同的最优配置。
  5. 算账比写代码更重要:论文最大的贡献是那张内存账本表,让”万亿训练”从口号变成可工程化的目标。

延伸阅读

关联

  • megatron-lm —— Megatron 切计算 / ZeRO 切状态,正交互补
  • fsdp-2023 —— PyTorch 把 Stage 3 搬进主线
  • gpipe-2019 —— 流水线并行常和 ZeRO 叠加
  • pipedream-2019 —— 1F1B 流水线,3D Parallelism 的另一维
  • alpa-2022 —— 把数据/张量/流水/ZeRO 统一成搜索问题