PyTorch에서 Attention Kernel 비교하기

cudaattentionpytorchbenchmark

Path 4의 마지막은 PyTorch workflow에서 attention 구현들을 비교하는 것이다.

host setup H2D copy kernel D2H copy verify
CPU timer
cudaEvent elapsed inside stream
PyTorch benchmark loop
최적화는 먼저 같은 구간을 같은 방식으로 재는 것에서 시작한다.

비교 대상

1. 직접 작성한 naive attention
2. torch.nn.functional.scaled_dot_product_attention
3. FlashAttention 계열 backend

항상 correctness를 먼저 본다. dtype, causal flag, dropout, shape가 다르면 비교가 깨진다.

benchmark shape

batch, heads, seq_len, head_dim

특히 seq_len을 바꾸면 FlashAttention의 IO 이득이 어떻게 드러나는지 보기 좋다.

확인

  • attention benchmark에서 warmup이 필요한 이유는 무엇인가?
  • float32, float16, bfloat16 결과 tolerance는 왜 다르게 봐야 하는가?
  • FlashAttention이 항상 빠르지 않다면 어떤 shape나 환경을 의심해야 하는가?