抱歉,您的浏览器无法访问本站

本页面需要浏览器支持(启用)JavaScript


了解详情 >

Roofline Model and Flash Attention

  • 计算强度: Flops/Byte => 反映每个字节的传输会产生多少计算
  • roofline model:
    • 纵轴: 算力, 单位Flop/s
    • 横轴: 计算强度I, 单位Flop/Byte
    • 斜率: 内存带宽, 单位Byte/s
    • 算力的上限决定了roofline model的屋顶
    • 带宽的上限决定了roofline model的屋檐斜率
    • 拐点左侧: mem-bound, 因为单位byte传输没能用满算力
      • 理论上平均一个Byte的传输带来F/B次的mm就能打满
    • 拐点右侧: compute-bound, 算力已经打满了
      • e.g. 你的算法在某个规模上的计算强度I > F_peak/B

矩阵乘法的计算强度

mm的flop数F = 2xMxNxK, 传输数B = 2(bf16)x(MxK + NxK)

平均计算强度I = F/B = (MxN)/(MxK + NxK), 计算强度随着M和N的增大而增大, 理论上就可以打满算力

Flash Attention的计算强度

设序列长度S, 维度数D, batch数B

  • flop
    • QK的flop = 2xBxS^2xD
    • SV的flop = 2xBxS^2xD
    • softmax的flop = 3xBxS^2
      • 3表示softmax的exp, sum, div
      • 当D>>1时, softmax的flop可以忽略不计
    • 整体flop = 4xBxS^2xD
  • 传输
    • Q的传输 = 2(bf16)xBxSxD
    • K的传输 = 2(bf16)xBxSxD
    • V的传输 = 2(bf16)xBxSxD
    • O的写回 = 2(bf16)xBxSxD
    • 整体传输 = 8(bf16)xBxSxD

A100的算力(bf16)FLOP/S = 312TFlops, 带宽为2TB/s, 计算强度拐点 = 312/2 = 156

Flash attention的计算强度I = F/B = (4xBxS^2xD)/(8(bf16)xBxSxD) = S/2

所以当S > 312时, Flash Attention就会打满A100的算力, 进入compute bound

评论