PyTorch FSDP — 把大模型切成 N 份分到 N 张卡
是什么
FSDP(Fully Sharded Data Parallel)是 PyTorch 官方的多卡训练大模型方案:把模型参数、梯度、优化器状态都等分到所有 GPU 上,每张卡平时只持有 1/N。要算的那一层临时拼回完整权重,算完立刻释放。
日常类比:一本 1000 页的书 8 个人读,传统做法是每人复印一整本(DDP),FSDP 是把书撕成 8 摞每人拿一摞,谁要读哪页谁去其他人那借一份,读完归还。常驻”书页”只有 1/8。
为什么重要
不理解 FSDP,下面这些事都没法解释:
- 为什么 Llama / Mistral / Qwen 的训练脚本默认写
FullyShardedDataParallel而不是DistributedDataParallel - 为什么 175B 参数的模型能在 512 张 A100 上训出来,按 DDP 算每张 A100 要装 700GB 才够
- 为什么 HuggingFace
accelerate、torchtune、trl默认后端都从 DeepSpeed 切到了 FSDP - 为什么同一份模型代码加一个 wrap 就从单卡训变成多卡训,不用改前向
核心要点
FSDP 工作流可以拆成 三步:
-
切片(Shard):模型按 wrap 单元切成 N 份,每张卡只存自己那 1/N 的参数 + 梯度 + Adam 状态。
-
临时拼回(AllGather):算到某层前向时,所有卡把这一层的碎片用 AllGather 拼成完整权重。算完前向立刻丢弃别人的碎片,只留自己那份。
-
反向后归约切回(ReduceScatter):反向算出梯度后用 ReduceScatter,一次完成”全员求和 + 切回每人 1/N”。优化器只更新自己那 1/N。
ZeRO-3 是 DeepSpeed 2020 年提出的同款思路,FSDP 是 Meta 把它写进 PyTorch 核心。
实践案例
案例 1:一行 wrap 把单卡变多卡
from torch.distributed.fsdp import FullyShardedDataParallel as FSDPmodel = MyTransformer() # 单卡写法model = FSDP(model) # 加这一行就分片到所有 GPUout = model(x) # 前向反向写法不变FSDP 这个壳在前向时偷偷插 AllGather,反向时插 ReduceScatter。你的训练循环一字不改。
案例 2:三档 sharding 策略
| 策略 | 参数 | 梯度 | 优化器状态 | 等价于 |
|---|---|---|---|---|
FULL_SHARD | 切 | 切 | 切 | ZeRO-3 |
SHARD_GRAD_OP | 复制 | 切 | 切 | ZeRO-2 |
NO_SHARD | 复制 | 复制 | 复制 | DDP |
显存最省是 FULL_SHARD,通信最少是 NO_SHARD。模型放得下就用越右边的,放不下才往左切。
案例 3:FlatParameter 为什么重要
一个 transformer block 里有几十个小 tensor(Q/K/V/O 各四份权重 + LayerNorm + MLP 两份)。如果每个 tensor 单独发一次 AllGather,几十次小 collective 把网络打满。
FSDP 把一个 wrap 单元里所有参数拼成一个扁平 tensor(FlatParameter),AllGather 一次搞定。代价是 state_dict 变形了——存盘、加载、推理切换时要在 FULL_STATE_DICT / SHARDED_STATE_DICT 之间转换。
案例 4:Hybrid Sharding 应付跨机带宽瓶颈
512 张 A100 分成 64 节点 × 8 卡,节点内走 NVLink(600 GB/s),节点间走 InfiniBand(200 Gb/s)。如果在全 512 张上做 FULL_SHARD,每次 AllGather 都要跨机,IB 带宽被打满。
HSDP(Hybrid Sharding)让节点内部做 FSDP(完整切片),节点之间做 DDP(参数复制)。每张卡常驻状态变成 1/8 而不是 1/512,但跨机只走梯度归约一次,带宽友好得多。论文里 1T 模型训练就是用 HSDP。
论文里的工程经验
Meta 在 175B / 1T 训练里踩出来的几条工程心得,比 API 更值得学:
- deferred init:1T 参数模型初始化时如果先在 rank 0 上构图再分发,rank 0 单卡就爆了。FSDP 用 meta device 先建空壳,wrap 时再按 shard 量分配真实显存。
- rate limiter:AllGather 是异步的,反向算到第 N 层时如果第 N+5 层已经预取完,缓冲堆积反而把显存吃光。FSDP 加了一个限速器控制 prefetch 深度。
- mixed precision 选项粒度:参数用 fp32 / bf16、梯度用 fp32 reduce、buffer 用 fp32 三档独立设。Adam 状态保 fp32 是数值稳定关键,光 bf16 算梯度会发散。
- TFLOPS 利用率 ~84%:175B 在 A100 上跑到 84% 峰值算力利用率,对 sharding 类训练算很高。剩下 16% 主要是通信和等待。
踩过的坑
-
小模型上 FSDP 反而比 DDP 慢:参数 < 1B 时 AllGather 通信开销盖过显存收益。FSDP 是给 7B 起步的模型用的,2B 以下用 DDP 就行。
-
wrap 单元粒度难调:太粗(整个模型一个 unit)→ 通信少但拼起来那一刻峰值显存爆。太细(每个 Linear 一个 unit)→ AllGather 几百次。实践上按 transformer block 切是好默认。
-
必须配 activation checkpointing:参数切了但中间激活没切,128k 上下文 + 70B 模型仍然爆。FSDP +
torch.utils.checkpoint是事实标配。 -
state_dict 加载坑:训练存的是 sharded ckpt,推理或微调要加载时如果不切
FULL_STATE_DICT模式,会拿到形状不对的扁平 tensor。新人首次跑通后换部署环境常踩这个。 -
FSDP1 的 FlatParameter 不兼容 LoRA:LoRA 只训部分参数,FlatParameter 把训练和冻结的揉在一起。2024 年 FSDP2 改成 per-parameter sharding 修了这个洞——新代码优先用
torch.distributed._composable.fsdp。 -
CPU offload 是最后的稻草:参数和优化器还能再 offload 到 CPU 内存(
CPUOffload(offload_params=True)),代价是每步多一次 PCIe 来回,吞吐下降 2-3 倍。除非显存真的不够,不然不开。 -
backward_prefetch=BACKWARD_PREvsBACKWARD_POST:默认 PRE 在反向算上一层时就预取下一层参数,吞吐高但峰值显存高。POST 等当前层算完再预取,省显存但慢。论文给的经验是默认 PRE,OOM 了再降级。
适用 vs 不适用场景
适用:
- 7B 以上、单卡放不下的 LLM 预训练 / 全参微调
- 需要和 PyTorch 原生功能(autograd /
torch.compile/ 混合精度 / activation checkpoint)兼容 - 多机多卡,每节点 8 张 A100/H100 的常见配置
不适用:
- 模型参数 < 2B:DDP 已经够用,FSDP 通信开销不划算
- 推理:FSDP 是训练方案,推理用 vLLM / TensorRT-LLM
- 张量并行需求(把单层 Linear 切到多卡):用 Megatron-LM 的 tensor parallel,常和 FSDP 组合成 2D / 3D 并行
- 流水并行(按层切到不同节点):用 PyTorch
pippy或 DeepSpeed pipeline,FSDP 不解决这个维度
历史小故事(可跳过)
- 2019 年:DeepSpeed 团队(Microsoft)观察到 DDP 把 Adam 状态在每卡都复制一份是巨大浪费——Adam 一阶二阶矩占的显存是参数本身的 2 倍。
- 2020 年:DeepSpeed 发表 ZeRO(Zero Redundancy Optimizer)三阶段论文,ZeRO-1 切优化器状态、ZeRO-2 加切梯度、ZeRO-3 连参数也切。GPT-3 训练用的就是 ZeRO。
- 2021 年:FairScale(Meta 的实验库)实现 FSDP 原型,验证 ZeRO-3 思路能写进 PyTorch 主线。
- 2022 年:FSDP 进 PyTorch 1.11 主分支。
- 2023 年:Meta 发表本论文,总结 175B / 1T 模型上的工程经验,FSDP 成为 PyTorch 官方推荐路径。
- 2024 年:FSDP2(per-parameter sharding)发布,修 LoRA / 部分冻结场景。
torch.distributed._composable.fsdp.fully_shard取代旧FullyShardedDataParallel类成为推荐 API。 - 2025 年:FSDP2 +
torch.compile+ tensor parallel 的组合成为 7B-70B 模型训练的事实默认,HuggingFace 整个训练栈基本都站在 FSDP2 上。 - 从 2019 年的 ZeRO 提案到 2025 年的事实标准,整个演化用了 6 年——典型的 “学术想法 → 工业框架 → 默认选项” 三段式落地。
学到什么
- 同一思路两套实现:ZeRO-3 是外挂库,FSDP 是官方一等公民。学术想法落地到工业框架是独立的二次工程,且第二次往往做得更稳。
- AllGather + ReduceScatter 是分片训练的两个核心 collective:理解这俩比理解 FSDP 本身更通用,张量并行、序列并行也都靠它们。
- 峰值显存 ≠ 常驻显存:FSDP 减的是常驻,AllGather 那一刻峰值仍然要装下完整层。这是 wrap 粒度选择的根本权衡。
- 训练框架的”零侵入”是核心 UX:一行 wrap 不改前向,是 FSDP 比 ZeRO-3 流行的关键工程价值。
- 多维并行的拼图位置:FSDP 解决”参数太大单卡放不下”这一维。“单层太大单卡放不下”用张量并行,“层数太多串行太慢”用流水并行。3D 并行 = FSDP × Megatron TP × Pipeline。
- 网络拓扑决定算法选择:HSDP 的存在揭示了一件事——分片训练的最优策略不取决于模型,取决于你机器之间怎么连。论文反复强调这点。
延伸阅读
- 论文 PDF:arXiv:2304.11277(Meta 工程笔记,重经验轻公式)
- 官方教程:PyTorch FSDP Tutorial(小模型跑通示例)
- ZeRO 原始论文:Rajbhandari et al., SC 2020(思路源头)
- FSDP2 文档:torch.distributed._composable.fsdp(per-parameter 新版本)
- deepspeed-zero —— 同款思路的 DeepSpeed 实现,FSDP 的直接灵感来源
- pytorch —— FSDP 的宿主框架
- megatron-lm —— 张量并行,常和 FSDP 组合成 2D 并行
关联
- deepspeed-zero —— ZeRO-3 是 FSDP 的思路源头,FSDP 是它在 PyTorch 的官方化身
- pytorch —— FSDP 是
torch.distributed的一等公民 - megatron-lm —— 张量并行,FSDP(数据并行分片)+ Megatron(张量并行)是大模型训练的常见 2D 组合
- gpipe-2019 —— 流水并行,3D 并行的第三维
- gshard-2020 —— Google 的 sharding 抽象,思路接近但绑定 JAX/TPU
- pipedream-2019 —— 流水并行调度,与 FSDP 互补
反向链接
- deepspeed-zero —— DeepSpeed ZeRO — 微软优化大模型训练显存
- dwork-our-data-ourselves-2006 —— 分布式噪声生成 — 去掉可信管理员也能保护隐私
- gpipe-2019 —— GPipe — micro-batch 流水线让 GPU 排成生产线
- gshard-2020 —— GShard — 用注解让 600B 模型自动跨设备切片
- pipedream-2019 —— PipeDream — 1F1B 调度让流水线工位别空等
- pytorch —— PyTorch — 深度学习主流框架
- zero-2020 —— ZeRO 2020 — 把训练状态切成 N 份让万亿参数成为可能