JAX Causal Self-Attention 구현

마지막 수정:

jaxtransformerattentioncausal-mask

JAX attention 구현도 수식은 PyTorch와 같다.

Q = X Wq
K = X Wk
V = X Wv
scores = Q K^T / sqrt(d)
probs = softmax(mask(scores))
out = probs V

lab에서는 QKV를 한 번에 만든 뒤 head dimension으로 reshape한다.

qkv = x @ params["qkv"]
q, k, v = jnp.split(qkv, 3, axis=-1)
q = q.reshape(batch, seq_len, n_heads, head_dim).transpose(0, 2, 1, 3)

causal mask는 array로 만든다.

causal = jnp.tril(jnp.ones((q_len, k_len), dtype=bool), k_len - q_len)
scores = jnp.where(causal[None, None, :, :], scores, jnp.finfo(scores.dtype).min)

JAX에서 특히 조심할 점은 shape가 compile contract라는 것이다. prefill처럼 q_len == k_len인 경우와 decode처럼 q_len == 1인 경우는 서로 다른 compiled executable이 될 수 있다.

실습

python3 labs/jax-transformer/bench_attention.py --workload prefill
python3 labs/jax-transformer/bench_attention.py --workload decode

확인

  • attention tensor의 기본 layout [B, H, T, Dh]에서 각 축은 무엇을 뜻하는가?
  • causal mask가 decode workload에서 prefill과 다르게 해석될 수 있는 이유는 무엇인가?
  • JAX에서 sequence length 변화가 compile cache에 영향을 주는 이유는 무엇인가?