news 2026/3/26 23:21:21

面试官:FlashAttention 的实现原理与内存优化方式?为什么能做到 O(N²) attention 的显存线性化?

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
面试官:FlashAttention 的实现原理与内存优化方式?为什么能做到 O(N²) attention 的显存线性化?

如果你最近刷到过“FlashAttention”,那你一定见过那句经典介绍:“它让传统 O(N²) 的 Attention,显存占用变成 O(N)。”

很多人平时也都用FlashAttention,但是很少有人能够讲清楚其中的原理。
今天我们就拆开讲清楚:

  • 为什么普通 Attention 显存爆炸;
  • FlashAttention 究竟改了什么;
  • 为什么它能在保持 O(N²) 计算量的同时,让显存线性化。

一、普通 Attention 的计算与内存瓶颈

标准的自注意力(Self-Attention)计算如下:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V

假设输入序列长度为 N,特征维度为 d。

那么计算步骤:

  1. 计算相似度矩阵
S = QKᵀ → [N, N]
  1. 归一化
A = softmax(S)
  1. 加权求和
O = A * V

显存问题出在哪?

关键在于那一步S = QKᵀ
它是一个N×N 的矩阵,会直接占据 O(N²) 的显存。

举个例子:
假设 N=4096,单精度浮点数 4 字节:

4096² × 4B ≈ 64 MB

而在多头 attention、batch 堆叠后,这个数会直接上百 MB。

再加上中间 softmax 的缓存与梯度,整个过程几乎炸显存

二、FlashAttention 的核心思想

FlashAttention 的核心不是改公式,而是改计算顺序。论文题目里那句关键话非常准确:“An IO-aware exact attention algorithm.”

也就是说:

  • 数学上结果一模一样;
  • 但计算顺序被重排,
  • 最小化显存访问和缓存中间矩阵为目标。

普通实现流程:

QKᵀ → Softmax → Dropout → (Softmax * V)

问题是:

  • 每一步都需要完整的 [N, N] 矩阵;
  • 每层都要读写显存(global memory);
  • Softmax 的数值稳定性还要额外缓存maxsum

这些中间值不是算力瓶颈,而是IO 瓶颈
GPU 大部分时间都在“搬运数据”,而不是“算”。

三、FlashAttention 的关键优化

FlashAttention 的思路非常巧妙:把 Attention 计算拆成小块(tiles),每次只在显存中保留局部块,并在块级别完成 softmax 的归一化与累加。

分块计算 QKᵀ

把 Q 和 K 按块划分:

Q = [Q₁, Q₂, ..., Q_M] K = [K₁, K₂, ..., K_M]

对于每个 query 块 Qᵢ:

  • 依次读取每个 key 块 Kⱼ;
  • 计算局部相似度矩阵 Sᵢⱼ = QᵢKⱼᵀ;
  • 同时在寄存器中保留该块的最大值与和。

这样只需要存储一个 tile 的中间矩阵(比如 64×64),不会生成完整的 [N, N] 矩阵。

块内 Softmax 的数值稳定处理

为了保持数值精度,FlashAttention 在块内维护:

  • 当前最大值mᵢ
  • 累积和lᵢ

公式如下:

m i ( j ) = m a x ( m i ( j − 1 ) , m a x ( S i j ) ) l i ( j ) = e x p ( m i ( j − 1 ) − m i ( j ) ) ∗ l i ( j − 1 ) + s u m ( e x p ( S i j − m i ( j ) ) ) m_i^{(j)} = max(m_i^{(j-1)}, max(S_{ij})) l_i^{(j)} = exp(m_i^{(j-1)} - m_i^{(j)}) * l_i^{(j-1)} + sum(exp(S_{ij} - m_i^{(j)}))mi(j)=max(mi(j1),max(Sij))li(j)=exp(mi(j1)mi(j))li(j1)+sum(exp(Sijmi(j)))

这样,在不保存全局 S 的情况下,也能正确计算 softmax 归一化。

同步加权求和

每计算完一个块:O i ( j ) + = s o f t m a x ( S i j ) ∗ V j O_i^{(j)} += softmax(S_{ij}) * V_jOi(j)+=softmax(Sij)Vj

所有块处理完之后,就得到了完整的输出 Oᵢ。
整个过程是流式的(streaming)

  • 一边计算,一边归一化;
  • 中间结果立刻被消费;
  • 不需要缓存完整 attention 矩阵。

