多快好省:神奇的FlashAttention
password
Sub-item
type
status
date
summary
tags
icon
Parent item
slug
category
引子
为什么要学习FlashAttention?
一言以蔽之:多块好省:
- 用得多:2024年的今天,FlashAttention[1]已经在事实上成为了大模型训练的标配。
- 速度快:相较标准Attention[2],FlashAttention的速度提升5倍以上。
- 效果好:FlashAttention并非是对标准Attention的近似计算,效果上没有任何损失。
- 显存省:标准Attention的显存消耗随序列长度 二次增长,而FlashAttention将显存消耗从 降低到 。
要看懂FlashAttention,最好的学习路径是沿着业界的探索道路:从英伟达2018年的online-softmax[3][4]开始,再到Google Research 2021年提出的Memory-Efficient Attention(下文称MEA[5])。你会发现,FlashAttention论文中的两个主要贡献:Tiling、Recompute,都算不上是首创(然而online-softmax和MEA这两篇论文的引用量,如今还达不到FlashAttention论文的零头🐶)。
问题分析
先了解关于显存的三个概念:
- SM(Streaming Multiprocessor):SM是GPU实际计算单元。
- GPU HBM(High Bandwidth Memory):HMB就是我们平时所说的显存,空间(相对)大,但数据传输带宽(相对)低。
- GPU SRAM(Static Random Access Memory):SRAM是一种高速运行的存储器,通常用于GPU内部的缓存,如L1和L2缓存。SM依赖SRAM访问要计算的数据,所有要计算的数据,都要先从HBM拷贝到SRAM。这点和CPU架构一样,运算单元不能直接使用内存的数据,需要经过L1/L2缓存。
HMB的问题在于传输带宽低,在标准Attention中,影响运行速度最重要的因素便是HBM访问次数[1]。
我们来分析标准Attention的显存占用和HBM访问:
HBM访问次数(以访问单个float值为基准),其中 为序列长度(通常为4096/8192), 为attention头维度(通常为64/128)
- Line1: 读 的HMB访问量为 , 写 的HMB访问量为 。
- Line2: 读 的HMB访问量为 , 写 的HMB访问量为 。
- Line3: 读 的HMB访问量为 , 读 的HMB访问量为 , 写 的HMB访问量次数为 。
- 上述所有加起来的总HBM访问量为 。忽略掉其中的常数项, 可以将复杂度写为 。由于 ,可以近似为 。
前面讲过,HBM的访问带宽低,过多的HBM访问因此成为标准Attention的瓶颈。
同时,标准FlashAttention的显存占用为,因为要显式再HBM中存储attention矩阵。
FlashAttention做了啥
FlashAttention通过Tiling和recompute两个技巧,在HBM访问和显存占用进行了优化:
- HBM访问量从 降低到 ,其中 为 SRAM 大小。之所以说是“降低”,因为 满足
- 显存占用从 降低到
对于FlashAttention,最核心的优化目标是:消灭掉 大小的attention weight矩阵。只要算法还依赖这个矩阵,HBM访问和显存占用就降低不下来。为此需要做两件事:
- 在不访问完整logits矩阵的情况下计算 softmax
- 不为反向传播存储大的中间 attention weight 矩阵
FlashAttention 借用(提出)了两个idea来解决上述问题:
- online softmax [3][4]:通过等价变换使得Softmax没有了行方向依赖,可以Tiling并行计算。
- recomputation [5] : 存储来自前向的 softmax 的分母项,在反向中重新计算 attention,这比从HBM读取中间矩阵的标准Attention更快。
具体的算法如下:
- 外循环:对于分块的Key、Value遍历,从HBM加载进SRAM
- 内循环:对于分块的 Query遍历,从HBM加载进SRAM
- 在SRAM上完成 Attention 的计算(只是中间结果,外层遍历结束后的结果才是正确的)
至于为什么可以迭代式地计算softmax,《From Online Softmax to FlashAttention》讲得比较清楚,对推公式感兴趣的内容可以继续阅读下面的内容:
step 1:two-passes self-atttention
- :第 个query
- :第 个 KV
- :序列长度
step 2:flash attention (one-pass self-atttention)
再看如何做到一次遍历,首先把式 12 带入式 11,得到
类似online-softmax的做法,构建新序列:
展开看看:
很容易观察到递推关系:
进而,我们可以仅对K、V做一次遍历,便得到一个query token对应的self-attention结果:
step 3:flash attention with tiling
这里得到的是单个Q的self-attention结果(没有对Q分块)
- 方法:遍历K、V时,不再逐个token遍历,而是逐tile遍历(一次遍历多个K、V)
参考文献
[1] Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
[2] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
[3] Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
[4] Ye, Zihao. "From Online Softmax to FlashAttention." (2023).
[5] Rabe, Markus N., and Charles Staats. "Self-attention does not need $ O (n^ 2) $ memory." arXiv preprint arXiv:2112.05682 (2021).
- Utterance