跳转到内容

TASO — 让机器自己发现深度学习图重写规则

是什么

TASO 是 2019 年 SOSP 上 Stanford 发的论文,提出一个想法:深度学习编译器里的”图重写规则”,不应该让专家一条条手写,而应该让机器自动枚举、再用数学证明它对,最后让搜索算法选用哪几条

日常类比:以前 XLA / TensorRT 的工程师像中医,靠经验背几百条”这两味药可以替换那一味”的口诀,背错了病人就死。TASO 的思路是把这件事工业化——让机器把所有可能的”两味药等于一味药”的组合枚举出来,每一条都让计算机证明”在任何病人身上都等价”,然后给医生一本有 700 多条经过数学验证的处方手册。

它的核心论点:图级优化的瓶颈不是搜索算法,而是规则集是否够大、够对

为什么重要

不理解这篇论文,下面这些事都没法解释:

  • 为什么 PyTorch 2.0 的 torch.compile 做算子融合时不再像 XLA 1.0 那样写死 pattern——它的 Inductor 后端引用了 TASO 的设计
  • 为什么 MLIR 出现后会专门加一种 PDL(Pattern Description Language)配合验证器——TASO 把”自动生成 + 自动验证”的模式带进了主流编译器栈
  • 为什么”自动机器学习编译器”成立——只要规则正确性可证明,规则数量可以放大一个数量级
  • 为什么 xla-compiler 在 2020 年后明显加快了重写规则的更新节奏——TASO 给出了规则的工业化生产线

核心要点

TASO 的设计是三个解耦的阶段

  1. 枚举(Generator):把基本算子(conv / matmul / add / split / concat 等)当原子,枚举所有节点数不超过 4 的子图。这一步不管对不对,纯粹”先把候选写满”。

  2. 验证(Verifier):从一组代数公理(结合律、分配律、线性性、矩阵乘的关联性等)出发,自动判断”子图 A 和子图 B 是否在任意输入下输出相同”。等价的两个子图就成为一条候选重写规则

  3. 搜索(Optimizer):把验证过的约 743 条规则喂给一个 cost-based 回溯搜索器。搜索器以”整图执行总时间”为目标函数,反复尝试用规则替换图中的子结构,挑选最快的重写序列。

类比:第一步像把所有”可能的拼图块”印出来;第二步像让验厂师挨个检查”这块和那块是不是真的能互换”;第三步像让机器人在拼图盘上反复换块,直到拼出最快的那种摆法。

实践案例

案例 1:一条 TASO 自动发现的规则

下面是论文里的一条典型规则(手写库里没有这一条):

Conv(x, w1) + Conv(x, w2) 等价于 Conv(x, concat(w1, w2)) 的某个切片

逐部分解释:

  • 左边:用 x 分别和两个卷积核 w1w2 做卷积,再相加——两次访存、两次卷积
  • 右边:先把 w1w2 在通道维度拼起来,再做一次卷积——一次访存、一次卷积
  • 验证器靠卷积对加法的分配律证明两边等价
  • 搜索器在 GPU 上比较时间,发现右边在大 batch 下快 2 倍

工程师没写这条规则,TASO 自己枚举出来并证明了。这种”两次小卷积合一次大卷积”在传统库里要靠人观察 ResNet 才会想到,TASO 把它升级成对任意网络都成立的通用模式。

案例 2:搜索过程长什么样

给 TASO 一张 BERT 的计算图(约 1500 个节点),它的搜索循环大致是:

  1. 在图里找所有”能套进某条规则左边模式”的子图位置
  2. 试着替换成右边模式,估算新图的总执行时间(用 cost model)
  3. 如果变快就保留,变慢就回退
  4. 反复迭代到没有规则能让图变快为止

最后 BERT 推理比 TensorFlow XLA 快 1.6 倍。整个过程没有人工干预。

案例 3:验证器为什么必要

不验证会怎样?看这条错误规则:

Reshape(x, [a, b]) + Reshape(x, [b, a]) 假装等于 2 * Reshape(x, [a, b])

形状不一样,不能加,但模式匹配引擎可能因为算子名匹配就替换了。手写库里这种 bug 修过很多次。TASO 的验证器会立刻拒绝这条——因为代数公理推不出两边的形状相等。

踩过的坑

  1. 枚举爆炸:节点数从 4 加到 5,子图候选数量涨 30 倍。论文工程上限定到 4 节点,这是务实但也是局限。

  2. 公理库不闭合:TASO 的代数公理是手写的。新加一个算子(比如 LayerNorm),如果没补对应公理,验证器会把所有相关规则都拒绝。自动化只挪到了”规则发现”这一步,没消除人工

  3. cost model 不准:搜索器靠 cost model 判断快慢,但 GPU 实际性能受 kernel launch 开销、cache、内存带宽影响。论文用真机 profiling 兜底,否则会选出”理论快但实际慢”的方案。

  4. 数值精度变形:浮点的结合律不严格成立——(a + b) + ca + (b + c) 在 float32 下可能差最后一位。TASO 的等价是”代数等价”,不保证比特等价;做训练时要小心引起 loss 曲线漂移。

  5. 规则之间的相互作用:单条规则验证过等价,但两条规则连用未必生成最优解。搜索器要尝试不同顺序,回溯空间因此变大;论文里靠剪枝 + cost 估算控制爆炸。

