跳转到内容

DLRM — Meta 把工业推荐模型拆成 4 个标准积木

是什么

DLRM 是 Meta(当时 Facebook)2019 年开源的工业级推荐模型参考实现。它做的事情是:把过去各家公司各写各的 CTR(点击率)模型,抽象成 4 个标准积木,然后把 PyTorch 和 Caffe2 实现一起开源。

日常类比:想象超市收银员要预测你会不会买某商品。

  • 你的购物车有数字特征(金额、件数、停留时长)
  • 也有类别特征(早上还是晚上来、会员等级、常买生鲜还是零食)

收银员的脑子里会做四件事:

  1. 把数字归到统一坐标系(“消费 200 元在本店算中等”)
  2. 给每个类别贴一张身份证向量(“会员金卡 = [0.3, -0.5, 0.8…]”)
  3. 让所有这些向量两两握手,看哪对默契度高
  4. 把握手强度汇总,吐出”会买”的概率

DLRM 就是把这四步做成可训练的神经网络。

为什么重要

不理解 DLRM,下面这些事都没法解释:

  • 为什么 2019 年之后所有大型推荐基础设施(NVIDIA Merlin、阿里 PAI-Rec、TorchRec)都拿 DLRM 做对标——它是 MLPerf Recommendation 的官方 workload
  • 为什么推荐模型需要TB 级显存——embedding table 一张就能几百 GB,HBM 装不下
  • 为什么 Meta 后来要专门设计 ZionEX 训练硬件——DLRM 的瓶颈不是算力是带宽
  • 为什么推荐模型的并行和大模型完全不一样——它需要把 embedding 做 model parallel,MLP 做 data parallel,混着跑

核心要点

DLRM 由 4 个模块 组成,按顺序拼起来。

  1. Bottom MLP(处理连续特征):年龄、消费金额这类数字过一个小 MLP,投影到和 embedding 一样的维度(比如 16 维)。让”数字”和”类别”在同一坐标系里说话。

  2. Embedding tables(处理类别特征):每个类别字段一张表。比如”国家”字段有 200 个国家,那就是一张 [200 × 16] 的可训练矩阵。给一个用户,去表里查他对应的那一行。一个工业模型可能有几十张表,加起来参数量到 TB 级

  3. Pairwise dot interaction(特征交叉):把上面所有向量(bottom MLP 输出 1 个 + embedding lookup 出来的 N 个)两两做点积,得到 N(N+1)/2 个标量。点积大 = 这两个特征”对得上”。

  4. Top MLP(输出层):把所有点积结果 + 原始 bottom MLP 输出拼成一个长向量,过几层 MLP,最后 sigmoid 输出”会点击/会买”的概率。

整套架构定下来后,研究的重点从”怎么改结构”变成”怎么把它跑起来更快、容量更大”。

实践案例

案例 1:embedding table 怎么吃显存

假设一个广告系统有:

  • 10 亿个 user ID(每个 64 维 embedding,FP32)→ 256 GB
  • 1 亿个 item ID(每个 64 维 embedding,FP32)→ 25 GB
  • 几十个其他类别字段加起来 → 几十 GB

总计几百 GB 到 TB。一张 H100 才 80 GB HBM。所以 embedding 必须切到多卡,每卡只放一部分行,查表时通过 all-to-all 互相要数据。

案例 2:异构并行方案

DLRM 论文最有价值的工程贡献——同一份模型用两种并行:

embedding tables → model parallel(每卡只存一部分行)
通信:all-to-all(查表后交换结果)
bottom MLP / top MLP → data parallel(每卡完整副本)
通信:all-reduce(梯度同步)

这种异构并行在大语言模型里见不到(LLM 全模型 parallel)。推荐模型独有,因为 embedding 大但算力小,MLP 小但每条样本都要算。

案例 3:一次前向的张量形状

假设 batch=4,embedding 维 d=4,类别字段 3 个,连续特征 8 维。

