PyTorch에서 Attention Kernel 비교하기
Path 4의 마지막은 PyTorch workflow에서 attention 구현들을 비교하는 것이다.
host setup H2D copy kernel D2H copy verify
비교 대상
1. 직접 작성한 naive attention
2. torch.nn.functional.scaled_dot_product_attention
3. FlashAttention 계열 backend
항상 correctness를 먼저 본다. dtype, causal flag, dropout, shape가 다르면 비교가 깨진다.
benchmark shape
batch, heads, seq_len, head_dim
특히 seq_len을 바꾸면 FlashAttention의 IO 이득이 어떻게 드러나는지 보기 좋다.
확인
- attention benchmark에서 warmup이 필요한 이유는 무엇인가?
float32,float16,bfloat16결과 tolerance는 왜 다르게 봐야 하는가?- FlashAttention이 항상 빠르지 않다면 어떤 shape나 환경을 의심해야 하는가?