Manual PyTorch Attention Benchmark

마지막 수정:

pytorchattentionbenchmarkprofiling

Attention benchmark의 첫 기준선은 직접 작성한 PyTorch attention이다.

구조는 명시적이다.

scores = Q @ K^T
scores = scores + causal_mask
probs = softmax(scores)
out = probs @ V

이 구현은 빠르기 위해 쓰는 것이 아니다. 어떤 intermediate tensor가 생기는지 보기 위해 쓴다.

Q, K, V: [B, H, T, Dh]
scores: [B, H, T, T]
probs:  [B, H, T, T]
out:    [B, H, T, Dh]

Profiler에서 확인할 것은 다음이다.

matmul time
softmax time
mask / fill overhead
peak memory

이 기준선이 있어야 PyTorch SDPA나 optimized attention이 무엇을 줄였는지 볼 수 있다.

확인

  • manual attention에서 가장 큰 intermediate tensor는 무엇인가?
  • causal mask를 materialize하면 memory와 time에 어떤 영향이 생기는가?
  • 이 구현이 느려도 benchmark 기준선으로 중요한 이유는 무엇인가?