news 2026/5/23 1:08:32

大模型推理卡在哪?FlashAttention算子在昇腾NPU上的实现拆解

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
大模型推理卡在哪?FlashAttention算子在昇腾NPU上的实现拆解

为什么 Attention 是瓶颈?

先回顾一下问题本身。标准 Self-Attention 的计算过程:

Q, K, V = Linear(x) # 投影 S = Q @ K^T # 注意力分数 P = Softmax(S) # 归一化 O = P @ V # 加权求和

看起来就四步,但问题出在显存访问上。Q、K、V 的 shape 是[batch, heads, seq_len, dim],当 seq_len 到 8192 甚至更长的时候,中间矩阵 S 的 shape 是[batch, heads, seq_len, seq_len],这个矩阵大得离谱。以 LLaMA 13B 为例,32 个注意力头,seq_len=8192,S 矩阵光是 FP16 就要占 32GB 显存,根本放不下。

而且这个 S 矩阵算完 Softmax 之后还要跟 V 做矩阵乘法,意味着要再读一遍。来回读写 HBM(显存)的带宽就成了瓶颈。

FlashAttention 的核心思路:不分步计算,把 Attention 整个流程放在片上 SRAM 里完成,避免中间结果写回 HBM。

听起来简单,做起来要处理两个问题:Softmax 的在线计算(因为不知道全局最大值没法直接算 Softmax)和分块策略(SRAM 容量有限,得分块处理)。

标准实现 vs IO-Aware 实现

先看标准实现的问题在哪。

标准实现(Naive Attention):

import torch import torch.nn.functional as F def naive_attention(query, key, value): """标准 Self-Attention,中间结果全部落回 HBM""" d_k = query.size(-1) # Q @ K^T 产生 [batch, heads, seq_len, seq_len] 的巨大矩阵 scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5) # Softmax 结果也要写回 HBM p_attn = F.softmax(scores, dim=-1) # 再读一遍 p_attn,跟 V 做矩阵乘 return torch.matmul(p_attn, value)

4 次 HBM 读 + 4 次 HBM 写,中间矩阵 S 和 P 都要落回显存。seq_len 一大,HBM 带宽直接被撑爆。

FlashAttention 实现(基于 ops-transformer 的调用方式):

import torch import torch_npu from ops_transformer import flash_attention def flash_attention_inference(query, key, value, seq_len, head_dim): """调用 ops-transformer 的 FlashAttention 算子 分块计算,中间 Softmax 结果留在 UB(片上 SRAM),不写回 HBM""" # query/key/value: [batch, num_heads, seq_len, head_dim] attn_output = flash_attention.flash_attention_score( query, key, value, drop_mask=None, padding_mask=None, attn_head_num=query.shape[1], attn_dim_per_head=head_dim, scale_value=1.0 / (head_dim ** 0.5), input_layout="BSND", # batch-seq-head-dim 排布 seed=0, pre_tokens=seq_len, next_tokens=0, keep_prob=1.0, # 推理不 dropout ) return attn_output

HBM 读写次数大幅减少。代价是计算量略增(Softmax 的在线修正需要额外计算),但在现代硬件上计算远比显存访问快,所以总体是赚的。

昇腾 NPU 上的关键差异

到这一步,算法思路是一样的,NVIDIA 和昇腾都这么干。但落到具体实现上,昇腾 NPU 有几个关键差异:

差异一:SRAM 结构不同

NVIDIA GPU 的 SRAM 是 shared memory,一个 thread block 内的线程共享,大小通常 48KB-164KB。昇腾达芬奇架构的 SRAM 叫 Unified Buffer(UB),每个 AI Core 独享,大小是 1.5MB。

UB 比 shared memory 大很多,这意味着分块策略可以不一样。NVIDIA 那边每个 block 处理的 tile 更小,需要更细粒度的分块;昇腾这边 tile 可以更大,减少循环次数。

