Tensor Parallel Linear From Picotron
마지막 수정:
TP를 코드로 읽을 때 첫 번째 기준은 두 layer다.
ColumnParallelLinear
RowParallelLinear
Picotron, Nanotron, Megatron 모두 이름과 세부 옵션은 달라도 이 구조로 돌아온다.
Column-wise
XW = X[W0 | W1 | W2] = [XW0 | XW1 | XW2] Row-wise
XW = [X0 | X1 | X2] [W0; W1; W2] = sum_i XiWi [m x k] x W0
[k x n/3] = Y0
[m x n/3]
[m x k] x W1
[k x n/3] = Y1
[m x n/3]
[m x k] x W2
[k x n/3] = Y2
[m x n/3]
[m x k/3] x W0
[k/3 x n] = P0
[m x n]
[m x k/3] x W1
[k/3 x n] = P1
[m x n]
[m x k/3] x W2
[k/3 x n] = P2
[m x n]
Column parallel
PyTorch Linear의 weight shape은 보통 [out_features, in_features]다.
Column parallel은 output feature 방향을 나눈다.
W_full: [out, in]
W_0: [out/tp, in]
W_1: [out/tp, in]
각 TP rank는 같은 input X를 보고 자기 weight shard로 output shard를 만든다.
Y_0 = X @ W_0.T
Y_1 = X @ W_1.T
Y = concat(Y_0, Y_1)
Picotron의 ColumnParallelLinear도 out_features가 tp_world_size로 나누어지는지 확인하고, master weight를 dim 0으로 split한다. gather_output=True일 때만 output shard를 다시 모은다.
Row parallel
Row parallel은 input feature 방향을 나눈다.
X_full: [batch, in]
X_0: [batch, in/tp]
X_1: [batch, in/tp]
W_0: [out, in/tp]
W_1: [out, in/tp]
각 rank가 partial output을 만든 뒤 합친다.
P_0 = X_0 @ W_0.T
P_1 = X_1 @ W_1.T
Y = P_0 + P_1
분산 환경에서는 이 합이 TP group 안의 all-reduce다. Picotron의 RowParallelLinear는 forward 끝에서 ReduceFromModelParallelRegion을 호출한다.
X [b x h] -> H [b x 4h] -> Y [b x h] [b x h] x W_up0
[h x 4h/3] =
[b x h] x W_up1
[h x 4h/3] =
[b x h] x W_up2
[h x 4h/3] =
[4h/3 x h] = P0
[b x h]
[4h/3 x h] = P1
[b x h]
[4h/3 x h] = P2
[b x h]
왜 MLP는 column 다음 row인가
Transformer MLP는 보통 hidden dimension을 키웠다가 다시 줄인다.
X [h] -> up/gate [4h] -> down [h]
up projection을 column parallel로 나누면 중간 activation H가 output feature 방향으로 shard된다. 바로 다음 down projection은 input feature 방향으로 shard된 H_i를 원한다. 그래서 중간 all-gather 없이 이어진다.
Column up:
X -> [H_0 | H_1]
Row down:
H_0 W_0 + H_1 W_1 -> Y
이 패턴이 Megatron 논문식 MLP TP의 핵심이다.
Nanotron과 Megatron에서 달라지는 점
Nanotron은 Picotron보다 framework 책임이 더 크다. TensorParallelColumnLinear와 TensorParallelRowLinear는 pg, mode, async_communication, SplitConfig, sharded parameter metadata를 함께 가진다.
Megatron은 여기에 production 옵션이 더 붙는다. ColumnParallelLinear와 RowParallelLinear는 gather_output, input_is_parallel, sequence_parallel, gradient_accumulation_fusion, is_expert, communication buffer 같은 옵션을 다룬다.
하지만 독자가 먼저 잡아야 할 것은 옵션이 아니다.
Column: output shard를 만든다
Row: partial output을 만들고 sum한다
이 두 문장을 놓치지 않으면 Nanotron과 Megatron의 복잡한 옵션도 “언제 gather/reduce를 하느냐”로 읽을 수 있다.
실습
GPU 없이 같은 계산을 확인한다.
cd labs/large-scale-training-parallelism
python3 tp_linear_sim.py
이 실습은 full linear 결과와 column/row sharding으로 복원한 결과가 같은지 확인한다.
확인
- Column parallel에서 weight는 어느 dimension으로 나뉘는가?
- Row parallel에서 왜 all-reduce가 필요한가?
- Transformer MLP가 column parallel 다음 row parallel을 쓰면 어떤 gather를 피할 수 있는가?