Flash Attention v3技术点
hopper特性co-design
- hopper新特性
- 低精度
- 优化量化误差
- 异步
- intra-warpgroup, inter-warpgroup
- 低精度
异步
异步: 异步计算(WGMMA), 异步传输(TMA) => 软件级流水线
- producer-consumer模型: 生产者传输, 消费者计算
- 算力高的tensor-core(gemm)去覆盖算力低的cuda-core(softmax, exp)
- inter-warpgroup overlap
- gemm和softmax做overlap
- intra-warpgroup overlap
- gemm0算ntile + 1时, gemm1算ntile
- wait WGMMA0 complete
- do online softmax
- wait WGMMA1 complete
- gemm0算ntile + 1时, gemm1算ntile
warp-specialized
- hopper:
- warp的寄存器分配有区分, e.g. TMA只用少量寄存器
- hopper之前:
- 所有warp的寄存器分配一视同仁, 不区分,导致寄存器浪费
- hopper:
低精度
- 挑战
- per-block quant
- gemm-I的C和gemm-II的A矩阵layout不兼容
- per-block quant
- C = A * B * scale_A * scale_B
- fp8场景下gemm融合的layout不兼容问题: gemm-I的C layout和gemm-II的A layout不兼容
- 解决方案: cutlass3.5 => 构造tileMMA时指定permutationLayout
make_tiled_mma
- 解决方案: cutlass3.5 => 构造tileMMA时指定permutationLayout
permutationLayout
permutationLayout: 给你一个row-ptr, 你构造stride去重排这一行
permutationLayout表示new_layout变成old_layout的方法(映射关系)
1 | """ |