Flash attention变长batching API使用
主要记录flash_attn.flash_attn_varlen_func
这个接口的使用, 精髓在于理解函数签名和输入形状: 函数签名需要每个seq的offset, 输入形状需要(bs, seqlen)
平坦化后的(total_num, nhead, headdim)
1 | from flash_attn import flash_attn_varlen_func |
需要注意使用causal
这个参数才能进入causal模式哦。
1 | def flash_attn_varlen_func( |
值得注意的是qkv输入形状上需要是(total_num, nheads, headdim)
而不是(batch_size, seqlen, nheads, headdim)
, 和flash_attn_func
是不同的。这是因为在变长batching中把bs * seqlen
打平展开了,然后再结合offset去找到每个batch的其实位置做计算。
最重要的参数就是cu_seqlens_q
和cu_seqlens_k
, 用于记录找到每个batch需要的offset。比如seq0的offset=0, seq1的offset=seq0.len, seq2的offset=seq0.len+seq1.len, 因此就是一个不包含自身的前缀和, 可以通过torch.cumsum
减去各自的seqlen获得:
1 | prefill_start_pos = torch.cumsum(seq_len, dim=0, dtype=torch.int32) - seq_len |
有因为API要求的cu_seqlens_k
的形状的batch_size+1
还需在末尾追加一个”总token数”:
1 | prefill_start_pos = torch.cat([prefill_start_pos, torch.tensor([torch.sum(seq_len)], dtype=torch.int32, device="cuda")], dim=0) |
完整示例如下:
Demo
1 | import torch |
Python setup.py和开发流程
Python setup.py和开发流程cheat sheet 1234567# setup.pyfrom setuptools import setup, find_packagessetup...
Label Words are Anchors An Information Flow Perspective for Understanding In-Context Learning深度解析
深入理解Label Words are Anchors: An Information Flow Perspective for Understanding In-Context Learnin...