Tensor Parallel Linear From Picotron

마지막 수정:

trainingdistributedtensor-parallelismpicotronnanotronmegatron

TP를 코드로 읽을 때 첫 번째 기준은 두 layer다.

ColumnParallelLinear
RowParallelLinear

Picotron, Nanotron, Megatron 모두 이름과 세부 옵션은 달라도 이 구조로 돌아온다.

Column-wise

XW = X[W0 | W1 | W2] = [XW0 | XW1 | XW2]

Row-wise

XW = [X0 | X1 | X2] [W0; W1; W2] = sum_i XiWi
Column-wise sharding split W by output columns
GPU 0
X
[m x k]
x W0
[k x n/3]
= Y0
[m x n/3]
GPU 1
X
[m x k]
x W1
[k x n/3]
= Y1
[m x n/3]
GPU 2
X
[m x k]
x W2
[k x n/3]
= Y2
[m x n/3]
optional all-gather [Y0 | Y1 | Y2] -> full Y when the next op needs full output
Row-wise sharding split X and W along the shared inner dimension
GPU 0
X0
[m x k/3]
x W0
[k/3 x n]
= P0
[m x n]
GPU 1
X1
[m x k/3]
x W1
[k/3 x n]
= P1
[m x n]
GPU 2
X2
[m x k/3]
x W2
[k/3 x n]
= P2
[m x n]
all-reduce P0 + P1 + P2 -> final Y on every GPU
Column-wise sharding produces output shards. Row-wise sharding produces same-shaped partial outputs that must be summed.

Column parallel

PyTorch Linear의 weight shape은 보통 [out_features, in_features]다.

Column parallel은 output feature 방향을 나눈다.

W_full: [out, in]
W_0:    [out/tp, in]
W_1:    [out/tp, in]

각 TP rank는 같은 input X를 보고 자기 weight shard로 output shard를 만든다.

Y_0 = X @ W_0.T
Y_1 = X @ W_1.T
Y   = concat(Y_0, Y_1)

Picotron의 ColumnParallelLinearout_featurestp_world_size로 나누어지는지 확인하고, master weight를 dim 0으로 split한다. gather_output=True일 때만 output shard를 다시 모은다.

Row parallel

Row parallel은 input feature 방향을 나눈다.

X_full: [batch, in]
X_0:    [batch, in/tp]
X_1:    [batch, in/tp]

W_0:    [out, in/tp]
W_1:    [out, in/tp]

각 rank가 partial output을 만든 뒤 합친다.

P_0 = X_0 @ W_0.T
P_1 = X_1 @ W_1.T
Y   = P_0 + P_1

분산 환경에서는 이 합이 TP group 안의 all-reduce다. Picotron의 RowParallelLinear는 forward 끝에서 ReduceFromModelParallelRegion을 호출한다.

Transformer MLP with TP Column-linear up projection followed by row-linear down projection
X [b x h] -> H [b x 4h] -> Y [b x h]
1. Column-linear up split W_up by output columns, keep H sharded
GPU 0
X
[b x h]
x W_up0
[h x 4h/3]
=
GPU 1
X
[b x h]
x W_up1
[h x 4h/3]
=
GPU 2
X
[b x h]
x W_up2
[h x 4h/3]
=
No intermediate all-gather H stays split as [H0 | H1 | H2], which is exactly what the next row-linear layer wants.
2. Row-linear down split W_down by input rows, sum partial outputs
GPU 0
x W_down0
[4h/3 x h]
= P0
[b x h]
GPU 1
x W_down1
[4h/3 x h]
= P1
[b x h]
GPU 2
x W_down2
[4h/3 x h]
= P2
[b x h]
All-reduce once P0 + P1 + P2 -> Y on every TP rank.
Column -> Row X replicated/synced -> H sharded -> final all-reduce
Row -> Column row-linear needs an all-reduce before the next split can continue
The MLP uses column-linear first so the intermediate activation shard can flow directly into row-linear without being gathered.

왜 MLP는 column 다음 row인가

Transformer MLP는 보통 hidden dimension을 키웠다가 다시 줄인다.

X [h] -> up/gate [4h] -> down [h]

up projection을 column parallel로 나누면 중간 activation H가 output feature 방향으로 shard된다. 바로 다음 down projection은 input feature 방향으로 shard된 H_i를 원한다. 그래서 중간 all-gather 없이 이어진다.

Column up:
  X -> [H_0 | H_1]

Row down:
  H_0 W_0 + H_1 W_1 -> Y

이 패턴이 Megatron 논문식 MLP TP의 핵심이다.

Nanotron과 Megatron에서 달라지는 점

Nanotron은 Picotron보다 framework 책임이 더 크다. TensorParallelColumnLinearTensorParallelRowLinearpg, mode, async_communication, SplitConfig, sharded parameter metadata를 함께 가진다.

Megatron은 여기에 production 옵션이 더 붙는다. ColumnParallelLinearRowParallelLineargather_output, input_is_parallel, sequence_parallel, gradient_accumulation_fusion, is_expert, communication buffer 같은 옵션을 다룬다.

하지만 독자가 먼저 잡아야 할 것은 옵션이 아니다.

Column: output shard를 만든다
Row: partial output을 만들고 sum한다

이 두 문장을 놓치지 않으면 Nanotron과 Megatron의 복잡한 옵션도 “언제 gather/reduce를 하느냐”로 읽을 수 있다.

실습

GPU 없이 같은 계산을 확인한다.

cd labs/large-scale-training-parallelism
python3 tp_linear_sim.py

이 실습은 full linear 결과와 column/row sharding으로 복원한 결과가 같은지 확인한다.

확인

  • Column parallel에서 weight는 어느 dimension으로 나뉘는가?
  • Row parallel에서 왜 all-reduce가 필요한가?
  • Transformer MLP가 column parallel 다음 row parallel을 쓰면 어떤 gather를 피할 수 있는가?