MoE 模型里有个反直觉的现象:8 个 expert 各跑一遍 Linear,每个 expert 只用了 1/8 的输入 token,但 8 个 expert 的权重全得从显存搬一遍。搬权重的开销比计算本身还大。ops-transformer 的 MergedMatMul 算子,把共享输入的多路矩阵乘法合成一次计算,昇腾NPU的 Cube 单元只需读一遍输入,8 路权重连续读取,吞吐量直接翻倍。
问题:MoE 的 hidden 层在干什么
Mixtral 8x7B 的 FFN 层,每个 expert 是这样的:
输入 x → Linear_up (4096×14336) → SiLU → Linear_gate (4096×14336) → × → Linear_down (14336×4096) → 输出8 个 expert,3 个 Linear,一共 24 次 MatMul。每个 token 只路由到 2 个 expert,所以实际执行 6 次 MatMul——但 Gate 的路由结果在运行时才知道,框架必须准备所有 8 个 expert 的权重。
标准实现的处理方式是:逐 expert 调用 MatMul。每次调用,输入数据从显存读到 Cube 单元,算完写回。2 个 expert 就读 2 遍输入。
输入是一样的,读了 2 遍。
MergedMatMul 的合并策略
MergedMatMul 把多路 MatMul 合成一次计算,核心是 Batch GEMM:
标准方式: x @ W_expert0 → y0 x @ W_expert1 → y1 2次输入读取,2次kernel launch MergedMatMul: x @ [W_expert0; W_expert1] → [y0, y1] 1次输入读取,1次kernel launch不是简单拼接。昇腾NPU的 Cube 单元支持 Batch GEMM 指令,一次提交多组矩阵乘法,Cube 单元自动流水执行。MergedMatMul 把 expert 权重按[num_experts, in_dim, out_dim]排列,调用 Batch GEMM 时只传一个 batch 参数。
这样输入 x 只读一次,8 路权重连续排布在显存里,DMA 引擎可以预取下一组权重,Cube 单元不停顿。
不只是 MoE
MergedMatMul 在非 MoE 场景也有用。Transformer 的 FFN 层有个常见优化:把 up_proj 和 gate_proj 合并:
# 标准写法:两次独立 MatMulup=x @ W_up# [batch, seq, 14336]gate=x @ W_gate# [batch, seq, 14336]y=silu(gate)*up out=y @ W_down# MergedMatMul 写法:合并 up 和 gatecombined=merged_matmul(x,[W_up,W_gate])# 1次读取,1次kernelup,gate=combined.chunk(2,dim=-1)y=silu(gate)*up out=y @ W_downup_proj 和 gate_proj 共享同一个输入 x,合并后省一次输入读取和一次 kernel launch。在 Llama2-70B 上,这个优化单独就能带来 8-12% 的 FFN 层加速。
性能数据
Atlas 800I A2,Mixtral 8x7B 推理:
| 配置 | FFN 层延迟 (ms) | 吞吐 (tokens/s) |
|---|---|---|
| 逐 expert MatMul | 18.3 | 1,240 |
| MergedMatMul (2 expert) | 10.7 | 1,890 |
| MergedMatMul + up/gate 合并 | 8.9 | 2,150 |
FFN 层延迟降了 40%。up/gate 合并又额外省了 12%。
用法
通过torch_npu.npu.merged_matmul接口调用:
importtorch_npu# weights: [num_experts, in_dim, out_dim]# x: [batch, seq, in_dim]# expert_ids: [batch, seq] 每个token路由到的expert索引y=torch_npu.npu.merged_matmul(x,weights,expert_ids)ATB 路径下,MergedMatMul 默认启用,不需要手动调用。ATB 内部会检测 FFN 层是否存在多路共享输入的 Linear,自动触发合并。
踩坑
MergedMatMul 要求所有合并的 MatMul 维度一致——in_dim 和 out_dim 必须相同。MoE 模型的 expert 权重天然满足这个条件。但如果你想把 up_proj 和 gate_proj 合并,得确保它们的 out_dim 一样。Llama 系列没问题,但有些模型的 gate_proj 维度跟 up_proj 不同,这时候只能合并 expert 之间的,不能跨层合并。
另外,expert 数量超过 8 时,Batch GEMM 的 register pressure 会上升。在 Ascend 910 上,8 expert 是甜点值,16 expert 也能跑但收益递减,32 expert 以上建议拆成多批。
如果你的 MoE 推理服务 FFN 层是瓶颈(看 NPU 利用率曲线上有没有周期性低谷),大概率是逐 expert MatMul 在反复搬数据。MergedMatMul 改动不大,但收益实打实。仓库在这里:
https://atomgit.com/cann/ops-transformer