跳转到内容

MLX — Apple Silicon 统一内存原生 ML 框架

是什么

MLX 是 Apple 机器学习研究团队 2023 年 12 月开源的数组与机器学习框架,专为 Apple Silicon(M1/M2/M3/M4 系列)芯片设计。日常类比:像给自家硬件量身做的西装——通用框架是百货店成衣,能穿但贴不住身;MLX 顺着 M 芯片的版型剪裁,每一寸都贴肉。

它分两层 API:

  • 低层 mlx.core——形状像 NumPy(mx.array / mx.matmul / mx.exp
  • 高层 mlx.nn / mlx.optimizers——形状像 pytorchnn.Module / optim.AdamW
import mlx.core as mx
import mlx.nn as nn
x = mx.random.normal((4, 768)) # 直接在统一内存里
y = nn.Linear(768, 256)(x) # GPU 算,CPU 也能读,零拷贝
mx.eval(y) # lazy:到这一步才真算

短短四行就摸到了 MLX 的三大特征:统一内存 / NumPy 形状 / lazy evaluation

为什么重要

不理解 MLX,下面这些事会卡住你:

  • 为什么 Mac mini M2 / MacBook Pro 上跑 70B 大模型推理比 24GB 的 RTX 4090 还顺——统一内存能直接吃满整机 64GB/128GB
  • 为什么 PyTorch 的 MPS 后端总是有 op fall back 到 CPU、速度差一截——它把 Apple Silicon 当成”另一种 GPU”来适配,没专门为统一内存重写
  • 为什么 Hugging Face 上的模型出现 mlx-community/Llama-3.1-70B-4bit 这种独立分支——MLX 有自己的权重量化格式
  • 为什么 iOS app 要嵌端侧模型时越来越多人选 MLX 而不是 Core ML——Core ML 偏部署、MLX 偏研究和定制

核心要点

3 个底层选择 + 1 个生态层

  1. 统一内存(Unified Memory)零拷贝:M 系列芯片 CPU/GPU 物理上共享同一块 DRAM。传统框架(pytorch CUDA / TensorFlow)按”host RAM → device VRAM”模型设计,每次跨设备都 tensor.to(device)。MLX 的 mx.array 一次分配在共享池,CPU op 和 GPU op 能直接读同一份数据。这是 MLX 与所有”移植到 Mac”框架的根本分水岭

  2. Lazy evaluation(延迟求值)y = x @ w + b 只是构图,真正算要等 mx.eval(y) 或访问 y.item()。好处:编译器能把多个 op 融合成一个 Metal kernel,省内存带宽。代价:调试时 print(y) 会”突然卡一下”,因为它隐式触发 eval。借鉴 jax,但保留 PyTorch eager 的写法直觉。

  3. Composable function transformsgrad(loss_fn) 直接返回梯度函数;vmap(fn) 把单样本函数自动批量化;jit(fn) 把动态图编译成静态图加速。这一招完全继承自 JAX。

  4. 生态三件套

    • mlx-lm——LLM 推理与微调脚手架(Llama / Mistral / Qwen / Phi 开箱即用)
    • mlx-examples——官方 Stable Diffusion / Whisper / LoRA 微调示例
    • mlx-swift——Swift 绑定,给 iOS / macOS app 端侧嵌入

合起来一句话:MLX 把统一内存这个硬件红利,从底层数组一路顶到上层 LLM/Swift 接口

实践案例

案例 1:本地跑 70B 模型,命令两行

Terminal window
pip install mlx-lm
mlx_lm.generate --model mlx-community/Meta-Llama-3.1-70B-Instruct-4bit \
--prompt "用中文写一首关于深秋的俳句"

Mac mini M2 Max 64GB 上能跑,token 速度大约 10 tok/s。同一台机器走 PyTorch+MPS 路径要么 OOM 要么 fall back 到 CPU,慢一个数量级。

案例 2:用 MLX 写一个简单 MLP 训练循环

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def __call__(self, x):
return self.fc2(nn.relu(self.fc1(x)))
model = MLP()
opt = optim.AdamW(learning_rate=1e-3)
loss_fn = lambda m, x, y: nn.losses.cross_entropy(m(x), y, reduction='mean')
loss_and_grad = nn.value_and_grad(model, loss_fn)
for x, y in dataloader:
loss, grads = loss_and_grad(model, x, y)
opt.update(model, grads)
mx.eval(model.parameters(), opt.state) # 显式 sync 一次

写法几乎和 PyTorch 一比一对应,只是显式 mx.eval 替代了隐式的 loss.backward()

案例 3:LoRA 微调 7B 模型在 32GB MacBook 上

Terminal window
mlx_lm.lora --model mlx-community/Mistral-7B-v0.3-4bit \
--train --data ./my_dataset \
--iters 600 --batch-size 1 --lora-layers 16

32GB 统一内存吃得下。同样的事用 PyTorch+MPS 通常 OOM,要么得退到 CPU offload,速度慢 10 倍。这是 MLX 在个人研究场景最实用的卖点

踩过的坑

  1. print(array) 是同步点:lazy eval 下,print 会触发计算。在循环里随手 print 中间结果,会让本来融合的 kernel 全部展开,性能掉一档。调试时改用 mx.eval 显式 sync,正式代码里把 print 删掉。

  2. 第一次跑某个 op 会”冷启动”:Metal kernel 是即时编译的,第一次出现 mx.matmul((m, n), (n, k)) 这种新 shape 会编译几百毫秒。benchmark 时务必先 warmup 几轮再计时。

  3. bf16 不是所有芯片都有:M1 / M1 Pro / M1 Max 不支持 bf16,写 mx.bfloat16 会 fall back 到 fp32。M2 起才有真正的 bf16 单元。看清自己芯片代际再选 dtype。

  4. 生态小,论文代码要自己 port:Hugging Face transformers 不直接支持 MLX 后端(要 mlx-lm 中转)。最新论文里的奇怪 op(FlashAttention v3、Mamba SSM 自定义 kernel)经常要自己写 Metal——比 PyTorch 累。

  5. 不能跨平台部署:MLX 只跑 Apple Silicon。生产部署到 Linux 服务器 GPU 集群,必须回到 pytorch/ONNX/TensorRT。把 MLX 当研究 + 端侧推理的工具,不当生产训练工具。

适用 vs 不适用场景

适用

  • 用 Mac 本地跑 LLM 推理 / LoRA 微调(个人 / 研究 / 学习)
  • 想给 iOS / macOS app 嵌入端侧模型(mlx-swift)
  • Mac 上做 Stable Diffusion / Whisper 离线推理
  • 做 Apple Silicon 性能研究、教学示范、Hackathon 原型

不适用

  • 训练大规模生产模型(仍走 NVIDIA + pytorch + accelerate
  • 需要跨平台部署到 Windows / Linux 服务器
  • 重度依赖 PyTorch 生态特定库(vLLM / DeepSpeed / Megatron-LM)
  • 需要分布式多机训练(MLX 设计是单机多设备,没多机抽象)

历史小故事(可跳过)

  • 2023-12:Apple ML Research 团队(Awni Hannun 主导)开源 MLX,首版只有 Python + C++ 接口,定位”给研究员的 NumPy 替代”。
  • 2024 Q1mlx-swift 发布,把同一套数组 API 暴露给 iOS / macOS 原生开发者。
  • 2024 中mlx-lm 成熟,Hugging Face 出现 mlx-community 组织专门发布量化权重。
  • 2025:随 M3 / M4 Pro / Max 释放更大显存(128GB / 192GB),Mac 成为社区跑 70B+ 模型门槛最低的平台。

短短两年从一个研究工具变成 Apple 生态端侧 AI 事实标准。

学到什么

  1. 硬件特性决定框架设计:统一内存不是”省一次拷贝”那么小的事,它让”哪个设备算”变成纯调度问题,不再是数据布局问题
  2. lazy + eager 可以共存:MLX 学 jax 的 lazy 拿到融合优化,但保留 PyTorch eager 的写法直觉,不强制 jit
  3. 生态厚度 vs 硬件契合度的取舍:选 PyTorch 拿生态,选 MLX 拿性能/显存——没有银弹,按场景挑
  4. 本地推理 + 端侧 app 是 MLX 真正的护城河,不是和 CUDA 抢训练市场

延伸阅读

关联

  • pytorch —— 同领域主流框架,PyTorch+MPS 在 Mac 上跑但不如 MLX 贴硬件
  • jax —— grad / vmap / jit 三大函数变换的设计原型
  • accelerate —— PyTorch 生态分布式薄壳,与 MLX 走完全不同的路线(兼容多后端 vs 专注单平台)
  • hindley-milner —— 不直接相关,但同样体现”为特定目标专门设计 vs 通用”的取舍