跳转到内容

N-BEATS — 纯前馈网络在时序预测上打败统计派

是什么

N-BEATS 是一种只用全连接层 + 残差就把时间序列预测做到 SOTA 的神经网络。日常类比:像一队侦探接力——第一个看完案情先猜一版,剩下的疑点交给下一个,每个人只补前一个没解释的部分,最后把所有人的猜测加起来。

输入:过去 N 个数据点(lookback 窗口)。输出:未来 H 个数据点(forecast 窗口)。中间没有 RNN、没有 attention、没有 CNN,只有堆叠的全连接块

核心结果:在 M4 竞赛(10 万条真实时序、6 种频率:年/季/月/周/日/小时)上,纯神经网络第一次打败了 ETS / ARIMA / Theta 这些统计派老将,并且超过了 M4 冠军(一个统计 + ML 的混合方案)约 3% sMAPE。

为什么重要

不理解 N-BEATS,就解释不了下面的事:

  • 为什么 2018 年 M4 竞赛后大家说”深度学习在时序上不行”——纯 ML 方法那年全都输给了简单统计方法
  • 为什么 2020 年风向突然变了——N-BEATS 出来后,业界开始相信深度学习也能做工业级预测
  • 为什么后来的 N-HiTS / TFT / TimesFM 都把”残差堆叠”当默认 building block
  • 为什么”简单的 MLP + 巧妙的连接结构”有时候比花哨的 attention 更好用

核心要点

N-BEATS 的设计就三件事:

  1. 块(Block)一次产两个输出:每个块吃一段输入,输出 backcast(重建过去)和 forecast(预测未来)。类比:侦探不仅猜下一步,还要”复述自己理解到的案情”——这样下一个侦探能看到他理解错了什么。

  2. 双重残差堆叠(Doubly Residual Stacking):把 backcast 从输入里减掉,剩下的残差送给下一个块。所有块的 forecast 相加得到最终预测。类比:第一个侦探说”我懂了 70%,剩下 30% 你看”——下一个块只对那 30% 负责。

  3. 基函数(Basis)可选两种

    • Generic(通用):基函数自己学,可解释性差但精度高
    • Interpretable(可解释):约束某些块只学多项式基(→ 趋势),某些块只学 Fourier 基(→ 季节性),分离后可看图

整个网络通过 180 个模型集成得到 M4 上的最终成绩。

实践案例

案例 1:一个 Block 内部长什么样

lookback (N 点) → FC → FC → FC → FC →
├─→ θ_b → 基函数 → backcast (N 点)
└─→ θ_f → 基函数 → forecast (H 点)

四层全连接抽特征,最后分两个头:一个头输出”重建系数 θ_b”,另一个输出”预测系数 θ_f”。基函数把系数变回时间序列。

逐部分解释:

  • θ_b / θ_f:每个头吐出的系数向量(可以理解成”这个块认为信号是怎么组合的”)
  • 基函数:generic 模式下是学习出来的线性映射;interpretable 模式下是固定的多项式或 Fourier
  • backcast 用来从输入里减掉,让下一个块只看残差

案例 2:双重残差是怎么”接力”的

x_0 = lookback
块 1:backcast b_1, forecast f_1 → x_1 = x_0 - b_1
块 2:backcast b_2, forecast f_2 → x_2 = x_1 - b_2
块 3:backcast b_3, forecast f_3 → x_3 = x_2 - b_3
...
最终预测 = f_1 + f_2 + f_3 + ...

每个块只负责前一块没解释清楚的那部分残差。这和 ResNet 的”学残差”思想同源,但 N-BEATS 把残差双向接(backward 减输入、forward 加输出)。

案例 3:interpretable 模式下能看到什么

把 stacks 分成两组:trend stack 用多项式基(例如 t、t²、t³),seasonality stack 用 sin / cos 基。训练完后画图:

  • trend stack 的 forecast 显示一条平滑曲线(“长期方向”)
  • seasonality stack 的 forecast 显示周期波动(“周内 / 月内规律”)

代价:精度比 generic 略低(论文报告小几个百分点),但业务方能看图解释

案例 4:generic 模式下基函数是什么

generic 块的”基函数”实际上就是一个可学习的线性层

forecast = W_f × θ_f (W_f 形状为 H × dim(θ_f),从训练里学出来)
backcast = W_b × θ_b (W_b 形状为 N × dim(θ_b))

