DLRM — Meta 把工业推荐模型拆成 4 个标准积木
是什么
DLRM 是 Meta(当时 Facebook)2019 年开源的工业级推荐模型参考实现。它做的事情是:把过去各家公司各写各的 CTR(点击率)模型,抽象成 4 个标准积木,然后把 PyTorch 和 Caffe2 实现一起开源。
日常类比:想象超市收银员要预测你会不会买某商品。
- 你的购物车有数字特征(金额、件数、停留时长)
- 也有类别特征(早上还是晚上来、会员等级、常买生鲜还是零食)
收银员的脑子里会做四件事:
- 把数字归到统一坐标系(“消费 200 元在本店算中等”)
- 给每个类别贴一张身份证向量(“会员金卡 = [0.3, -0.5, 0.8…]”)
- 让所有这些向量两两握手,看哪对默契度高
- 把握手强度汇总,吐出”会买”的概率
DLRM 就是把这四步做成可训练的神经网络。
为什么重要
不理解 DLRM,下面这些事都没法解释:
- 为什么 2019 年之后所有大型推荐基础设施(NVIDIA Merlin、阿里 PAI-Rec、TorchRec)都拿 DLRM 做对标——它是 MLPerf Recommendation 的官方 workload
- 为什么推荐模型需要TB 级显存——embedding table 一张就能几百 GB,HBM 装不下
- 为什么 Meta 后来要专门设计 ZionEX 训练硬件——DLRM 的瓶颈不是算力是带宽
- 为什么推荐模型的并行和大模型完全不一样——它需要把 embedding 做 model parallel,MLP 做 data parallel,混着跑
核心要点
DLRM 由 4 个模块 组成,按顺序拼起来。
-
Bottom MLP(处理连续特征):年龄、消费金额这类数字过一个小 MLP,投影到和 embedding 一样的维度(比如 16 维)。让”数字”和”类别”在同一坐标系里说话。
-
Embedding tables(处理类别特征):每个类别字段一张表。比如”国家”字段有 200 个国家,那就是一张 [200 × 16] 的可训练矩阵。给一个用户,去表里查他对应的那一行。一个工业模型可能有几十张表,加起来参数量到 TB 级。
-
Pairwise dot interaction(特征交叉):把上面所有向量(bottom MLP 输出 1 个 + embedding lookup 出来的 N 个)两两做点积,得到 N(N+1)/2 个标量。点积大 = 这两个特征”对得上”。
-
Top MLP(输出层):把所有点积结果 + 原始 bottom MLP 输出拼成一个长向量,过几层 MLP,最后 sigmoid 输出”会点击/会买”的概率。
整套架构定下来后,研究的重点从”怎么改结构”变成”怎么把它跑起来更快、容量更大”。
实践案例
案例 1:embedding table 怎么吃显存
假设一个广告系统有:
- 10 亿个 user ID(每个 64 维 embedding,FP32)→ 256 GB
- 1 亿个 item ID(每个 64 维 embedding,FP32)→ 25 GB
- 几十个其他类别字段加起来 → 几十 GB
总计几百 GB 到 TB。一张 H100 才 80 GB HBM。所以 embedding 必须切到多卡,每卡只放一部分行,查表时通过 all-to-all 互相要数据。
案例 2:异构并行方案
DLRM 论文最有价值的工程贡献——同一份模型用两种并行:
embedding tables → model parallel(每卡只存一部分行) 通信:all-to-all(查表后交换结果)
bottom MLP / top MLP → data parallel(每卡完整副本) 通信:all-reduce(梯度同步)这种异构并行在大语言模型里见不到(LLM 全模型 parallel)。推荐模型独有,因为 embedding 大但算力小,MLP 小但每条样本都要算。
案例 3:一次前向的张量形状
假设 batch=4,embedding 维 d=4,类别字段 3 个,连续特征 8 维。
输入: dense shape=[4, 8] 连续特征 cat_1 shape=[4] 类别 ID cat_2 shape=[4] cat_3 shape=[4]
bottom MLP: [4, 8] → [4, 4] 投到 d=4embedding lookup: 3 张表各得 [4, 4]现在有 4 个 [4, 4] 向量(dense 1 个 + cat 3 个)
dot interaction: 把 4 个向量 stack 成 [4, 4, 4] 两两点积 → [4, 4, 4] @ [4, 4, 4]^T → [4, 4, 4] 取上三角(不含对角)→ [4, 6]
拼接 dense MLP 输出和上三角 → [4, 10]top MLP: [4, 10] → [4, 1] sigmoid 概率理解了形状变化,整个模型代码就能照着写。
案例 4:和 Wide & Deep 的关系
| 维度 | Wide & Deep (2016) | DLRM (2019) |
|---|---|---|
| 重点 | 算法(记忆 + 泛化) | 系统(标准结构 + 并行) |
| Wide 路 | 手工交叉特征 + LR | 没有 wide 路 |
| Deep 路 | embedding + MLP | embedding + dot interaction + MLP |
| 开源 | TensorFlow API | facebookresearch/dlrm(PyTorch + Caffe2) |
| 历史定位 | 让”两路并行”流行 | 让”工业推荐”有可对标的基线 |
可以理解为 Wide & Deep 给了思想,DLRM 给了可复现的工业骨架。
踩过的坑
-
以为 dot interaction 是先进结构:其实它是最朴素的二阶特征交叉。DCN / xDeepFM 都比它复杂。DLRM 选朴素是故意的——给一个简单可复现的基线比给最强结构更重要。
-
以为推荐模型是算力瓶颈:恰恰相反。embedding 占 99% 参数但只占 < 1% 算力。瓶颈是显存容量 + 通信带宽,不是 FLOPS。这导致推荐硬件选型逻辑和训练 LLM 完全不一样。
-
直接复制 LLM 的并行方案到推荐:会爆炸。LLM 是模型大、激活大;推荐是 embedding 大、MLP 小。embedding 不能 data parallel(每卡一个 TB 副本不可能),MLP 不能 model parallel(切了通信比算还贵)。必须异构。
-
忽略 embedding 的稀疏更新:每个 batch 只有少数 embedding 行被访问,梯度也只更新这些行。如果用普通 dense optimizer 会浪费 99% 算力——必须用稀疏 optimizer(SparseAdam 之类)。
适用 vs 不适用场景
适用:
- 工业 CTR 预估、广告排序、推荐召回后排序
- 类别特征 + 连续特征混合的结构化预测
- 需要工业级 baseline 对标的研究(MLPerf 等)
不适用:
- 序列推荐(用户行为序列)→ 需要 Transformer / RNN,DLRM 不建模时序
- 图结构推荐(社交网络)→ 用 GNN
- 冷启动场景 → DLRM 依赖大量训练数据学 embedding,新用户/新物品没历史
- 内容推荐(图像/视频本身)→ 需要先用 CNN/ViT 提特征再喂 DLRM
历史小故事(可跳过)
- 2016:Google 发 Wide & Deep,业界开始用”两路并行”做推荐
- 2017-2018:DCN / xDeepFM / DIN 各家发力改特征交互结构
- 2019:Meta 干脆开源 DLRM,意思是”结构创新就到这吧,咱比系统效率”
- 2020:MLPerf Recommendation track 直接用 DLRM 做 benchmark
- 2021:Meta 发 ZionEX 论文,专为 DLRM 设计训练硬件
- 2022:TorchRec 上线,PyTorch 官方支持 DLRM 训练
学到什么
- 工业开源不一定是结构最先进,而是最可复现——DLRM 用最朴素的 dot 交叉,正是为了让大家拿来就跑
- 推荐和大模型是两套世界观——一个 embedding 大算力小,一个权重大算力也大;并行方案、硬件选型完全不同
- 标准化是基础设施的前提——DLRM 把模型骨架钉死后,硬件、编译器、训练框架才有共同对标对象
- 有时候发”集大成的简化版”比发”最强结构”影响更大——DLRM 引用量远超同期更复杂的推荐论文
延伸阅读
- 论文 PDF:DLRM arXiv 1906.00091
- 官方实现:facebookresearch/dlrm
- ZionEX 后续:Mudigere et al., “Software-Hardware Co-design for Fast and Scalable Training of DLRM”, ISCA 2022
- wide-deep-2016 —— 推荐双路结构的鼻祖
- youtube-two-tower-2019 —— 召回阶段的双塔结构
关联
- wide-deep-2016 —— Wide & Deep 给思想,DLRM 给工业骨架
- youtube-two-tower-2019 —— 召回 vs 排序:双塔做召回,DLRM 做排序
- ampere-architecture-2020 —— DLRM 推动了 GPU HBM 容量与 NVLink 带宽设计
- pytorch —— DLRM 官方实现框架,后来催生 TorchRec