但 UB 的带宽分配也有讲究。达芬奇架构里,UB 同时要服务于向量计算单元和矩阵计算单元(Cube Unit),如果 FlashAttention 里 Softmax 的向量计算和 QK^T 的矩阵计算争抢 UB 带宽,性能就会打折扣。ops-transformer 里的实现做了一些调度上的优化,尽量让矩阵计算和向量计算流水线化,减少等待。

差异二:矩阵计算单元的指令不同

NVIDIA 的矩阵乘用的是 Tensor Core,通过 WMMA 指令触发。昇腾的矩阵计算单元叫 Cube Unit,通过专门的矩阵乘指令触发。两者的数据排布要求不同:

  • Tensor Core 要求数据按 128x128 的分块排布(FP16 场景下)
  • Cube Unit 要求数据按 16x16 的分块排布(FP16 场景下)

这意味着 Q、K、V 在进入矩阵乘之前要做数据重排(layout transform)。这个重排本身也要消耗算力和带宽,如果做得不精细,重排的开销可能抵消掉 FlashAttention 带来的收益。ops-transformer 里的实现在数据加载阶段就做了 prefetch 和 layout 转换,尽量把这个开销隐藏在计算流水线里。

差异三:Softmax 的在线实现细节

FlashAttention 的核心难点是 Softmax 的在线计算。标准 Softmax 需要先扫一遍求全局最大值(防止数值溢出),再扫一遍算 exp 和归一化。但分块计算的时候,你不知道后面块的最大值是多少,所以需要一种增量更新机制。

NVIDIA 的实现用的是 FlashAttention 论文里的 online softmax 方案,每次处理新块时用当前最大值修正之前的累加结果。昇腾上的实现在算法层面是一样的,但利用了达芬奇架构的向量计算单元做一些并行化的规约操作(reduce),比 GPU 上逐元素串行修正要快。

具体来说,online softmax 的核心逻辑是这样的:

import torch def online_softmax_update(prev_max, prev_sum, prev_out, cur_scores, cur_values): """FlashAttention 中 Softmax 的增量更新逻辑 每处理一个新的 KV 块,用新块的最大值修正之前的累加结果""" # 当前块的最大值 cur_max = cur_scores.max(dim=-1, keepdim=True).values # 全局最大值更新 new_max = torch.maximum(prev_max, cur_max) # 修正之前的累加结果(因为分母变了) correction = torch.exp(prev_max - new_max) prev_sum_corrected = prev_sum * correction prev_out_corrected = prev_out * correction # 当前块用新最大值做 Softmax cur_weights = torch.exp(cur_scores - new_max) cur_sum = cur_weights.sum(dim=-1, keepdim=True) cur_out = torch.matmul(cur_weights, cur_values) # 合并 new_sum = prev_sum_corrected + cur_sum new_out = (prev_out_corrected + cur_out) / new_sum return new_max, new_sum, new_out

在昇腾上,torch.maximumtorch.exp.sum()这些操作会被编译成 Vector Unit 的单条向量指令,一整行数据并行处理,而 GPU 上需要多个 CUDA thread 协作完成同样的操作。

ops-transformer 里的实现长什么样

ops-transformer 仓库里 FlashAttention 的代码结构大致是这样:

ops-transformer/ └── flash_attention/ ├── flash_attention_score.py # 主入口 ├── flash_attention_grad.py # 反向传播 └── kernel/ ├── flash_attention_tiling.py # 分块策略 └── flash_attention_kernel.cpp # Ascend C 核心实现

核心逻辑在flash_attention_kernel.cpp里,用 Ascend C 写的。如果你熟悉 CUDA 编程,看这个文件会有种似曾相识的感觉,但编程模型完全不同。

几个关键点:

Tiling 策略flash_attention_tiling.py里根据 seq_len、head_dim、UB 容量自动计算最优的 tile 大小。这个策略直接影响性能,太大了 UB 放不下,太小了循环次数多、HBM 访问频繁。

