FlashAttention 1
- main idea
- IO感知, 即感知GPU的层级关系
- 手动算子融合, 实现CUDA算子
- 局限和Future
- 需要手写CUDA做融合, 希望可以用高级语言写在编程成CUDA
- IO感知的思路可以扩展到非Attention的场景
- 多GPU的IO感知也可以做优化
- 实现
- 尽可能少设计HBM读写
- 计算softmax时不需要访问整个输入
- 重新设计attn的计算, 让输入可以分块多次地计算: tiling
- 反向时不存储大量中间结果
- 保存前向时softmax normalization factor以快速重算, 而不是传统方法的需要读取中间数据: recomputation
- 计算softmax时不需要访问整个输入
- 具体实现: tiling, recomputation. ref
- tiling: 分块加载分块计算。Q, K, V分块加载到SRAM, 分块单独计算
- softmax公式转换, 关键在于如何通过局部值在最后换算出全局值
- 分母直接用最新标量值, 分子部分要将指数位更新成全局值, e.g. $(\sum e^{x_i^{(2)} - m(x^{(2)})}) * e^{m(x^{(2)} - m(x_{new}))}$
- in short 相乘等于指数位相加 从而替换上新值
- 分母直接用最新标量值, 分子部分要将指数位更新成全局值, e.g. $(\sum e^{x_i^{(2)} - m(x^{(2)})}) * e^{m(x^{(2)} - m(x_{new}))}$
- softmax公式转换, 关键在于如何通过局部值在最后换算出全局值
- recomputation: 不存储方向传播需要的中间值
- 通过存储softmax normalization statistics (m,l)和输出O就可以重计算S和P
- kernel融合
- tiling: 分块加载分块计算。Q, K, V分块加载到SRAM, 分块单独计算
- 尽可能少设计HBM读写
$$
softmax(x) = \frac{[f(x^{(1)}) \cdot e^{m(x^{(1)}) - m(x)} , f(x^{(2)}) \cdot e^{m(x^{(2)}) - m(x)}]}{\sum{[l^{new1}, l^{new2} ] }}
$$
FlashAttention 2
PS: block具体大小应随GPU变化
- main idea
- FlashAttention1还不是最优, 主要是因为任务的划分在不同GPU thread blocks, wraps下不是最优
- 一个grid包含多个block, 一个block包含多个wrap, 一个wrap包含多个thread
- 更好的work partitioning
- 减少非乘法(non-matmul)操作
- 并行计算attn, 即使是单头
- 考虑多在thread block内计算, 减少跨组通信
- FlashAttention1还不是最优, 主要是因为任务的划分在不同GPU thread blocks, wraps下不是最优
- discussion and future
- 让flashAttention2兼容更多设备和数据类型
- 利用编译器让编程更简单
- 实现
- matmul优化
- 并行化
- 额外在序列长度这一维度考虑并行, 提高GPU利用率
- 因为序列长度大时会降低batch size, 从而降低GPU利用率
- 额外在序列长度这一维度考虑并行, 提高GPU利用率
- wrap分配和循环调整 TODO
- 调整公式从而跳转循环的实现, 结果是HBM的读写更少了
- 内外循环调整: 原本KV在外循环, QO在内循环
- 跳转后Q在外循环, KVO在内循环, 降低wrap之间的通信TODO