跳转到内容

RWKV — 让 RNN 拿到 Transformer 那张训练并行的入场券

是什么

RWKV(Receptance Weighted Key Value)是 2023 年由 BlinkDL(彭博)与社区合作提出的序列模型架构。它做了一件听起来很矛盾的事:训练时长得像 Transformer(全序列并行),推理时长得像 RNN(每来一个 token 只更新固定大小的状态)

日常类比:

  • Transformer 像一个学生做完形填空时翻回去看每一段话——题越长翻得越累
  • 传统 RNN 像一个人边听边记小本子——记得快但同学没法分工
  • RWKV 像一个会”一次性给全班做笔记的本子”——上课(训练)时大家分工抄写,回去(推理)时只翻自己那一页

它训练时把”加权累积”写成 prefix-sum,可以全序列并行(GPU 吃满);推理时把同一个公式写成递推,每个 token 只读写几个固定大小向量。这个”双形态等价”是 RWKV 真正的卖点。

为什么重要

不理解 RWKV,下面这些事都没法解释:

  • 为什么 2024 年突然冒出一波”线性 attention 复兴”——RetNet / GLA / mamba 都受它直接或间接启发
  • 为什么手机本地大模型(端侧 LLM)一开始几乎都选 RWKV——没有 KV cache,1.5B 模型在树莓派上也能跑长对话
  • 为什么 14B 是个心理关口——这是第一次有人把 RNN 路线开源训到 10B 量级并发布完整 checkpoint
  • 为什么”训练并行 vs 推理高效”在序列模型里是个长期张力——attention 偏前者,传统 RNN 偏后者,RWKV 想同时拿到

RWKV 的核心承诺:训练 O(N) 时间且并行推理每 token O(1) 显存质量与同规模 Transformer 持平。三件一起达到,是它和此前所有线性 attention 工作的最大区别。

核心要点

RWKV 的关键是 把 softmax(QK^T)V 替换成可累积的加权和,让同一个公式既能并行又能递推。

  1. R / W / K / V 四个角色:R 是 receptance(门控,读取记忆),W 是 time decay(每个通道独立学习的衰减率),K / V 沿用 attention 的命名。一层有两个 block:time-mix(类 attention)和 channel-mix(类 FFN)。

  2. WKV 算子:核心公式大致是

    wkv_t = sum_{i<=t} exp(-(t-i)·w + k_i) · v_i
    / sum_{i<=t} exp(-(t-i)·w + k_i)

    分子分母都是 prefix-sum 形式——训练时对全序列并行算(O(N) 时间,GPU 友好),推理时改写成”上一步状态 + 当前 token”的递推(每 token O(1) 显存)。

  3. time decay 让每个通道有自己的”记忆半衰期”:w 是逐 channel 学习的标量,等价于”这一通道关心多远的历史”。一些 channel 学到长程衰减(保留全局语义),一些学到快速衰减(关注局部 n-gram)。

三件加起来 = 数学上一道公式,工程上两套实现,业务上训练快 + 推理省

实践案例

案例 1:端侧 / 嵌入式 LLM 推理

任务:在手机或树莓派上跑 1.5B 模型做长对话。

  • Transformer:KV cache 随上下文线性增长,对话越长占用越大,长对话几百兆显存就吃不消
  • RWKV:state 固定大小(每层只有几个向量),1.5B 模型常驻几百 MB,对话多长都不涨

这是 RWKV 在 2023-2024 年最现实的产品落地窗口。Llama.cpp 之外,社区里另一条端侧 LLM 主线就是 rwkv.cpp。

案例 2:流式语音 / 翻译生成

实时语音边听边翻。Transformer 做 streaming 要分 chunk 配滑窗,要决定”多久重新算一次 attention”,工程复杂;RWKV 的 RNN 形态天然就是”每来一个 token 更新一次 state”,延迟稳定,代码简单——只要把上一步的 state 存住即可。

案例 3:超长上下文文档处理

100k+ token 的长文档。RWKV 推理显存与上下文长度无关,单卡可跑完全篇。但要注意它是”压缩式”理解——适合摘要、续写、风格延续,不擅长精确召回某一句原文。

案例 4:14B 上和同规模 Transformer 打平

RWKV-4 14B 在 The Pile 上训练后,HellaSwag / PIQA / ARC / LAMBADA 等 zero-shot benchmark 平均分与 Pythia / GPT-J 等同规模 Transformer 持平。这是 RNN 路线 第一次 在 10B+ 规模上证明”质量不差”——此前所有线性 attention 工作都只在 100M 以内规模上验证。

案例 5:训练时为什么能并行

把 wkv 写成两个 prefix-sum(分子分母各一)。prefix-sum 在 GPU 上有 work-efficient 并行算法(Blelloch 1990),时间复杂度 O(log N)。RWKV 训练时一次性算完全序列的 wkv,反向传播也是同样形式,因此与 attention 一样可以 fused-kernel 优化。

直觉理解:传统 RNN 是 h_t = f(h_{t-1}, x_t),强依赖 t-1,必须串行;RWKV 把每一步的”贡献”写成独立可加项,再用 prefix-sum 把它们累起来——加法本身满足结合律,所以可以二分合并并行算。这种”把递归改写成可结合累积”的技巧,在并行计算里叫 scan 模式,深度学习社区从 2023 年起才把它当成一线工具。

