跳转到内容

PyTorch Lightning — PyTorch 训练循环抽象

是什么

PyTorch Lightning 是一个把 PyTorch 训练里”工程代码”和”研究代码”切开的薄壳框架。日常类比:像搬家公司——你只负责说”这箱子叫宝贝、那箱子怕摔”(研究逻辑),搬运、卡车、电梯调度、跨城物流(device / 分布式 / 混合精度 / checkpoint / 日志)全交给搬家队(Trainer)。

你写:

import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, idx):
x, y = batch
loss = self.loss_fn(self(x), y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
trainer = pl.Trainer(max_epochs=10, accelerator="gpu", devices=4, precision="bf16")
trainer.fit(model, dataloader)

一行 Trainer(...) 就拿到 4 卡 DDP + bf16 混合精度,研究代码一行没改

为什么重要

不理解 Lightning,下面这些事都没法解释:

  • 为什么 2020 年后论文复现仓库一半是 LightningModule——审稿人看代码结构一致,省精力
  • 为什么”研究代码上生产”过去要重写一遍,现在只需把 Trainer 参数从 devices=1 改成 devices=8, strategy='ddp'
  • 为什么 HuggingFace、NVIDIA NeMo、PyTorch Geometric 等上游库都内置 Lightning 适配
  • 为什么 2.0 又拆出 Fabric——承认”全收口的 Trainer 不适合所有人”,提供半自动选项

核心要点

Lightning 的能力可以拆成 四块

  1. LightningModule(你的研究代码):继承一个类,写 training_step / validation_step / configure_optimizers。模型本身还是普通 nn.Module,多了几个钩子方法。

  2. Trainer(工程代码收口):一个对象拿走所有”和模型逻辑无关的事”——max_epochs / accelerator='gpu' / devices=4 / strategy='ddp' / precision='bf16' / callbacks=[...]

  3. Callbacks(可插拔工件)ModelCheckpoint 自动存最佳模型,EarlyStopping 自动停不再下降的训练,LearningRateMonitor 记录 lr 曲线。新增功能不改 LightningModule。

  4. Lightning Fabric(2.0 新增的底层 API):保留你自己的 for batch in dataloader: 循环,但 fabric.setup(model, optimizer) 后自动做 device 搬运 / 分布式 / 混合精度。给”不想被 Trainer 接管全循环”的人一个中间层。

实践案例

案例 1:从 raw PyTorch 迁到 Lightning

raw 版本(每次都要抄):

model.to(device)
for epoch in range(10):
for x, y in loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
opt.step()

Lightning 版本(device / zero_grad / backward / step 全没了):

class M(pl.LightningModule):
def training_step(self, batch, idx):
x, y = batch
return self.loss_fn(self(x), y)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
pl.Trainer(max_epochs=10, accelerator="auto").fit(M(), loader)

少 6 行。换 GPU / 多卡 / TPU 时,只动 Trainer 参数。

案例 2:4 卡 DDP 一行切换

trainer = pl.Trainer(devices=4, strategy="ddp", precision="bf16-mixed")

raw PyTorch 要写 torch.distributed.init_process_group / DistributedDataParallel(model) / DistributedSampler(dataset) / torchrun 启动脚本——Lightning 全包了。

案例 3:callback 替代手动 if-else

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
trainer = pl.Trainer(callbacks=[
ModelCheckpoint(monitor="val_loss", save_top_k=3),
EarlyStopping(monitor="val_loss", patience=5),
])

无需在 training_step 里写”如果 val_loss 5 轮没降就 break”——callback 监听事件总线自动触发。

踩过的坑

  1. 版本不兼容剧烈:0.x → 1.x → 2.x 大改三轮,training_step_end 等 hook 被删,2023 年前的教程一半跑不了。永远先看官方 docs 的版本号,别信博客。

  2. automatic_optimization 默认 True 会替你 zero_grad / backward / step:写 GAN 等多优化器场景必须 self.automatic_optimization = False 后手动调 self.manual_backward(loss)opt.step()。新人 GAN 只更新一个判别器找不到原因。

  3. self.log() 默认按 epoch 聚合:想看 step 级曲线必须 self.log("loss", loss, on_step=True, on_epoch=False),否则 TensorBoard 看起来”日志没记录”。

  4. DDP 模式下 Trainer 会 spawn 子进程:脚本顶层副作用(数据预处理、wandb.init)会跑多次。入口必须包进 if __name__ == "__main__":,否则 fork 出 N 个数据加载进程互踩。

  5. hook 触发顺序复杂on_train_batch_end vs training_step_end vs on_train_epoch_end 各自时机不同,文档要查清楚才动手。

适用 vs 不适用场景

适用

  • 中大型研究项目(多人协作、需要”训练循环长得一样”)
  • 需要快速切单卡 / 多卡 / TPU / bf16
  • checkpoint / EarlyStopping / 日志想要开箱即用
  • 论文复现仓库(标准结构让审稿人快速理解)

不适用

  • 只跑一次的玩具脚本 → raw PyTorch 50 行更短
  • 高度自定义的训练循环(强化学习的 rollout 循环、需要 actor-learner 异步)→ Fabric 或 raw 更顺手
  • 不熟 PyTorch 的初学者 → 先把 raw 训练循环写明白再上 Lightning,否则报错时没法调试
  • 极简推理服务 → Lightning 是训练侧抽象,部署用 TorchScript / ONNX / Triton

历史小故事(可跳过)

  • 2019 年:William Falcon 在 NYU 读博士,把自己抄了 N 遍的训练循环模板开源。最初只是 personal toolkit
  • 2020 年:1.0 发布,公司 Grid AI 成立做云训练平台,Lightning 是开源前端
  • 2022 年:Grid AI 改名 Lightning AI,框架升级 Lightning 2.0 路线图
  • 2023-03:Lightning 2.0 发布,重写 Trainer,Lightning Fabric 作为底层 API 单独发布,torch.compile 一等公民
  • 现在:~29k stars,HuggingFace、NVIDIA NeMo、PyTorch Geometric 等上游内置 Lightning 适配

学到什么

  1. 抽象的代价是 API 学习成本——Lightning 替你省了工程代码,但你必须先学清楚 hook 触发顺序、callback 协议
  2. “约定优于配置”再次胜利——所有 LightningModule 长得一样,让大型项目协作和论文复现的认知负担骤降
  3. 2.0 引入 Fabric 是一次妥协——承认”完全包住 for 循环”挡住了部分用户(强化学习 / 自定义采样),需要分层抽象
  4. API 命名稳定性 >> 功能丰富度——Lightning 早期 hook 改名带走一批用户,后来吸取教训保持兼容
  5. 生态比框架重要——Lightning 真正护城河不是技术而是 callback 生态 + 上游适配(HuggingFace 直接给 LightningTrainer)

延伸阅读

关联

  • pytorch —— Lightning 是 PyTorch 上的薄壳,所有 tensor / autograd / nn.Module 概念原样保留
  • scikit-learn —— 同样用”统一的 fit 入口”消化各种算法/模型,Lightning 把这套思路推广到深度学习训练循环
  • hindley-milner —— LightningModule 接口约定(必须实现哪些方法)类似类型签名,是另一种”靠形状吃饭”的契约设计

反向链接

  • accelerate —— Accelerate — HuggingFace 设备/分布式抽象
  • bentoml —— BentoML — 模型打包部署
  • candle —— Candle — HuggingFace 出品的 Rust 推理框架
  • clearml —— ClearML — 自托管 MLOps 套件
  • colossal-ai —— Colossal-AI — 大模型训练系统
  • deepspeed —— DeepSpeed — 微软分布式训练库
  • fastai —— fastai — 三行代码做迁移学习
  • flax —— Flax — JAX 上的神经网络库
  • haystack —— Haystack — 企业 NLP / RAG 流水线
  • hindley-milner —— Hindley-Milner — 编译器自己猜变量类型
  • keras —— Keras 3 — 一份模型代码跑三套后端
  • megatron-lm —— Megatron-LM — NVIDIA 张量并行库
  • mlflow —— MLflow — 端到端 ML 生命周期
  • optax —— Optax — JAX 优化器组合库
  • piper —— Piper — 端侧低延迟 TTS
  • pytorch —— PyTorch — 深度学习主流框架
  • scikit-learn —— scikit-learn — 经典 ML 库
  • torchcodec —— TorchCodec — PyTorch 原生 GPU 视频解码与张量输出
  • torchtune —— torchtune — PyTorch 官方 LLM 微调库
  • trl —— TRL — RLHF / DPO / GRPO 训练库
  • wandb —— Weights & Biases — 几行 init 把指标系统代码自动入库