适用 vs 不适用场景

适用

  • 深度学习推理编译器(XLA / TensorRT / IREE / TVM Relay)的图级优化阶段
  • 需要把”规则正确性”当一等公民的场景(数据库查询优化也是同源思路)
  • 需要规则集快速扩张但不能引入 bug 的工业场景

不适用

  • 规则空间太大无法枚举(比如 LLM 训练的 megakernel 融合,子图节点数轻松破百)
  • 算子语义无法用代数公理刻画(比如随机性算子、控制流密集的代码)
  • 实时性要求高的场景——TASO 搜索本身要几分钟,适合编译期不适合运行时

一个完整的搜索循环长什么样

把上面三个案例串起来看一次 TASO 跑 ResNet-50 的全过程:

  1. 离线一次性:枚举器跑约 10 小时,生成约 743 条候选规则;验证器对每条规则在 SMT 求解器里证明等价(通过约 60% 的候选)
  2. 加载阶段:把验证过的规则集按”左边模式的算子签名”建索引,以便运行时快速匹配
  3. 在线搜索:拿到 ResNet-50 的图(约 200 个节点),运行回溯搜索约 2 分钟,遍历约 5000 个候选重写序列
  4. 真机校准:对每个候选最优解在 GPU 上实测一次,挑选实际最快的——避免 cost model 估错
  5. 输出:把最终图序列化回 ONNX,下游的 TVM / TensorRT / CUDA backend 接管做 codegen

整个流程对用户透明——用户只需把模型扔进去,几分钟后拿到一张快得多的等价图。

历史小故事(可跳过)

  • 2018 年:Zhihao Jia 在 Stanford 做的前作 MetaFlow 已经把”图重写 + 搜索”结合,但规则仍是手写
  • 2019 年:Jia 发现规则手写无法规模化,找 Oded Padon(验证背景)合作,把代数公理验证引擎加进来——TASO 诞生,比 MetaFlow 提速 1.3 到 2.8 倍
  • 2020 年起:Jia 接着做 FlexFlow / Unity,把”自动重写”思路从单机扩到分布式,影响了 alpa-2022 的并行策略搜索
  • 2022 年:PyTorch 2.0 团队在 Inductor 设计文档里引用 TASO,作为”为什么不再手写 fusion pattern”的论据
  • 2024 年:Jia 的团队继续把这套思路推到 LLM 训练,做出 Mirage 系统——同样是”枚举 + 验证 + 搜索”,但子图换成 megakernel 级别

学到什么

  1. 解耦比聪明更重要:把”规则发现 / 规则验证 / 规则使用”拆成三个独立阶段,每个阶段都能独立替换或加速——这是 TASO 真正的工程贡献
  2. 形式化验证不是学术玩具:在 700 多条规则的工业系统里,机器证明比人工 review 便宜得多
  3. 算法搜索的瓶颈往往在搜索空间本身:TASO 没换搜索算法(仍是回溯),只是把候选规则从 30 条扩到 743 条,效果就上去了
  4. 代数公理是规则的源真相:算子层加新东西时,先扩公理再扩规则,不要反过来
  5. 离线 / 在线分层:把昂贵的枚举与证明放离线一次跑完,把便宜的搜索放在线对每个模型跑——这是工业系统通用的省钱套路

延伸阅读

  • 论文 PDF:TASO SOSP 2019(16 页,前 8 页可读懂大部分)
  • 代码仓库:github.com/jiazhihao/TASO(含验证器实现,C++ 写的)
  • 后续工作:FlexFlow Unity OSDI 2022(同一作者把 TASO 思路扩到分布式)
  • 视频讲解:作者 Zhihao Jia 在 SOSP 2019 的 25 分钟报告——可以从 ACM Digital Library 找到
  • xla-compiler —— XLA 的 Pattern Matcher 是 TASO 思路的工业落地
  • tvm-2018 —— TVM 也有图级优化,但仍依赖手写规则;可以对比两者的取舍

关联

  • xla-compiler —— XLA 的图级 fusion 受 TASO 影响最大;二者的 pattern 库可以互译
  • tvm-2018 —— 同样做编译器栈,但 TVM 重在算子级 schedule,TASO 重在图级重写;两层正交
  • alpa-2022 —— Jia 后续作品,把”自动搜索”从图重写扩到并行策略
  • pytorch —— PyTorch 2.0 Inductor 引用 TASO 设计

反向链接

  • alpa-2022 —— Alpa — 把张量/流水/数据并行统一成一道搜索题
  • pytorch —— PyTorch — 深度学习主流框架
  • xla-compiler —— XLA — 给 TensorFlow / JAX 装一台真正的张量编译器