JAX Causal Self-Attention 구현
마지막 수정:
JAX attention 구현도 수식은 PyTorch와 같다.
Q = X Wq
K = X Wk
V = X Wv
scores = Q K^T / sqrt(d)
probs = softmax(mask(scores))
out = probs V
lab에서는 QKV를 한 번에 만든 뒤 head dimension으로 reshape한다.
qkv = x @ params["qkv"]
q, k, v = jnp.split(qkv, 3, axis=-1)
q = q.reshape(batch, seq_len, n_heads, head_dim).transpose(0, 2, 1, 3)
causal mask는 array로 만든다.
causal = jnp.tril(jnp.ones((q_len, k_len), dtype=bool), k_len - q_len)
scores = jnp.where(causal[None, None, :, :], scores, jnp.finfo(scores.dtype).min)
JAX에서 특히 조심할 점은 shape가 compile contract라는 것이다. prefill처럼 q_len == k_len인 경우와 decode처럼 q_len == 1인 경우는 서로 다른 compiled executable이 될 수 있다.
실습
python3 labs/jax-transformer/bench_attention.py --workload prefill
python3 labs/jax-transformer/bench_attention.py --workload decode
확인
- attention tensor의 기본 layout
[B, H, T, Dh]에서 각 축은 무엇을 뜻하는가? - causal mask가 decode workload에서 prefill과 다르게 해석될 수 있는 이유는 무엇인가?
- JAX에서 sequence length 변화가 compile cache에 영향을 주는 이유는 무엇인가?