跳转到内容

TensorFlow — 把神经网络拆成数据流图再跑到任何机器上

是什么

TensorFlow 是 Google 2015 年开源的机器学习系统。它的核心想法是把一个神经网络先画成一张图,再让运行时把这张图切片、分发到 CPU/GPU/TPU 上跑。

日常类比:像菜谱与厨房分离。你写菜谱(图)的时候不关心几口锅几个灶;执行的时候系统自己分锅分火。换厨房(GPU 改 TPU)只换执行器,不改菜谱。

它的前身叫 DistBelief,是 Google 内部 2011 年的训练框架,只能跑特定模型架构。TensorFlow 是 Google 写给”任意模型 + 任意硬件 + 任意规模”的第二代系统。

为什么重要

不理解 TensorFlow 这一代设计,很多事情解释不清:

  • 为什么 PyTorch 的 eager 模式当年被吹成”革命”——因为它的对手就是 TF 的”先建图再跑”
  • 为什么后来出现 tf.function / torch.compile——大家又想要图带来的优化
  • 为什么 JAX 的 jit + grad + vmap 看起来熟悉——它继承了 TF 的图思想再做减法
  • 为什么”算子”(operator)这个词在 ML 系统里这么基础——TF 把它定成了图的节点单位

简单说:现代 ML 系统的设计词汇表(图、算子、设备、自动微分、参数服务器)大半是 TF 这一代论文沉淀下来的。

核心要点

TensorFlow 的设计可以拆成 四块

  1. 数据流图(dataflow graph):节点是算子(matmul / relu / conv),边是张量(多维数组)。整张图描述一次前向 + 反向 + 更新。

  2. 可变状态作为节点:模型参数(weights)不是普通张量,是图里一种特殊节点叫 Variable。这让”参数更新”也变成图里的一条边。DistBelief 把状态藏在外部参数服务器,TF 把它拉进图里。

  3. 自动切分到设备:图建好后,运行时根据设备约束(“这个算子放 GPU0”)和数据依赖把图切成子图,每台机器跑一片,跨机用 send/recv 节点桥接。

  4. 同一抽象同时支持训练与推理:训练时多一条反向 + 优化器子图,推理时只跑前向。两边共享算子和图调度器。

四块加起来给出一个统一描述:「程序 = 图」

实践案例

案例 1:一个最小图