踩过的坑

  1. fp16 下数值溢出:time decay 累积写成 exp(-w·t + k) 形式,t 一大指数项就爆 fp16。社区惯例是训练时把 state 累加放 fp32、推理时再降精度;初学者直接 fp16 训练几乎必 NaN。

  2. 长程精确召回弱:state 是固定大小的 channel-wise 摘要,本质有损。needle-in-a-haystack(在 100k 上下文里找一句原文)比同规模 Transformer 差一截。不要指望它替代 attention 的所有用例。

  3. time decay w 的初始化敏感:w 太大(衰减太慢)数值不稳,w 太小(衰减太快)等于扔掉长程依赖。RWKV-4 用按 layer 递增的 head-wise 初始化,社区调参经验比 Transformer 少很多。

  4. batch 流式推理容易写错:state 是固定 shape,但 batch 内每条序列要各自维护一份 state;早期实现常把 batch 维度和 state 维度搞混,padding token 也参与了 state 更新,导致输出错位。

  5. prompt 顺序敏感:因为 state 是单向累积出来的,把同样几个示例换个顺序,输出可能差很多。Transformer 双向 attention 没有这个问题。做 in-context learning 时建议固定 prompt 模板。

适用 vs 不适用场景

适用

  • 端侧 / 嵌入式 / 边缘部署(state 仅几百 MB,无需 GB 级 KV cache)
  • 流式生成场景(实时语音 / 实时翻译 / 实时翻译字幕)
  • 超长上下文(>100k)做”压缩式”理解(摘要、续写、风格延续)
  • 推理流量远大于训练成本的产品(一次训练,海量低成本 inference)

不适用

  • 强 in-context learning / few-shot 场景(精确复制示例的任务)
  • 精确召回(RAG 命中、code search、法律条款引用)
  • 工具调用密集的 agent(需要稳定指令跟随)
  • 学术 SOTA 追逐——14B 之后规模优势越拉越大的还是 attention 系

历史小故事(可跳过)

  • 1990 年:Blelloch 发明 work-efficient parallel scan 算法(即 prefix-sum 并行版),本是并行计算理论玩具
  • 1997 年lstm-1997 让 RNN 实用化处理长依赖,但训练串行、扩不到大规模
  • 2014 年:GRU 简化 LSTM,工程上更轻;但本质还是串行 RNN
  • 2017 年:Transformer 发布,attention 一举击溃 RNN,序列建模十年范式定型
  • 2020 年:Linear Attention / Performer / Linformer 等线性化工作出现,但只在小模型上验证
  • 2021 年:RWKV 项目由 BlinkDL(彭博)个人发起,从 v1 一路迭代到 v4,社区驱动
  • 2022 年:S4 在 Long Range Arena 击败 Transformer,但语言建模仍弱
  • 2023 年 5 月:RWKV-4 论文 arXiv 上线,14B 参数完全开源,第一次把 RNN 路线推到大模型量级
  • 2023 年 12 月mamba 把 selectivity 加进 SSM,社区把 RWKV / Mamba / RetNet 一起称为”线性序列模型复兴”
  • 2024 年起:RWKV-5 / RWKV-6 / GLA / HGRN 等后续涌现,旗舰 LLM 仍以 attention 为主

学到什么

  1. 同一个数学公式可以有两种实现形态——训练并行 + 推理递推的等价改写,是序列模型一个深刻的设计杠杆
  2. prefix-sum 是被深度学习社区低估的并行原语——RWKV / Mamba / FlashAttention 都在用它把”看似串行”的累积变成 O(log N) 并行
  3. time decay 的本质是 input-independent gating——和 LSTM 的 forget gate、attention 的 softmax 是同一类思想:让每个通道有自己的记忆节奏
  4. 看 paper 看 limitation 比看 result 更重要——RWKV 论文老老实实写了 ICL 弱、长程召回弱,社区跳过这一节然后惊讶”为什么没替代 Transformer”
  5. 完全开源 + 社区驱动 + 大规模验证,三者凑齐才是让一个新架构”被认真讨论”的入场券
  6. 架构创新的真正成本是工程——只有公式不够,还要写 fused CUDA kernel、设计 fp16/fp32 混合策略、调 time decay 初始化;RWKV 早期社区花在工程上的精力远多于花在数学上的

延伸阅读

  • 论文 PDF:RWKV arXiv 2305.13048(30+ 作者,社区驱动;section 4 公式推导是核心)
  • 官方代码:BlinkDL/RWKV-LM(完整训练 + 推理 + checkpoint)
  • 中文讲解:RWKV 中文社区文档(按版本拆解,工程细节多)
  • attention —— RWKV 想替代的那个机制
  • mamba —— 同期”线性序列模型”代表,SSM 路线 vs RWKV 的 RNN 路线

关联

  • attention —— RWKV 的对照面;attention 是 lossless lookup,RWKV 是 lossy 累积
  • mamba —— 平行竞争路线;Mamba 用 SSM + selectivity,RWKV 用 RNN + time decay;二者常被一起讨论
  • lstm-1997 —— RWKV 的远亲,都用”隐藏状态压缩历史”,但 RWKV 的状态可线性并行
  • transformer —— RWKV 的竞争对手;2024 后双方更多走”混合”而非”替代”
  • flash-attention —— 同一类 hardware-aware 思路,只是优化对象一个是 attention 一个是 RNN 累积
  • gru-2014 —— GRU 简化了 LSTM,但本质还是串行;RWKV 算”GRU 的并行训练版”

反向链接

  • attention —— Attention Is All You Need
  • flash-attention —— FlashAttention — 不改算法,只改数据怎么进 GPU
  • gru-2014 —— GRU 2014 — 用两个门替代 LSTM 三个门,编码-解码范式登场
  • mamba —— Mamba — 选择性状态空间模型