也就是说,generic N-BEATS 本质是一个深度 MLP + 双重残差骨架——没有任何特殊结构,全部参数都从数据里学。这也是为什么它精度更高、但解释性几乎为零。

踩过的坑

  1. Univariate only:原版只吃一条时间序列,不接收外生变量(节假日、促销标志、天气等)。后续 N-BEATSx(2022)才加上。

  2. 靠集成撑性能:M4 上的 SOTA 数字来自 180 个模型集成(不同 lookback 长度 × 不同损失函数 × 不同随机种子)。单模型成绩没那么夸张。

  3. interpretable 不总是干净:把基约束成多项式 + Fourier 后,仍可能学到分不清趋势和季节性——多项式头吃掉了一部分周期,Fourier 头吃掉了一部分趋势。需要人工看图判断。

  4. lookback 选错就崩:lookback 太短捕不到长周期,太长就过拟合噪声。论文用 lookback = k × horizon,k ∈ {2, 3, 4, 5, 6, 7},每个 k 各训一组再集成。

  5. 跨频段一份模型:M4 的年/季/月/周/日/小时频段差异巨大,论文最终对每种频段单独训一组。直接把六种频段塞一起训会拉低整体精度——这点经常被新手忽略。

  6. 损失函数选择敏感:sMAPE / MAPE / MASE 在不同频段表现不同,论文集成里同时用了多种损失函数。只用一种损失训出来的单模型很难达到报告的成绩。

适用 vs 不适用场景

适用

  • 单变量时序预测(销量、流量、负荷、汇率等只看自己历史的场景)
  • 训练数据较多(几千到几万条序列,跨序列共享参数)
  • 工业级精度需求 + 可接受集成开销

不适用

  • 强外生变量驱动(必须看天气 / 促销 / 上下游信号)→ 用 N-BEATSx / TFT
  • 数据极少(单条序列只有几十点)→ 统计方法 ETS / ARIMA 仍是首选
  • 需要给出预测分布 / 不确定性估计(原版只输出点预测)→ 用 DeepAR / TFT

历史小故事(可跳过)

  • 2018 年 M4 竞赛:Spyros Makridakis 主办,10 万条时序、61 个参赛队。冠军 Slawek Smyl 用 ES-RNN(统计 + LSTM 混合);纯 ML 方法平均输给简单统计方法。M4 报告下了硬结论:“纯 ML 方法目前不行”。
  • 2019 年 5 月:Element AI(Bengio 联合创办的实验室)放出 N-BEATS 论文。第一次有纯神经网络在 M4 全频段上超过 ES-RNN。
  • 2020 年 ICLR:论文录用,作者把代码 + 训好的模型开源,业界复现确认。“DL 时序预测”翻盘从这里开始。
  • 2022–2023 年:N-BEATSx 加外生变量、N-HiTS 用层级插值大幅提速、TimesFM / Chronos 等通用预测大模型把残差堆叠当默认架构。

学到什么

  1. 结构 > 容量:N-BEATS 没用 attention 没用 RNN,只靠”双向残差 + 块堆叠”就赢。说明在合适的归纳偏置下,朴素 MLP 仍有竞争力。

  2. backcast 不是装饰:让网络”复述输入”逼它真正理解,而不是只蒙输出。这种”自监督式辅助任务”在很多结构里都见过(autoencoder / BERT MLM)。

  3. 集成是 SOTA 的常规弹药:单模型很少打破纪录,但工程上 180 个模型集成对实时推理基本不可接受——读论文要分清”刷榜数字”和”可部署数字”。

  4. 可解释性可以”嵌进结构”:把基函数固定成多项式 + Fourier,模型本身就能输出”趋势”和”季节性”两条曲线,不需要事后归因工具。

  5. 领域无关 + 跨序列共享:N-BEATS 不为每条序列单独建模,所有序列共享同一组参数。这与统计派”每条序列拟一个 ARIMA”的范式根本不同——也是它能从大量数据里”借力”的关键。

延伸阅读

关联

  • resnet —— 残差思想的源头,N-BEATS 把它从”加”扩展到”双向减加”
  • transformer —— 同期主流时序模型的另一条路,N-BEATS 证明不用 attention 也能赢
  • deepar —— Amazon 早期的 RNN 时序预测,给概率预测设了 baseline
  • autoformer —— Transformer 路线在时序上的代表,与 N-BEATS 思路对照

反向链接

  • chronos-2024 —— Chronos — 把时间序列当语言来训练大模型
  • resnet —— ResNet — 残差连接