import tensorflow as tf
W = tf.Variable(tf.random.normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
x = tf.placeholder(tf.float32, [None, 784])
logits = tf.matmul(x, W) + b

这四行没跑任何计算,只是建了一张图。logits 是图里一个节点的引用,不是数字。要真正算,得 sess.run(logits, feed_dict={x: ...})——这一步运行时才把图切片、分到 GPU、执行。

把这个流程画出来:

Python 脚本(建图阶段) → GraphDef(图的 protobuf 描述)
Session(执行阶段)
运行时切片 → 各设备本地子图 → 算子内核执行

建图和执行是两个世界,这是 TF 1 的设计核心,也是它最被诟病的点。

案例 2:自动微分怎么发生

你写 loss = ...,调一次 tf.gradients(loss, [W, b]),TF 反向遍历图,对每个算子查”反向规则表”(matmul 的反向是另一次 matmul),自动生成反向子图。你不用手写偏导。

这是把 1970 年代的反向模式自动微分(reverse-mode AD)工业化的关键——有图就能自动生成反向

案例 3:分布式训练的两种姿势

  • 数据并行 + 异步参数服务器:worker 各算各的梯度,发给 PS,PS 异步更新。快但收敛抖。
  • 数据并行 + 同步:所有 worker 算完梯度做 all-reduce 再更新。慢一点但稳。

TF 把这两种都画成图——参数服务器、worker、梯度交换都是图节点。换策略只换图结构,模型代码不动。

论文里跑 Inception v3 给的数据:50 GPU 同步训练比单 GPU 快约 40 倍,异步参数服务器可以扩到 100+ worker 但每 worker 利用率下降。这两条曲线后来被 ResNet 时代的 all-reduce 改写。

案例 4:算子和内核分离

tf.matmul 在图里是一个算子(operator),它有多个内核(kernel)实现:CPU x86 一个、CUDA GPU 一个、TPU 一个。算子是接口,内核是平台特化实现。

这套分层让 TF 加新硬件不用改图——XLA 编译器、TPU 后端、ROCm(AMD GPU)后端都按这个接口接进来。后来的 PyTorch dispatcher、JAX 的 lax 都沿用了同样的”算子 + 后端”两层设计。

踩过的坑

  1. 静态图调试痛:图建好才能跑,print 一个中间张量得插 tf.Print 节点。新人来一句 print(x) 出来一个 Tensor("matmul:0", shape=...),看不到值。这就是后来 PyTorch eager 模式吃掉 TF 用户群的最大原因。

  2. Session 与 Graph 是两个对象:图描述结构,Session 才管设备和执行。新人常把两件事混。tf.Session().run(...) 一不小心创了多个 session,状态对不上。

  3. Variable 共享要用 scope:写两遍 tf.Variable(...) 会创两个变量,不会复用。复用得用 tf.variable_scope("foo", reuse=True)。这个 API 后来被骂太复杂,TF 2 用 Keras 的 layer 重做了。

  4. 同步/异步训练收敛性不同:异步参数服务器在大规模时收敛速度可能不如同步 all-reduce,但论文当时主推异步。后来 ResNet 时代发现同步更稳,工业界大多回到 all-reduce。

适用 vs 不适用场景

适用

  • 大规模分布式训练(多机多 GPU)——图 + 设备分配天然支持
  • 推理服务化(TensorFlow Serving)——同一张图直接部署
  • 跨硬件后端(CPU / GPU / TPU)——换设备不换代码
  • 需要图层面优化(XLA 编译、算子融合)的场景

不适用

  • 研究阶段快速实验——eager 模式(PyTorch / JAX)调试更快
  • 控制流复杂的模型(树、变长 RNN、强化学习)——静态图表达起来别扭,要 tf.cond / tf.while_loop
  • 小模型小数据——图建图开销大于收益

历史小故事(可跳过)

  • 2011 年:Google Brain 内部用 DistBelief 训 ImageNet,发现框架太死,只支持 feedforward + 参数服务器异步。
  • 2014 年:Jeff Dean 立项第二代系统,目标是”任意模型 + 任意硬件 + 开源”。
  • 2015 年 11 月:TensorFlow 1.0 开源,带着 DistBelief 内部经验和 Theano 的图思想。
  • 2016 年 5 月:本论文 OSDI 投稿。Google 同期论文 TPU 第一代也发了。
  • 2017-2019 年:PyTorch 用 eager 模式抢走研究市场。
  • 2019 年:TF 2.0 默认 eager + tf.function 装饰器把图思想保留给生产部署。

之后图思想没死——JAX 的 jit、PyTorch 2.0 的 torch.compile 都是回到图。

学到什么

  1. 「程序 = 图」是 ML 系统的关键抽象——把计算和调度解耦,才能跨硬件跨规模复用。
  2. 状态进入图比 DistBelief 的”状态在外”更优雅——参数更新也成了一条边,分布式策略只是图变换。
  3. 自动微分需要图——有了图,反向规则表加一遍遍历就能生成反向。
  4. 静态图 vs 动态图不是非此即彼——eager 调试 + 图部署是后来的共识,TF 1 / PyTorch 1 各占一头,TF 2 / PyTorch 2 都向中间收敛。
  5. 基础设施的设计语言会沉淀十年——「算子 / 图 / 设备 / Variable」这套词到今天还在用。

延伸阅读

关联

  • pytorch —— 同代框架,靠 eager 调试体验击穿研究市场
  • jax —— Google 第三代 ML 系统,把 TF 的图思想做成纯函数变换
  • alpa-2022 —— 把 TF 的设备分配自动化为搜索问题
  • keras —— TF 2 默认前端,把”建图”包成 layer 对象
  • mlx —— Apple Silicon 的同思想精简版
  • ssa —— 编译器的图思想,TF 的图与 SSA 是远亲
  • kildall-dataflow —— 数据流分析的理论祖先,图节点 + 边的同构思路

反向链接

  • adam-2014 —— Adam — 让深度学习自己挑步长的优化器
  • alpa-2022 —— Alpa — 把张量/流水/数据并行统一成一道搜索题
  • amdahl-law-1967 —— Amdahl 定律 — 串行比例决定并行加速比的上界
  • jax —— JAX — Google 函数式数值计算
  • kildall-dataflow —— Kildall 数据流框架 — 用一套格论统一所有全局编译优化
  • pytorch —— PyTorch — 深度学习主流框架
  • ssa —— SSA — 静态单赋值形式
  • tvm-2018 —— TVM OSDI 2018 — 把 Halide 思想搬到深度学习