PyTorch SDPA Attention Benchmark
마지막 수정:
PyTorch는 scaled_dot_product_attention을 제공한다.
import torch.nn.functional as F
out = F.scaled_dot_product_attention(
q,
k,
v,
is_causal=True,
)
SDPA는 같은 attention 의미를 더 최적화된 backend로 실행할 수 있다. 환경과 shape에 따라 math, memory-efficient, FlashAttention 계열 backend가 선택될 수 있다.
비교는 manual attention과 같은 조건에서 한다.
same B, H, T, Dh
same dtype
same causal setting
same dropout setting
same correctness tolerance
중요한 질문은 “SDPA가 빠른가?”가 아니라 “어떤 shape와 dtype에서 어떤 backend가 이득을 주는가?”이다.
확인
scaled_dot_product_attention은 manual attention의 어떤 연산들을 하나의 interface로 묶는가?- SDPA benchmark에서 dtype과 causal flag를 맞춰야 하는 이유는 무엇인가?
- SDPA가 항상 빠르지 않다면 어떤 조건을 의심해야 하는가?