JAX MLP와 Decoder Block 조립
마지막 수정:
JAX lab의 MLP는 두 개의 weight와 GELU로 충분하다.
def mlp(params, x):
return linear(params["fc2"], jax.nn.gelu(linear(params["fc1"], x)))
decoder block은 function composition이다.
def decoder_block(params, x, config):
x = x + self_attention(params["attn"], rms_norm(params["norm1"], x, config.eps), config)
x = x + mlp(params["mlp"], rms_norm(params["norm2"], x, config.eps))
return x
여기서 params는 block 하나의 subtree다. 전체 모델 parameter pytree 안에 blocks list가 있고, 각 block은 다시 norm1, attn, norm2, mlp subtree를 가진다.
params
blocks[0]
norm1
attn
norm2
mlp
이 구조가 중요한 이유는 나중에 optimizer update, gradient norm, sharding rule이 모두 같은 pytree 구조를 따라가기 때문이다.
확인
- decoder block에서 parameter subtree를 나눠두면 어떤 후속 작업이 쉬워지는가?
- JAX의
tree_map이 optimizer 구현에서 자연스러운 이유는 무엇인가? - block을 class로 만들지 않아도 모델 구조를 읽을 수 있게 하려면 무엇을 명확히 해야 하는가?