TensorFlow — 把神经网络拆成数据流图再跑到任何机器上
是什么
TensorFlow 是 Google 2015 年开源的机器学习系统。它的核心想法是把一个神经网络先画成一张图,再让运行时把这张图切片、分发到 CPU/GPU/TPU 上跑。
日常类比:像菜谱与厨房分离。你写菜谱(图)的时候不关心几口锅几个灶;执行的时候系统自己分锅分火。换厨房(GPU 改 TPU)只换执行器,不改菜谱。
它的前身叫 DistBelief,是 Google 内部 2011 年的训练框架,只能跑特定模型架构。TensorFlow 是 Google 写给”任意模型 + 任意硬件 + 任意规模”的第二代系统。
为什么重要
不理解 TensorFlow 这一代设计,很多事情解释不清:
- 为什么 PyTorch 的 eager 模式当年被吹成”革命”——因为它的对手就是 TF 的”先建图再跑”
- 为什么后来出现
tf.function/torch.compile——大家又想要图带来的优化 - 为什么 JAX 的
jit + grad + vmap看起来熟悉——它继承了 TF 的图思想再做减法 - 为什么”算子”(operator)这个词在 ML 系统里这么基础——TF 把它定成了图的节点单位
简单说:现代 ML 系统的设计词汇表(图、算子、设备、自动微分、参数服务器)大半是 TF 这一代论文沉淀下来的。
核心要点
TensorFlow 的设计可以拆成 四块:
-
数据流图(dataflow graph):节点是算子(matmul / relu / conv),边是张量(多维数组)。整张图描述一次前向 + 反向 + 更新。
-
可变状态作为节点:模型参数(weights)不是普通张量,是图里一种特殊节点叫 Variable。这让”参数更新”也变成图里的一条边。DistBelief 把状态藏在外部参数服务器,TF 把它拉进图里。
-
自动切分到设备:图建好后,运行时根据设备约束(“这个算子放 GPU0”)和数据依赖把图切成子图,每台机器跑一片,跨机用 send/recv 节点桥接。
-
同一抽象同时支持训练与推理:训练时多一条反向 + 优化器子图,推理时只跑前向。两边共享算子和图调度器。
四块加起来给出一个统一描述:「程序 = 图」。
实践案例
案例 1:一个最小图
import tensorflow as tfW = tf.Variable(tf.random.normal([784, 10]))b = tf.Variable(tf.zeros([10]))x = tf.placeholder(tf.float32, [None, 784])logits = tf.matmul(x, W) + b这四行没跑任何计算,只是建了一张图。logits 是图里一个节点的引用,不是数字。要真正算,得 sess.run(logits, feed_dict={x: ...})——这一步运行时才把图切片、分到 GPU、执行。
把这个流程画出来:
Python 脚本(建图阶段) → GraphDef(图的 protobuf 描述) ↓ Session(执行阶段) ↓ 运行时切片 → 各设备本地子图 → 算子内核执行建图和执行是两个世界,这是 TF 1 的设计核心,也是它最被诟病的点。
案例 2:自动微分怎么发生
你写 loss = ...,调一次 tf.gradients(loss, [W, b]),TF 反向遍历图,对每个算子查”反向规则表”(matmul 的反向是另一次 matmul),自动生成反向子图。你不用手写偏导。
这是把 1970 年代的反向模式自动微分(reverse-mode AD)工业化的关键——有图就能自动生成反向。
案例 3:分布式训练的两种姿势
- 数据并行 + 异步参数服务器:worker 各算各的梯度,发给 PS,PS 异步更新。快但收敛抖。
- 数据并行 + 同步:所有 worker 算完梯度做 all-reduce 再更新。慢一点但稳。
TF 把这两种都画成图——参数服务器、worker、梯度交换都是图节点。换策略只换图结构,模型代码不动。
论文里跑 Inception v3 给的数据:50 GPU 同步训练比单 GPU 快约 40 倍,异步参数服务器可以扩到 100+ worker 但每 worker 利用率下降。这两条曲线后来被 ResNet 时代的 all-reduce 改写。
案例 4:算子和内核分离
tf.matmul 在图里是一个算子(operator),它有多个内核(kernel)实现:CPU x86 一个、CUDA GPU 一个、TPU 一个。算子是接口,内核是平台特化实现。
这套分层让 TF 加新硬件不用改图——XLA 编译器、TPU 后端、ROCm(AMD GPU)后端都按这个接口接进来。后来的 PyTorch dispatcher、JAX 的 lax 都沿用了同样的”算子 + 后端”两层设计。
踩过的坑
-
静态图调试痛:图建好才能跑,print 一个中间张量得插
tf.Print节点。新人来一句print(x)出来一个Tensor("matmul:0", shape=...),看不到值。这就是后来 PyTorch eager 模式吃掉 TF 用户群的最大原因。 -
Session 与 Graph 是两个对象:图描述结构,Session 才管设备和执行。新人常把两件事混。
tf.Session().run(...)一不小心创了多个 session,状态对不上。 -
Variable 共享要用 scope:写两遍
tf.Variable(...)会创两个变量,不会复用。复用得用tf.variable_scope("foo", reuse=True)。这个 API 后来被骂太复杂,TF 2 用 Keras 的 layer 重做了。 -
同步/异步训练收敛性不同:异步参数服务器在大规模时收敛速度可能不如同步 all-reduce,但论文当时主推异步。后来 ResNet 时代发现同步更稳,工业界大多回到 all-reduce。
适用 vs 不适用场景
适用:
- 大规模分布式训练(多机多 GPU)——图 + 设备分配天然支持
- 推理服务化(TensorFlow Serving)——同一张图直接部署
- 跨硬件后端(CPU / GPU / TPU)——换设备不换代码
- 需要图层面优化(XLA 编译、算子融合)的场景
不适用:
- 研究阶段快速实验——eager 模式(PyTorch / JAX)调试更快
- 控制流复杂的模型(树、变长 RNN、强化学习)——静态图表达起来别扭,要
tf.cond/tf.while_loop - 小模型小数据——图建图开销大于收益
历史小故事(可跳过)
- 2011 年:Google Brain 内部用 DistBelief 训 ImageNet,发现框架太死,只支持 feedforward + 参数服务器异步。
- 2014 年:Jeff Dean 立项第二代系统,目标是”任意模型 + 任意硬件 + 开源”。
- 2015 年 11 月:TensorFlow 1.0 开源,带着 DistBelief 内部经验和 Theano 的图思想。
- 2016 年 5 月:本论文 OSDI 投稿。Google 同期论文 TPU 第一代也发了。
- 2017-2019 年:PyTorch 用 eager 模式抢走研究市场。
- 2019 年:TF 2.0 默认 eager +
tf.function装饰器把图思想保留给生产部署。
之后图思想没死——JAX 的 jit、PyTorch 2.0 的 torch.compile 都是回到图。
学到什么
- 「程序 = 图」是 ML 系统的关键抽象——把计算和调度解耦,才能跨硬件跨规模复用。
- 状态进入图比 DistBelief 的”状态在外”更优雅——参数更新也成了一条边,分布式策略只是图变换。
- 自动微分需要图——有了图,反向规则表加一遍遍历就能生成反向。
- 静态图 vs 动态图不是非此即彼——eager 调试 + 图部署是后来的共识,TF 1 / PyTorch 1 各占一头,TF 2 / PyTorch 2 都向中间收敛。
- 基础设施的设计语言会沉淀十年——「算子 / 图 / 设备 / Variable」这套词到今天还在用。
延伸阅读
- 论文 PDF(18 页):TensorFlow OSDI 2016
- 前作论文:DistBelief NIPS 2012(理解 TF 在解决什么问题)
- 设计回顾:Jeff Dean — A Look Back at TensorFlow
- pytorch —— 主要竞品,先 eager 后图
- jax —— Google 自己的下一代,图思想 + 函数式
- alpa-2022 —— 把 TF 的设备分配思想推到自动并行
关联
- pytorch —— 同代框架,靠 eager 调试体验击穿研究市场
- jax —— Google 第三代 ML 系统,把 TF 的图思想做成纯函数变换
- alpa-2022 —— 把 TF 的设备分配自动化为搜索问题
- keras —— TF 2 默认前端,把”建图”包成 layer 对象
- mlx —— Apple Silicon 的同思想精简版
- ssa —— 编译器的图思想,TF 的图与 SSA 是远亲
- kildall-dataflow —— 数据流分析的理论祖先,图节点 + 边的同构思路
反向链接
- adam-2014 —— Adam — 让深度学习自己挑步长的优化器
- alpa-2022 —— Alpa — 把张量/流水/数据并行统一成一道搜索题
- amdahl-law-1967 —— Amdahl 定律 — 串行比例决定并行加速比的上界
- jax —— JAX — Google 函数式数值计算
- kildall-dataflow —— Kildall 数据流框架 — 用一套格论统一所有全局编译优化
- pytorch —— PyTorch — 深度学习主流框架
- ssa —— SSA — 静态单赋值形式
- tvm-2018 —— TVM OSDI 2018 — 把 Halide 思想搬到深度学习