JAX Token Embedding과 LM Head
마지막 수정:
PyTorch에서는 보통 nn.Embedding과 nn.Linear module을 만든다. JAX lab에서는 parameter를 data로 둔다.
params = {
"token_embedding": ...,
"position_embedding": ...,
}
token id를 vector로 바꾸는 일은 array indexing이다.
x = params["token_embedding"][input_ids]
position embedding도 같은 방식으로 더한다.
positions = jnp.arange(seq_len)
x = x + params["position_embedding"][positions][None, :, :]
LM head는 별도 parameter를 만들 수도 있지만, tiny lab에서는 embedding weight를 공유한다.
logits = x @ params["token_embedding"].T
이렇게 하면 모델 state가 어디 있는지 숨지 않는다. forward(params, input_ids, config)가 어떤 parameter를 읽고 어떤 logits를 만드는지 함수 signature에서 바로 보인다.
실습
python3 labs/jax-transformer/train.py --steps 1 --n-layers 1 --d-model 64 --n-heads 4
확인
- JAX에서 embedding layer가 class가 아니라 array indexing으로 표현되는 이유는 무엇인가?
- tied LM head는 어떤 parameter를 재사용하는가?
forward함수가params를 인자로 받으면 debugging이 쉬워지는 지점은 어디인가?