输入:
dense shape=[4, 8] 连续特征
cat_1 shape=[4] 类别 ID
cat_2 shape=[4]
cat_3 shape=[4]
bottom MLP: [4, 8] → [4, 4] 投到 d=4
embedding lookup: 3 张表各得 [4, 4]
现在有 4 个 [4, 4] 向量(dense 1 个 + cat 3 个)
dot interaction:
把 4 个向量 stack 成 [4, 4, 4]
两两点积 → [4, 4, 4] @ [4, 4, 4]^T → [4, 4, 4]
取上三角(不含对角)→ [4, 6]
拼接 dense MLP 输出和上三角 → [4, 10]
top MLP: [4, 10] → [4, 1] sigmoid 概率

理解了形状变化,整个模型代码就能照着写。

案例 4:和 Wide & Deep 的关系

维度Wide & Deep (2016)DLRM (2019)
重点算法(记忆 + 泛化)系统(标准结构 + 并行)
Wide 路手工交叉特征 + LR没有 wide 路
Deep 路embedding + MLPembedding + dot interaction + MLP
开源TensorFlow APIfacebookresearch/dlrm(PyTorch + Caffe2)
历史定位让”两路并行”流行让”工业推荐”有可对标的基线

可以理解为 Wide & Deep 给了思想,DLRM 给了可复现的工业骨架

踩过的坑

  1. 以为 dot interaction 是先进结构:其实它是最朴素的二阶特征交叉。DCN / xDeepFM 都比它复杂。DLRM 选朴素是故意的——给一个简单可复现的基线比给最强结构更重要。

  2. 以为推荐模型是算力瓶颈:恰恰相反。embedding 占 99% 参数但只占 < 1% 算力。瓶颈是显存容量 + 通信带宽,不是 FLOPS。这导致推荐硬件选型逻辑和训练 LLM 完全不一样。

  3. 直接复制 LLM 的并行方案到推荐:会爆炸。LLM 是模型大、激活大;推荐是 embedding 大、MLP 小。embedding 不能 data parallel(每卡一个 TB 副本不可能),MLP 不能 model parallel(切了通信比算还贵)。必须异构。

  4. 忽略 embedding 的稀疏更新:每个 batch 只有少数 embedding 行被访问,梯度也只更新这些行。如果用普通 dense optimizer 会浪费 99% 算力——必须用稀疏 optimizer(SparseAdam 之类)。

适用 vs 不适用场景

适用

  • 工业 CTR 预估、广告排序、推荐召回后排序
  • 类别特征 + 连续特征混合的结构化预测
  • 需要工业级 baseline 对标的研究(MLPerf 等)

不适用

  • 序列推荐(用户行为序列)→ 需要 Transformer / RNN,DLRM 不建模时序
  • 图结构推荐(社交网络)→ 用 GNN
  • 冷启动场景 → DLRM 依赖大量训练数据学 embedding,新用户/新物品没历史
  • 内容推荐(图像/视频本身)→ 需要先用 CNN/ViT 提特征再喂 DLRM

历史小故事(可跳过)

  • 2016:Google 发 Wide & Deep,业界开始用”两路并行”做推荐
  • 2017-2018:DCN / xDeepFM / DIN 各家发力改特征交互结构
  • 2019:Meta 干脆开源 DLRM,意思是”结构创新就到这吧,咱比系统效率”
  • 2020:MLPerf Recommendation track 直接用 DLRM 做 benchmark
  • 2021:Meta 发 ZionEX 论文,专为 DLRM 设计训练硬件
  • 2022:TorchRec 上线,PyTorch 官方支持 DLRM 训练

学到什么

  1. 工业开源不一定是结构最先进,而是最可复现——DLRM 用最朴素的 dot 交叉,正是为了让大家拿来就跑
  2. 推荐和大模型是两套世界观——一个 embedding 大算力小,一个权重大算力也大;并行方案、硬件选型完全不同
  3. 标准化是基础设施的前提——DLRM 把模型骨架钉死后,硬件、编译器、训练框架才有共同对标对象
  4. 有时候发”集大成的简化版”比发”最强结构”影响更大——DLRM 引用量远超同期更复杂的推荐论文

延伸阅读

关联