JAX Token Embedding과 LM Head

마지막 수정:

jaxtransformerembeddinglogits

PyTorch에서는 보통 nn.Embeddingnn.Linear module을 만든다. JAX lab에서는 parameter를 data로 둔다.

params = {
    "token_embedding": ...,
    "position_embedding": ...,
}

token id를 vector로 바꾸는 일은 array indexing이다.

x = params["token_embedding"][input_ids]

position embedding도 같은 방식으로 더한다.

positions = jnp.arange(seq_len)
x = x + params["position_embedding"][positions][None, :, :]

LM head는 별도 parameter를 만들 수도 있지만, tiny lab에서는 embedding weight를 공유한다.

logits = x @ params["token_embedding"].T

이렇게 하면 모델 state가 어디 있는지 숨지 않는다. forward(params, input_ids, config)가 어떤 parameter를 읽고 어떤 logits를 만드는지 함수 signature에서 바로 보인다.

실습

python3 labs/jax-transformer/train.py --steps 1 --n-layers 1 --d-model 64 --n-heads 4

확인

  • JAX에서 embedding layer가 class가 아니라 array indexing으로 표현되는 이유는 무엇인가?
  • tied LM head는 어떤 parameter를 재사용하는가?
  • forward 함수가 params를 인자로 받으면 debugging이 쉬워지는 지점은 어디인가?