四、显存线性化的本质

普通 Attention:

  • 必须保存 O(N²) 的相似度矩阵;
  • 所以显存复杂度是 O(N²)。

FlashAttention:

  • 只保存 O(N) 的输入输出(Q, K, V, O);
  • 中间矩阵被分块并立即释放;
  • 显存复杂度降为 O(N)。

计算量仍然是 O(N²),但显存访问和缓存规模线性化了。

简而言之,FlashAttention 不是降低计算复杂度,而是降低内存访问复杂度

五、梯度计算也能高效吗?

梯度计算中,FlashAttention 也优化了反向传播。
它同样采用流式重计算(recompute):

  • 前向不保存完整中间激活;
  • 反向时重新计算需要的局部块;
  • 减少显存峰值,但增加少量算力消耗。

这种设计非常适合训练大模型,因为 GPU 的主要瓶颈往往是显存,而不是算力。

FlashAttention v2采用了更高并行度 + kernel 调度来提升吞吐率,v3支持FP8、序列并行、多 query 批融合,进一步提速并适配大模型推理。如果想详细了解FlashAttentionV2 V3的详细算法和思想,文章末尾有专门分析它们的文章。

FlashAttention的精妙之处不在数学,而在工程调度

它通过分块(tiling)计算流式(streaming)softmaxkernel 融合(fusion),让原本需要 O(N²) 显存的注意力计算,在保持 O(N²) 计算量的同时实现了显存 O(N) 的线性化

📚推荐阅读

FlashAttention怎么提升速度的?

FlashAttention2:更快的注意力机制,更好的并行效率

FlashAttention3 全解析:速度、精度、显存的再平衡

FlashDecoding:让大模型推理提速的关键突破

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/14 18:04:33

Langchain-Chatchat公式识别支持:LaTeX数学表达式解析尝试

Langchain-Chatchat 公式识别支持:LaTeX 数学表达式解析探索 在科研、教育和工程领域,文档中频繁出现的数学公式构成了知识传递的核心。然而,当我们将这些富含 LaTeX 表达式的学术资料导入智能问答系统时,常常发现模型“视而不见”…

作者头像 李华
网站建设 2026/3/22 16:34:26

字节跳动M3多智能体框架:让AI团队协作效率提升85%

字节跳动M3多智能体框架:让AI团队协作效率提升85% 【免费下载链接】M3-Agent-Control 项目地址: https://ai.gitcode.com/hf_mirrors/ByteDance-Seed/M3-Agent-Control 你是否曾遇到过这样的场景?当服务器出现故障时,运维团队需要像侦…

作者头像 李华
网站建设 2026/3/19 11:47:43

五年之后,我们想把这场大会当作一份送给行业的礼物

铛铛铛!很高兴告诉大家,第12届全球边缘计算大会即将于12月27日在上海虹桥雅乐轩酒店举办!这是第12届大会,也是我们筹备最久的一次。时光倒回2020年11月7日,在北京,那是我们第一次举办全球边缘计算大会。说实…

作者头像 李华
网站建设 2026/3/24 7:36:29

Langchain-Chatchat上下文感知问答:理解对话历史的连贯性

Langchain-Chatchat上下文感知问答:理解对话历史的连贯性 在企业知识管理日益复杂的今天,员工常常面临这样的困扰:想查一条年假政策,却要在几十页PDF中反复翻找;技术支持人员被客户追问“上次说的那个配置参数是多少”…

作者头像 李华
网站建设 2026/3/24 5:26:41

终极指南:用Oxigraph在30分钟内构建高性能语义网应用

终极指南:用Oxigraph在30分钟内构建高性能语义网应用 【免费下载链接】oxigraph SPARQL graph database 项目地址: https://gitcode.com/gh_mirrors/ox/oxigraph 想要构建符合W3C标准的语义网应用,却苦于找不到既高性能又易于使用的RDF数据库&…

作者头像 李华
网站建设 2026/3/24 3:18:36

3步实现高精度人脸特征点实时检测系统

3步实现高精度人脸特征点实时检测系统 【免费下载链接】face-alignment 项目地址: https://gitcode.com/gh_mirrors/fa/face-alignment 人脸特征点检测技术正逐渐成为计算机视觉领域的核心技术之一,它能够从图像中精准定位人脸的68个关键特征点,…

作者头像 李华