MLX — Apple Silicon 统一内存原生 ML 框架
是什么
MLX 是 Apple 机器学习研究团队 2023 年 12 月开源的数组与机器学习框架,专为 Apple Silicon(M1/M2/M3/M4 系列)芯片设计。日常类比:像给自家硬件量身做的西装——通用框架是百货店成衣,能穿但贴不住身;MLX 顺着 M 芯片的版型剪裁,每一寸都贴肉。
它分两层 API:
- 低层
mlx.core——形状像 NumPy(mx.array/mx.matmul/mx.exp) - 高层
mlx.nn/mlx.optimizers——形状像 pytorch(nn.Module/optim.AdamW)
import mlx.core as mximport mlx.nn as nn
x = mx.random.normal((4, 768)) # 直接在统一内存里y = nn.Linear(768, 256)(x) # GPU 算,CPU 也能读,零拷贝mx.eval(y) # lazy:到这一步才真算短短四行就摸到了 MLX 的三大特征:统一内存 / NumPy 形状 / lazy evaluation。
为什么重要
不理解 MLX,下面这些事会卡住你:
- 为什么 Mac mini M2 / MacBook Pro 上跑 70B 大模型推理比 24GB 的 RTX 4090 还顺——统一内存能直接吃满整机 64GB/128GB
- 为什么 PyTorch 的 MPS 后端总是有 op fall back 到 CPU、速度差一截——它把 Apple Silicon 当成”另一种 GPU”来适配,没专门为统一内存重写
- 为什么 Hugging Face 上的模型出现
mlx-community/Llama-3.1-70B-4bit这种独立分支——MLX 有自己的权重量化格式 - 为什么 iOS app 要嵌端侧模型时越来越多人选 MLX 而不是 Core ML——Core ML 偏部署、MLX 偏研究和定制
核心要点
记 3 个底层选择 + 1 个生态层:
-
统一内存(Unified Memory)零拷贝:M 系列芯片 CPU/GPU 物理上共享同一块 DRAM。传统框架(pytorch CUDA / TensorFlow)按”host RAM → device VRAM”模型设计,每次跨设备都
tensor.to(device)。MLX 的mx.array一次分配在共享池,CPU op 和 GPU op 能直接读同一份数据。这是 MLX 与所有”移植到 Mac”框架的根本分水岭。 -
Lazy evaluation(延迟求值):
y = x @ w + b只是构图,真正算要等mx.eval(y)或访问y.item()。好处:编译器能把多个 op 融合成一个 Metal kernel,省内存带宽。代价:调试时print(y)会”突然卡一下”,因为它隐式触发 eval。借鉴 jax,但保留 PyTorch eager 的写法直觉。 -
Composable function transforms:
grad(loss_fn)直接返回梯度函数;vmap(fn)把单样本函数自动批量化;jit(fn)把动态图编译成静态图加速。这一招完全继承自 JAX。 -
生态三件套:
mlx-lm——LLM 推理与微调脚手架(Llama / Mistral / Qwen / Phi 开箱即用)mlx-examples——官方 Stable Diffusion / Whisper / LoRA 微调示例mlx-swift——Swift 绑定,给 iOS / macOS app 端侧嵌入
合起来一句话:MLX 把统一内存这个硬件红利,从底层数组一路顶到上层 LLM/Swift 接口。
实践案例
案例 1:本地跑 70B 模型,命令两行
pip install mlx-lmmlx_lm.generate --model mlx-community/Meta-Llama-3.1-70B-Instruct-4bit \ --prompt "用中文写一首关于深秋的俳句"Mac mini M2 Max 64GB 上能跑,token 速度大约 10 tok/s。同一台机器走 PyTorch+MPS 路径要么 OOM 要么 fall back 到 CPU,慢一个数量级。
案例 2:用 MLX 写一个简单 MLP 训练循环
import mlx.core as mximport mlx.nn as nnimport mlx.optimizers as optim
class MLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 10) def __call__(self, x): return self.fc2(nn.relu(self.fc1(x)))
model = MLP()opt = optim.AdamW(learning_rate=1e-3)loss_fn = lambda m, x, y: nn.losses.cross_entropy(m(x), y, reduction='mean')loss_and_grad = nn.value_and_grad(model, loss_fn)
for x, y in dataloader: loss, grads = loss_and_grad(model, x, y) opt.update(model, grads) mx.eval(model.parameters(), opt.state) # 显式 sync 一次写法几乎和 PyTorch 一比一对应,只是显式 mx.eval 替代了隐式的 loss.backward()。
案例 3:LoRA 微调 7B 模型在 32GB MacBook 上
mlx_lm.lora --model mlx-community/Mistral-7B-v0.3-4bit \ --train --data ./my_dataset \ --iters 600 --batch-size 1 --lora-layers 1632GB 统一内存吃得下。同样的事用 PyTorch+MPS 通常 OOM,要么得退到 CPU offload,速度慢 10 倍。这是 MLX 在个人研究场景最实用的卖点。
踩过的坑
-
print(array)是同步点:lazy eval 下,print会触发计算。在循环里随手 print 中间结果,会让本来融合的 kernel 全部展开,性能掉一档。调试时改用mx.eval显式 sync,正式代码里把 print 删掉。 -
第一次跑某个 op 会”冷启动”:Metal kernel 是即时编译的,第一次出现
mx.matmul((m, n), (n, k))这种新 shape 会编译几百毫秒。benchmark 时务必先 warmup 几轮再计时。 -
bf16 不是所有芯片都有:M1 / M1 Pro / M1 Max 不支持 bf16,写
mx.bfloat16会 fall back 到 fp32。M2 起才有真正的 bf16 单元。看清自己芯片代际再选 dtype。 -
生态小,论文代码要自己 port:Hugging Face transformers 不直接支持 MLX 后端(要
mlx-lm中转)。最新论文里的奇怪 op(FlashAttention v3、Mamba SSM 自定义 kernel)经常要自己写 Metal——比 PyTorch 累。 -
不能跨平台部署:MLX 只跑 Apple Silicon。生产部署到 Linux 服务器 GPU 集群,必须回到 pytorch/ONNX/TensorRT。把 MLX 当研究 + 端侧推理的工具,不当生产训练工具。
适用 vs 不适用场景
适用:
- 用 Mac 本地跑 LLM 推理 / LoRA 微调(个人 / 研究 / 学习)
- 想给 iOS / macOS app 嵌入端侧模型(mlx-swift)
- Mac 上做 Stable Diffusion / Whisper 离线推理
- 做 Apple Silicon 性能研究、教学示范、Hackathon 原型
不适用:
- 训练大规模生产模型(仍走 NVIDIA + pytorch + accelerate)
- 需要跨平台部署到 Windows / Linux 服务器
- 重度依赖 PyTorch 生态特定库(vLLM / DeepSpeed / Megatron-LM)
- 需要分布式多机训练(MLX 设计是单机多设备,没多机抽象)
历史小故事(可跳过)
- 2023-12:Apple ML Research 团队(Awni Hannun 主导)开源 MLX,首版只有 Python + C++ 接口,定位”给研究员的 NumPy 替代”。
- 2024 Q1:
mlx-swift发布,把同一套数组 API 暴露给 iOS / macOS 原生开发者。 - 2024 中:
mlx-lm成熟,Hugging Face 出现mlx-community组织专门发布量化权重。 - 2025:随 M3 / M4 Pro / Max 释放更大显存(128GB / 192GB),Mac 成为社区跑 70B+ 模型门槛最低的平台。
短短两年从一个研究工具变成 Apple 生态端侧 AI 事实标准。
学到什么
- 硬件特性决定框架设计:统一内存不是”省一次拷贝”那么小的事,它让”哪个设备算”变成纯调度问题,不再是数据布局问题
- lazy + eager 可以共存:MLX 学 jax 的 lazy 拿到融合优化,但保留 PyTorch eager 的写法直觉,不强制 jit
- 生态厚度 vs 硬件契合度的取舍:选 PyTorch 拿生态,选 MLX 拿性能/显存——没有银弹,按场景挑
- 本地推理 + 端侧 app 是 MLX 真正的护城河,不是和 CUDA 抢训练市场
延伸阅读
- 官方文档:ml-explore.github.io/mlx(quickstart 半小时跑通)
- mlx-lm 仓库:github.com/ml-explore/mlx-lm(LLM 推理 / 微调脚手架)
- Hugging Face MLX 社区:huggingface.co/mlx-community(量化权重集合)
- Awni Hannun 个人博客:awnihannun.com(MLX 主程,常发性能 benchmark)
- pytorch —— 主流对照,PyTorch+MPS 是 Mac 上的另一条路(兼容大、性能差)
- jax —— function transforms 范式来源
- accelerate —— HuggingFace 分布式抽象,对 MLX 不感知(生态分叉点)
关联
- pytorch —— 同领域主流框架,PyTorch+MPS 在 Mac 上跑但不如 MLX 贴硬件
- jax ——
grad/vmap/jit三大函数变换的设计原型 - accelerate —— PyTorch 生态分布式薄壳,与 MLX 走完全不同的路线(兼容多后端 vs 专注单平台)
- hindley-milner —— 不直接相关,但同样体现”为特定目标专门设计 vs 通用”的取舍