跳转到内容

TensorFlow — Google 端到端 DL 平台

是什么

TensorFlow 是一个让神经网络从训练到上线全套都有官方答案的框架。日常类比:像一座带有自家发电厂、铁路、卡车、便利店的工业园区——你只要在园区里干活,从原材料到出厂上货架的每一段路都已经铺好。

你写:

import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1),
])
model.compile(optimizer='adam', loss='mse')
model.fit(x, y, epochs=10)
model.save('my_model') # 一行存出 SavedModel

最后那个 my_model 目录可以原封不动塞进 TF Serving(云上 gRPC 服务)、TFLite(手机)、TF.js(浏览器)、TFX(pipeline)。这就是 TensorFlow 的核心卖点——一次训练,全平台落地

为什么重要

不理解 TensorFlow,下面这些事都没法解释:

  • 为什么 2015-2018 年几乎所有公司的 ML 工程岗都要求 TF 经验
  • 为什么 Google 的 TPU(自家 AI 芯片)只有 TF 是一等公民——XLA 编译器是它和硬件之间的桥
  • 为什么手机上的人脸解锁、Google Translate 离线翻译大多跑在 TFLite
  • 为什么 PyTorch 抢了研究界,TF 仍稳坐工业部署——SavedModel + Serving + TFX 三件套没有同等替代

核心要点

TensorFlow 的能力可以拆成 四层

  1. Tensor + 计算图:所有数据是 tensor(多维数组),所有运算先编进一张(Graph),再交给后端跑。TF 1.x 强制先建图后跑,TF 2.x 默认 eager(边写边跑),加 @tf.function 才进图模式。

  2. Keras(高层 API):TF 2.x 把 Keras 吸收为官方门面(tf.keras)。Sequential / Model 三五行就拼出网络,.fit() 一行训完。

  3. XLA(编译器):把图编译成融合后的 GPU/TPU/CPU kernel。性能跃升的来源,也是 TPU 唯一能跑 TF 的原因。

  4. 部署矩阵:SavedModel(标准格式)→ TF Serving(云)/ TFLite(移动)/ TF.js(浏览器)/ TFX(pipeline)。同一个模型文件走完整条工业链。

实践案例

案例 1:Keras 三行训一个分类器

model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10),
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

新人最常见的入门栈,比 TF 1.x 那套 placeholder + session.run 简单 10 倍。

案例 2:@tf.function 把 eager 变图

@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
loss = loss_fn(model(x), y)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss

加一个装饰器,TF 自动把这段 Python 代码 trace 成静态图,之后每次调用都直接跑编译过的版本——慢的 Python 开销消失,性能逼近 TF 1.x。

案例 3:SavedModel 一次训练多端部署

Terminal window
# 训练完
model.save('mnist_savedmodel')
# 转 TFLite 上手机
tflite_convert --saved_model_dir=mnist_savedmodel --output_file=mnist.tflite
# 起 TF Serving 上服务器
docker run -p 8501:8501 -v $(pwd)/mnist_savedmodel:/models/mnist tensorflow/serving

一个目录 = 云端服务 + 手机推理 + 浏览器端,PyTorch 至今没有同等开箱方案。

踩过的坑

  1. TF 1.x → 2.x 是断崖升级tf.Sessiontf.placeholdertf.global_variables_initializer 全废。2019 年前的教程几乎全部失效,新人 Stack Overflow 抄来代码跑不动很常见。

  2. @tf.function 里的 Python 副作用只跑一次print()list.append() 只在首次 trace 时执行,之后图固化了不会再跑。调试时 print 没输出新人常误以为函数没被调。改用 tf.print

  3. eager 模式好调试 vs graph 模式快:两种模式行为不完全一致——eager 下能跑的代码加 @tf.function 后可能因 Python 控制流 trace 失败而报错。

  4. TFLite 算子是 TF 的子集:训练用了某个新层,转 TFLite 可能直接报”unsupported op”。生产前必须先跑转换检查。

  5. CUDA / cuDNN / TF 版本矩阵严苛:TF 2.15 配 CUDA 12.2 + cuDNN 8.9,错一位就 Could not load dynamic library。装 tensorflow[and-cuda] 是新版省事路径。

适用 vs 不适用场景

适用

  • 端到端工业部署(云 + 移动 + 浏览器 + 嵌入式 → TFLite/TFJS/Serving/TFX)
  • Google Cloud + TPU 训练(XLA 把 TF 翻成 TPU 原生指令)
  • 需要 SavedModel 标准格式跨团队交付的大组织
  • 嵌入式 / 边缘 ML(TFLite Micro 是少数能跑 MCU 的 DL runtime)

不适用

  • 学术研究快速试错(PyTorch 动态图 + Python 调试体验更顺)→ 见 pytorch
  • 单 GPU 个人项目快速实验(TF 安装 + 启动开销不划算)
  • 极客自定义算子(PyTorch + Triton 比 TF 自定义 op 友好得多)
  • 想要”一切皆函数”的范式(JAX / Flax 的 vmap/grad/pmap 更纯)

历史小故事(可跳过)

  • 2011 年:Google Brain 内部框架 DistBelief 跑出第一代大规模分布式训练,但耦合 Google 内部基建,外部用不了。
  • 2015-11:Google 把 DistBelief 重写为 TensorFlow 并开源(Apache 2.0)。当年立即成为深度学习最热项目。
  • 2017 年:TF 1.x 时代,静态图 + session.run 是行业默认,但门槛极高。
  • 2019-09TF 2.0 发布——eager 默认开、Keras 升为官方 API、@tf.function 出现,终于”像 Python 了”。
  • 现在:~187k stars,工业部署事实标准;研究界让位 PyTorch,但 TPU + 移动端是它无可替代的护城河。

学到什么

  1. 一个框架的命运在生态而不是 API——PyTorch 接管研究后 TF 仍立得住,靠的是 TFLite/Serving/TFX 这些”PyTorch 没有官方替代”的部署链。
  2. 静态图 vs 动态图之争最后的答案是”两者都要”——TF 2.x 的 @tf.function 和 PyTorch 2.0 的 torch.compile 殊途同归。
  3. 编译器 + 硬件协同才能解锁 ASIC——XLA 是 TPU 能跑 TF 的根本原因。
  4. 破坏性升级要付学费:TF 1→2 让 Google 自己付出三年迁移成本,新人也被旧教程坑。API 稳定是隐性资产。
  5. Keras 的吸收是聪明的并购——把第三方友好的高层 API 收为官方门面,门槛瞬间降一半。

延伸阅读

关联

  • pytorch —— 同代深度学习框架,TF 输研究、赢部署,互为镜像
  • fastai —— PyTorch 高层封装;Keras 之于 TF 即 fastai 之于 PyTorch(思路相通)
  • accelerate —— HuggingFace 在 PyTorch 之上的设备/分布式抽象,对应 TF 的 tf.distribute.Strategy
  • llvm —— XLA 的”多级 IR + 多后端”工程范式继承自 LLVM
  • lambda-calculus —— @tf.function 把 Python 函数 trace 成图,本质是高阶函数到表达式树的变换