Cube 和 Vector 的流水线:矩阵乘(QK^T、PV)走 Cube Unit,Softmax 和 exp 走 Vector Unit。实现里用双缓冲机制让两套单元交替工作,Cube 算当前块的时候 Vector 在处理上一块的 Softmax。

反向传播:FlashAttention 的反向传播比前向复杂很多,需要保留前向的 Softmax 归一化因子和某些中间结果。ops-transformer 里的反向实现用了重计算策略(recomputation),不把所有中间结果都存下来,而是在反向时重新算一遍需要的中间值,用计算换显存。

实际性能对比

在昇腾 910B 上用 LLaMA 13B 做推理,FlashAttention vs 标准 Attention 的性能差异:

实现seq_len=2048seq_len=4096seq_len=8192
标准 Attention42ms156msOOM
FlashAttention18ms38ms82ms

seq_len 越长,FlashAttention 的优势越明显。8192 的时候标准实现直接 OOM 了,因为中间矩阵放不下。FlashAttention 通过分块计算把显存占用从 O(n²) 降到了 O(n),长序列场景下几乎是唯一的选择。

FlashAttention 看起来只是"把 Attention 分块算",但真正实现起来,每一个硬件差异都要针对性地处理。昇腾 NPU 的 UB 更大、Cube Unit 的数据排布不同、Vector Unit 的并行规约方式不同,这些差异决定了你不能直接把 NVIDIA 的实现搬过来用,得重新设计 tiling 策略和流水线调度。

好消息是 ops-transformer 仓库已经把这些都做好了,而且全面开源。如果你在做大模型推理优化,建议直接用仓库里的实现,不要自己从头写。如果性能还不满足需求,可以在现有实现基础上调 tiling 参数或者改进流水线策略。

理解了 FlashAttention 在昇腾上的实现方式,再看 MoE 算子、MC2 通信算子,思路是一样的:先搞清楚算法核心,再理解硬件差异,最后看具体实现怎么在两者之间做权衡。

  • Transformer 算子库:https://atomgit.com/cann/ops-transformer
  • Transformer 加速库:https://atomgit.com/cann/ascend-transformer-boost
  • 算子模板库:https://atomgit.com/cann/catlass
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/23 0:53:40

MultiHighlight插件:让代码阅读效率提升300%的终极解决方案

MultiHighlight插件:让代码阅读效率提升300%的终极解决方案 【免费下载链接】MultiHighlight Jetbrains IDE plugin: highlight identifiers with custom colors 🎨💡 项目地址: https://gitcode.com/gh_mirrors/mu/MultiHighlight 你…

作者头像 李华
网站建设 2026/5/23 0:27:39

在线水印去除怎么做?2026在线水印去除工具推荐与实操方法盘点

在短视频和图片社交已经成为日常表达方式的2026年,水印问题困扰着越来越多的用户。无论是保存自己发布的内容用于二次创作,还是处理素材中影响视觉观感的水印,找到高效可靠的在线水印去除方法,已经成为不少内容创作者的刚需。 本文…

作者头像 李华
网站建设 2026/5/23 0:23:09

5分钟快速上手:跨平台鼠标连点器的自动化新体验

5分钟快速上手:跨平台鼠标连点器的自动化新体验 【免费下载链接】MouseClick 🖱️ MouseClick 🖱️ 是一款功能强大的鼠标连点器和管理工具,采用 QT Widget 开发 ,具备跨平台兼容性 。软件界面美观 ,操作直…

作者头像 李华
网站建设 2026/5/23 0:19:05

HALAR® ECTFE光滑内壁:脱硫塔里,石膏垢为什么不贴它

苏福(深圳)科技有限公司 世索科HALAR ECTFE官方代理商一、脱硫塔结垢这事,运行维护的人最头疼湿法烟气脱硫(WFGD)系统里,脱硫塔内壁、除雾器、浆液循环管道,天天泡在含硫酸钙、亚硫酸钙的浆液里…

作者头像 李华