要看懂FlashAttention,最好的学习路径是沿着业界的探索道路:从英伟达2018年的online-softmax[3][4]开始,再到Google Research 2021年提出的Memory-Efficient Attention(下文称MEA[5])。你会发现,FlashAttention论文中的两个主要贡献:Tiling、Recompute,都算不上是首创(然而online-softmax和MEA这两篇论文的引用量,如今还达不到FlashAttention论文的零头🐶)。
GPU SRAM(Static Random Access Memory):SRAM是一种高速运行的存储器,通常用于GPU内部的缓存,如L1和L2缓存。SM依赖SRAM访问要计算的数据,所有要计算的数据,都要先从HBM拷贝到SRAM。这点和CPU架构一样,运算单元不能直接使用内存的数据,需要经过L1/L2缓存。
[1] Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
[2] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
[3] Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
[4] Ye, Zihao. "From Online Softmax to FlashAttention." (2023).
[5] Rabe, Markus N., and Charles Staats. "Self-attention does not need $ O (n^ 2) $ memory." arXiv preprint arXiv:2112.05682 (2021).