TASO — 让机器自己发现深度学习图重写规则
是什么
TASO 是 2019 年 SOSP 上 Stanford 发的论文,提出一个想法:深度学习编译器里的”图重写规则”,不应该让专家一条条手写,而应该让机器自动枚举、再用数学证明它对,最后让搜索算法选用哪几条。
日常类比:以前 XLA / TensorRT 的工程师像中医,靠经验背几百条”这两味药可以替换那一味”的口诀,背错了病人就死。TASO 的思路是把这件事工业化——让机器把所有可能的”两味药等于一味药”的组合枚举出来,每一条都让计算机证明”在任何病人身上都等价”,然后给医生一本有 700 多条经过数学验证的处方手册。
它的核心论点:图级优化的瓶颈不是搜索算法,而是规则集是否够大、够对。
为什么重要
不理解这篇论文,下面这些事都没法解释:
- 为什么 PyTorch 2.0 的
torch.compile做算子融合时不再像 XLA 1.0 那样写死 pattern——它的 Inductor 后端引用了 TASO 的设计 - 为什么 MLIR 出现后会专门加一种 PDL(Pattern Description Language)配合验证器——TASO 把”自动生成 + 自动验证”的模式带进了主流编译器栈
- 为什么”自动机器学习编译器”成立——只要规则正确性可证明,规则数量可以放大一个数量级
- 为什么 xla-compiler 在 2020 年后明显加快了重写规则的更新节奏——TASO 给出了规则的工业化生产线
核心要点
TASO 的设计是三个解耦的阶段:
-
枚举(Generator):把基本算子(conv / matmul / add / split / concat 等)当原子,枚举所有节点数不超过 4 的子图。这一步不管对不对,纯粹”先把候选写满”。
-
验证(Verifier):从一组代数公理(结合律、分配律、线性性、矩阵乘的关联性等)出发,自动判断”子图 A 和子图 B 是否在任意输入下输出相同”。等价的两个子图就成为一条候选重写规则。
-
搜索(Optimizer):把验证过的约 743 条规则喂给一个 cost-based 回溯搜索器。搜索器以”整图执行总时间”为目标函数,反复尝试用规则替换图中的子结构,挑选最快的重写序列。
类比:第一步像把所有”可能的拼图块”印出来;第二步像让验厂师挨个检查”这块和那块是不是真的能互换”;第三步像让机器人在拼图盘上反复换块,直到拼出最快的那种摆法。
实践案例
案例 1:一条 TASO 自动发现的规则
下面是论文里的一条典型规则(手写库里没有这一条):
Conv(x, w1) + Conv(x, w2) 等价于 Conv(x, concat(w1, w2)) 的某个切片逐部分解释:
- 左边:用
x分别和两个卷积核w1、w2做卷积,再相加——两次访存、两次卷积 - 右边:先把
w1、w2在通道维度拼起来,再做一次卷积——一次访存、一次卷积 - 验证器靠卷积对加法的分配律证明两边等价
- 搜索器在 GPU 上比较时间,发现右边在大 batch 下快 2 倍
工程师没写这条规则,TASO 自己枚举出来并证明了。这种”两次小卷积合一次大卷积”在传统库里要靠人观察 ResNet 才会想到,TASO 把它升级成对任意网络都成立的通用模式。
案例 2:搜索过程长什么样
给 TASO 一张 BERT 的计算图(约 1500 个节点),它的搜索循环大致是:
- 在图里找所有”能套进某条规则左边模式”的子图位置
- 试着替换成右边模式,估算新图的总执行时间(用 cost model)
- 如果变快就保留,变慢就回退
- 反复迭代到没有规则能让图变快为止
最后 BERT 推理比 TensorFlow XLA 快 1.6 倍。整个过程没有人工干预。
案例 3:验证器为什么必要
不验证会怎样?看这条错误规则:
Reshape(x, [a, b]) + Reshape(x, [b, a]) 假装等于 2 * Reshape(x, [a, b])形状不一样,不能加,但模式匹配引擎可能因为算子名匹配就替换了。手写库里这种 bug 修过很多次。TASO 的验证器会立刻拒绝这条——因为代数公理推不出两边的形状相等。
踩过的坑
-
枚举爆炸:节点数从 4 加到 5,子图候选数量涨 30 倍。论文工程上限定到 4 节点,这是务实但也是局限。
-
公理库不闭合:TASO 的代数公理是手写的。新加一个算子(比如 LayerNorm),如果没补对应公理,验证器会把所有相关规则都拒绝。自动化只挪到了”规则发现”这一步,没消除人工。
-
cost model 不准:搜索器靠 cost model 判断快慢,但 GPU 实际性能受 kernel launch 开销、cache、内存带宽影响。论文用真机 profiling 兜底,否则会选出”理论快但实际慢”的方案。
-
数值精度变形:浮点的结合律不严格成立——
(a + b) + c和a + (b + c)在 float32 下可能差最后一位。TASO 的等价是”代数等价”,不保证比特等价;做训练时要小心引起 loss 曲线漂移。 -
规则之间的相互作用:单条规则验证过等价,但两条规则连用未必生成最优解。搜索器要尝试不同顺序,回溯空间因此变大;论文里靠剪枝 + cost 估算控制爆炸。
适用 vs 不适用场景
适用:
- 深度学习推理编译器(XLA / TensorRT / IREE / TVM Relay)的图级优化阶段
- 需要把”规则正确性”当一等公民的场景(数据库查询优化也是同源思路)
- 需要规则集快速扩张但不能引入 bug 的工业场景
不适用:
- 规则空间太大无法枚举(比如 LLM 训练的 megakernel 融合,子图节点数轻松破百)
- 算子语义无法用代数公理刻画(比如随机性算子、控制流密集的代码)
- 实时性要求高的场景——TASO 搜索本身要几分钟,适合编译期不适合运行时
一个完整的搜索循环长什么样
把上面三个案例串起来看一次 TASO 跑 ResNet-50 的全过程:
- 离线一次性:枚举器跑约 10 小时,生成约 743 条候选规则;验证器对每条规则在 SMT 求解器里证明等价(通过约 60% 的候选)
- 加载阶段:把验证过的规则集按”左边模式的算子签名”建索引,以便运行时快速匹配
- 在线搜索:拿到 ResNet-50 的图(约 200 个节点),运行回溯搜索约 2 分钟,遍历约 5000 个候选重写序列
- 真机校准:对每个候选最优解在 GPU 上实测一次,挑选实际最快的——避免 cost model 估错
- 输出:把最终图序列化回 ONNX,下游的 TVM / TensorRT / CUDA backend 接管做 codegen
整个流程对用户透明——用户只需把模型扔进去,几分钟后拿到一张快得多的等价图。
历史小故事(可跳过)
- 2018 年:Zhihao Jia 在 Stanford 做的前作 MetaFlow 已经把”图重写 + 搜索”结合,但规则仍是手写
- 2019 年:Jia 发现规则手写无法规模化,找 Oded Padon(验证背景)合作,把代数公理验证引擎加进来——TASO 诞生,比 MetaFlow 提速 1.3 到 2.8 倍
- 2020 年起:Jia 接着做 FlexFlow / Unity,把”自动重写”思路从单机扩到分布式,影响了 alpa-2022 的并行策略搜索
- 2022 年:PyTorch 2.0 团队在 Inductor 设计文档里引用 TASO,作为”为什么不再手写 fusion pattern”的论据
- 2024 年:Jia 的团队继续把这套思路推到 LLM 训练,做出 Mirage 系统——同样是”枚举 + 验证 + 搜索”,但子图换成 megakernel 级别
学到什么
- 解耦比聪明更重要:把”规则发现 / 规则验证 / 规则使用”拆成三个独立阶段,每个阶段都能独立替换或加速——这是 TASO 真正的工程贡献
- 形式化验证不是学术玩具:在 700 多条规则的工业系统里,机器证明比人工 review 便宜得多
- 算法搜索的瓶颈往往在搜索空间本身:TASO 没换搜索算法(仍是回溯),只是把候选规则从 30 条扩到 743 条,效果就上去了
- 代数公理是规则的源真相:算子层加新东西时,先扩公理再扩规则,不要反过来
- 离线 / 在线分层:把昂贵的枚举与证明放离线一次跑完,把便宜的搜索放在线对每个模型跑——这是工业系统通用的省钱套路
延伸阅读
- 论文 PDF:TASO SOSP 2019(16 页,前 8 页可读懂大部分)
- 代码仓库:github.com/jiazhihao/TASO(含验证器实现,C++ 写的)
- 后续工作:FlexFlow Unity OSDI 2022(同一作者把 TASO 思路扩到分布式)
- 视频讲解:作者 Zhihao Jia 在 SOSP 2019 的 25 分钟报告——可以从 ACM Digital Library 找到
- xla-compiler —— XLA 的 Pattern Matcher 是 TASO 思路的工业落地
- tvm-2018 —— TVM 也有图级优化,但仍依赖手写规则;可以对比两者的取舍
关联
- xla-compiler —— XLA 的图级 fusion 受 TASO 影响最大;二者的 pattern 库可以互译
- tvm-2018 —— 同样做编译器栈,但 TVM 重在算子级 schedule,TASO 重在图级重写;两层正交
- alpa-2022 —— Jia 后续作品,把”自动搜索”从图重写扩到并行策略
- pytorch —— PyTorch 2.0 Inductor 引用 TASO 设计
反向链接
- alpa-2022 —— Alpa — 把张量/流水/数据并行统一成一道搜索题
- pytorch —— PyTorch — 深度学习主流框架
- xla-compiler —— XLA — 给 TensorFlow / JAX 装一台真正的张量编译器