Cutlass通俗理解
核心目的
核心目标: 消除矩阵乘法中对A, B矩阵的重复读取
- 核心: 内积变外积 -> 每个数据只会读第一次, 避免对A/B矩阵的重复读取
- 最内层外积, 可复用的地方也用外积, 体现就是MMA的warp level op
内积的问题: 如果每个thread负责C矩阵的一个元素C[m, n], 并且用下面这种方式循环(内积), 可以发现A, B矩阵会被重复读取。
1 | # 内积的问题 |
外积的情况: 把k循环提到最外层, 那thread的任务要如何排布? -> MMA
利用MMA, 我们可以: 多个thread负责不同A/B矩阵元素的加载, 同时算同时算C的不同位置
1 | for k in range(K): |
没有MMA的年代是怎么操作的? 我们只能优化gmem -> smem的读取了, 一次协作式加载到smem, 然后用内积方法做smem数据的矩阵乘
cutlass tiling
- Thread block tile: C矩阵分块, 每个块C[Mb, Nb]分给不同的thread block
- 一个C分块需要读取A, B矩阵的一个长条
- warp tile: MMA, smem后C的分块, 每个warp负责C[mb, nb]的一个小块
- thread tile: MMA内部电路交换, 一般不使用thread level的操作, 外积真正起作用的地方
1 | for (int mb = 0; mb < M; mb += Mtile) { |
WarpTileOp展开
1 | for (int m = 0; m < Mtile; m += warp_m) { |
misc
- thread block tile
- warp tile
- …
- predicate: 避免使用ifelse等跳转指令, 对硬件流水线不友好
- QA
- 为什么不能直接从寄存器写到global memory
- tensor core的设计: 每个thread拥有的结果在逻辑上是比连续的, e.g.
- v0, v1连续, v2, v3连续,但是(v0, v1)和(v2, v3)之间不连续
- 直接写回一个thread每次只能写回两个结果, 但是每个thread每次最多是可以写回128bit的数据的
- 使用smem来整理数据
- tensor core的设计: 每个thread拥有的结果在逻辑上是比连续的, e.g.
- 为什么不能直接从寄存器写到global memory
TODO:
- pipeline
- swizzle