Flax — JAX 上的神经网络库
是什么
Flax 是给 jax 配的神经网络库。日常类比:JAX 像一台只卖”面粉、糖、酵母”的食材店——它给你 grad / jit / vmap 这些原料,但不直接卖蛋糕。Flax 是开在 JAX 楼上的”蛋糕店”,把神经网络层(Linear / Conv / Attention)、参数管理、训练状态打包好。
你写:
import flax.linen as nnimport jax, jax.numpy as jnp
class MLP(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(128)(x); x = nn.relu(x) return nn.Dense(10)(x)
model = MLP()params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))logits = model.apply(params, jnp.ones((1, 784)))注意一个反直觉的设计:params 是和 model 分开存的——模型只是一份”配方”,参数是外面的”食材袋”。这是 Flax 函数式哲学的核心。
后端是 jax + XLA,跑在 CPU / GPU / TPU。Google DeepMind / Anthropic 的大模型训练栈都站在 Flax 上——Gemini 前端、AlphaFold 2/3、Gemma 开源权重官方实现都是 Flax。
为什么重要
不理解 Flax,下面这些事都没法解释:
- 为什么 jax 这么强还要再套一层——JAX 只卖原料,Flax 才是蛋糕
- 为什么 DeepMind 开源的 AlphaFold / Gemma 代码看起来”参数到处传”,而 pytorch 里参数都”藏”在
self.linear.weight - 为什么 Flax 2024 又推出一套全新的
nnxAPI,和老的linen风格完全不一样 - 为什么大模型训练偏好 Flax 而不是 keras——Keras 3 是”跨后端最大公约数”,Flax 是”为 JAX 量身定做”
核心要点
Flax 现在有两套并存的 API,新人最容易在这里迷路:
-
linen(旧版,目前主流):Module 是纯函数式 dataclass——它自己不存参数,调
model.init(rng, x)才返回参数 PyTree,调model.apply(params, x)才推理。好处:参数永远显式,对 jax 的jit/grad/vmap极友好。 -
nnx(2024 新版):Module 像 pytorch 一样内部存可变状态——参数是
nnx.Param字段,可以直接model.linear.kernel = ...改。要进 jit 时用nnx.split把”状态”和”纯函数”拆开。牺牲一点函数式纯度,换 PyTorch 用户能秒懂。 -
TrainState:
flax.training.train_state.TrainState把params/optimizer state/step打成一个 PyTree,配合 jax 的jit整体传进训练函数。配套优化器是 Optax。 -
rng 显式:和 jax 一样,每次随机操作都要传
rng key。init时给一把,Dropout时给一把,永远不能复用——纯函数 + 可复现的代价。
实践案例
案例 1:linen 风格的训练循环
import optaxfrom flax.training import train_state
state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=optax.adam(1e-3))
@jax.jitdef step(state, x, y): def loss_fn(p): logits = state.apply_fn(p, x) return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads)注意 state 是个 PyTree,@jax.jit 整体编译;params 永远显式传,没有 pytorch 的 .backward()。
案例 2:nnx 风格——更像 PyTorch
from flax import nnx
class MLP(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.l1 = nnx.Linear(784, 128, rngs=rngs) self.l2 = nnx.Linear(128, 10, rngs=rngs) def __call__(self, x): return self.l2(nnx.relu(self.l1(x)))
model = MLP(nnx.Rngs(0))y = model(jnp.ones((1, 784))) # 直接调,参数在 model 里对从 pytorch / pytorch-lightning 转过来的人,nnx 学习曲线短一大截。
案例 3:参数 PyTree 长什么样
jax.tree.map(lambda p: p.shape, params)# {'params': {'Dense_0': {'kernel': (784,128), 'bias': (128,)},# 'Dense_1': {'kernel': (128,10), 'bias': (10,)}}}参数是嵌套字典,新人常误以为是普通 dict——实际是 PyTree,要用 jax.tree.map 操作才能保持 jit 友好。
踩过的坑
-
linen 两种风格混用:
@nn.compact内联定义层 vssetup()显式定义。别在同一个 Module 混用——会报”variable already defined”。 -
BatchNorm / Dropout 的状态分集合:linen 里
params和batch_stats是两个独立集合。训练时model.apply({'params': p, 'batch_stats': bs}, x, mutable=['batch_stats']),写错就 silent bug。 -
nnx 是 2024 新 API,社区代码大多还是 linen:找开源参考时先看版本。
pip install flax默认装最新,但 README 大半还是 linen 例子。 -
保存模型用 orbax,不是 pickle:
orbax.checkpoint.PyTreeCheckpointer是官方推荐——pickle 在多卡 sharding 下会爆。 -
rng 不能复用:忘记 split 直接传同一个 key 给 Dropout,整个 batch 的 mask 完全一样,模型不收敛——和 jax 的坑同宗。
适用 vs 不适用场景
适用:
- JAX 上的任意神经网络训练(MLP / CNN / Transformer / Diffusion)
- TPU 大规模训练(jax sharding + Flax Module)
- 需要高阶导数 / 自定义训练循环的研究(meta-learning、隐式微分)
- 复现 DeepMind / Google 论文的 reference 实现
不适用:
- 纯推理部署 → 直接 jax + 编译,Flax 抽象多余
- 已有大量 pytorch 代码 → 迁移成本远高于继续用 pytorch-lightning
- 想用 HuggingFace 主流模型 → 大部分先有 PT 版,Flax 实现少
- 想要”跨后端最大公约数” → 用 keras 3,它把 jax / pytorch / tensorflow 都当后端
历史小故事(可跳过)
- 2020:Google Brain 推出 Flax,替代早期
jax.experimental.stax,对标 DeepMind 的 Haiku - 2021:AlphaFold 2 开源,训练栈全 Flax
- 2022-2023:linen API 稳定,PaLM / Gemma 都用它写
- 2024:
nnx作为新一代 API 公开,定位”linen 的精神继任者”,但 linen 不会废弃——两套长期并存
学到什么
- 参数和模型分开是函数式深度学习的核心——和 pytorch 把参数藏在
self.linear.weight是两条路 - Flax 不是替代 jax——它在 JAX 之上加”层 / 参数 / 训练状态”三件套,下面还是 grad/jit/vmap
- nnx vs linen 反映了一个张力:函数式纯度 vs 命令式手感,Flax 选择两套都给
- Optax + orbax + Flax 是 JAX 生态的”三件套”——分别对应优化器、checkpoint、神经网络
延伸阅读
- 官方文档:Flax Documentation — 先读 “Quick Start”,再选 linen 或 nnx
- nnx 教程:Why NNX — 解释为什么要新 API
- 参考实现:Gemma in Flax — Google 官方开源 LLM
- jax —— 必须先理解 JAX 的函数变换思想
- pytorch —— 对照看能秒懂 Flax 函数式哲学的反差
关联
- jax —— Flax 是 JAX 之上的神经网络层,下面引擎完全是 JAX
- pytorch —— 命令式 + 参数藏在 self.x.weight;Flax linen 反过来 params 完全外置
- keras —— Keras 3 把 JAX 当后端之一;Flax 只为 JAX 服务,更深度
- tensorflow —— Flax 与 TF 共享 XLA 后端,但前端范式(函数式 vs 图)完全不同
- pytorch-lightning —— PyTorch 上的训练循环抽象;Flax 的 TrainState 是 JAX 侧对应物
- accelerate —— HuggingFace 的设备/分布式抽象;JAX 自带 sharding 不需要它