JAX MLP와 Decoder Block 조립

마지막 수정:

jaxtransformermlpdecoder-block

JAX lab의 MLP는 두 개의 weight와 GELU로 충분하다.

def mlp(params, x):
    return linear(params["fc2"], jax.nn.gelu(linear(params["fc1"], x)))

decoder block은 function composition이다.

def decoder_block(params, x, config):
    x = x + self_attention(params["attn"], rms_norm(params["norm1"], x, config.eps), config)
    x = x + mlp(params["mlp"], rms_norm(params["norm2"], x, config.eps))
    return x

여기서 params는 block 하나의 subtree다. 전체 모델 parameter pytree 안에 blocks list가 있고, 각 block은 다시 norm1, attn, norm2, mlp subtree를 가진다.

params
  blocks[0]
    norm1
    attn
    norm2
    mlp

이 구조가 중요한 이유는 나중에 optimizer update, gradient norm, sharding rule이 모두 같은 pytree 구조를 따라가기 때문이다.

확인

  • decoder block에서 parameter subtree를 나눠두면 어떤 후속 작업이 쉬워지는가?
  • JAX의 tree_map이 optimizer 구현에서 자연스러운 이유는 무엇인가?
  • block을 class로 만들지 않아도 모델 구조를 읽을 수 있게 하려면 무엇을 명확히 해야 하는가?