Transformer 里 Attention 的核心是 Softmax——它把注意力分数变成概率分布。没有 Softmax,注意力分数就只是一个数值没有归一化的矩阵,无法作为权重来聚合 Value。
CANN 的 ops-softmax 仓库专门管理 Softmax 及其变体的实现。Softmax 的计算量不大——就是exp → sum → div三步——但它的数据访问模式决定了它是 Memory Bound 算子,在昇腾NPU 上需要针对大序列长度做专门的优化。
Softmax 为什么是 Transformer 核心
Attention 的计算公式:Attention(Q,K,V) = Softmax(Q × K^T / √d) × V
Q × K^T输出的注意力分数矩阵S是[n, n]的矩阵。矩阵中的每个元素S_ij表示第 i 个 Token 对第 j 个 Token 的注意力强度。但这些分数是未归一化的——可能很大也可能很小。Softmax 把它们归一化成概率分布,让sum(S_ij over j) = 1。
Softmax 的步骤:
exp(x_i)——指数化,把分数转为正数sum(exp(x_i))——求所有指数值的和exp(x_i) / sum——每个指数值除以总和,归一化为概率
Softmax 为什么会成为性能瓶颈
Softmax 的计算量很小——每个元素一次指数运算、一次除法。但它的数据访问模式很差:
- 输入读取
[n, n]矩阵的全部元素(从 DDR 搬到 L1) - 对所有元素做指数运算(Vector Unit 执行)
- 在行方向做 sum(归约操作,需要对整行扫描)
- 再读取一次,每个元素除以 sum(从 DDR 搬到 L1)
对于 n=4096 的序列,Score 矩阵 32MB。整个流程需要搬运约 64MB——两次读S、一次写S_softmax。计算/搬运比很低。
FlashAttention 中的 Softmax 优化
FlashAttention 对 Softmax 的优化是让它原地完成——Score 矩阵不落地 DDR。具体做法:Score 矩阵被切成block×block的子块,每次只搬运一个子块到 L1。在 L1 上做完 Softmax 后立即跟 Value 做矩阵乘,Softmax 的结果不需要写回 DDR。
这个过程需要 Online Softmax 算法——在不知道全局最大值的情况下分块计算:
初始化:max_val = -inf, sum_val = 0 循环每个 K/V 块: 当前块的最大值 local_max = max(S_ij) 更新 max_val = max(max_val, local_max) 缩放旧的 sum_val:sum_val *= exp(max_val - local_max) 当前块的 exp 和:local_sum = sum(exp(S_ij - max_val)) 累积:sum_val += local_sumOnline Softmax 的计算精度跟标准 Softmax 完全一致,但避免了 Score 矩阵的整体搬运。在长序列场景中,Softmax 不再是性能瓶颈。
Online Softmax 的数值稳定性
Softmax 的朴素实现:exp(x_i)在x_i很大时(如 Attention Score 的值可能超过 30)会导致 float16 溢出。标准做法是减去最大值:exp(x_i - max(x)) / sum(exp(x_j - max(x)))。
FlashAttention 的 Online Softmax 在分块计算时也保持了数值稳定性——每个分块独立减去自己的局部最大值,跨分块时用 running max 修正。这个修正的数值误差在10^-5级别——不影响推理精度。
ops-softmax 在 Vector Unit 上的实现
ops-softmax 在 Vector Unit 上的实现不是直接写一条softmax指令——Vector Unit 只有基本的数学指令。Softmax 被拆解为:
vec_max(x)— SIMD 找最大值vec_sub(x, max)— 每个元素减最大值vec_exp(x)— SIMD 指数运算(使用多项式近似)vec_sum(exp_x)— SIMD 求和vec_div(exp_x, sum)— 每个元素除以总和
这 5 条 Vector 指令在 L1 上执行,不需要写 DDR。对于 4096 个元素的 Softmax,Vector Unit 的执行时间约 1-2μs。
大序列长度(n > 4096)时,Score 矩阵[n, n]超出了一次 Kernel 可以处理的 L1 容量。ops-softmax 把 Score 矩阵按行分成多块——每块在 L1 上做完完整的 Softmax 后再写回 DDR。
参考仓库
ops-softmax 仓库
FlashAttention 融合优化