Sequence Parallelism as TP Layout

마지막 수정:

trainingdistributedsequence-parallelismtensor-parallelismnanotronmegatron

Picotron은 TP/PP/CP/DP를 배우기에 좋지만, Sequence Parallelism은 거의 구현되어 있지 않다. 코드에도 LayerNorm을 TP rank 사이에서 나누는 것은 TODO로 남아 있다.

그래서 SP는 Picotron이 아니라 Nanotron과 Megatron으로 넘어가면서 배워야 한다.

TP + Sequence Parallelism Keep either sequence or hidden dimension sharded to reduce activation memory.
TP = 2, activation base shape = (b, s, h)
SP region sequence sharded, hidden full (b, s/2, h)
g
all-gather seq
TP region sequence full, hidden/intermediate sharded (b, s, h/2)
g*
reduce-scatter seq
SP region sequence sharded, hidden full (b, s/2, h)
1. LayerNorm / Dropout in SP region each GPU owns different tokens, but full hidden dimension
GPU 0 tokens 0..s/2
(b, s/2, h)
LayerNorm
Dropout
Y0*
(b, s/2, h)
GPU 1 tokens s/2..s
(b, s/2, h)
LayerNorm
Dropout
Y1*
(b, s/2, h)
Why this works LayerNorm needs full hidden per token, not the full sequence. Dropout is elementwise and can run on the same sequence shard.
2. SP -> TP: all-gather along sequence
GPU 0
(b, s/2, h)
GPU 1
(b, s/2, h)
-> each GPU gets full X
(b, s, h)

Column-linear TP splits output hidden columns, so every rank needs all tokens as input.

3. TP MLP region column-linear produces hidden shards, row-linear produces partial full-hidden outputs
GPU 0 X
(b, s, h)
x W_up0 x W_down0 P0
(b, s, h)
GPU 1 X
(b, s, h)
x W_up1 x W_down1 P1
(b, s, h)
4. TP -> SP: reduce-scatter along sequence
P0 + P1
partial full sequence
-> GPU 0
Y tokens 0..s/2
GPU 1
Y tokens s/2..s

Reduce makes the row-linear result correct; scatter avoids storing full (b, s, h) on every rank.

TP only peak (b, s, h) full activation appears in non-TP regions
TP + SP peak (b, s, h) / tp sequence or hidden dimension stays sharded
SP is a layout transition strategy: sequence-sharded outside TP, hidden-sharded inside TP.

SP는 별도 축이라기보다 TP의 layout 전환이다

이 path에서는 SP를 다음처럼 이해한다.

TP only:
  row-linear 뒤에 all-reduce
  모든 TP rank가 full activation을 가짐

TP + SP:
  row-linear 뒤에 reduce-scatter
  각 TP rank가 sequence shard만 가짐

즉 SP는 새로운 process group을 만드는 독립 병렬화 축이라기보다, TP group 안에서 activation layout을 바꾸는 방식에 가깝다.

SP region: sequence sharded
  (batch, seq/tp, hidden)

TP matmul region: hidden/intermediate sharded
  (batch, seq, hidden/tp)

Nanotron에서 보는 법

Nanotron은 이 관점을 코드에 꽤 직접적으로 드러낸다.

ParallelismArgs의 설명은 tp_mode를 두 가지로 둔다.

ALL_REDUCE      -> normal TP
REDUCE_SCATTER  -> activate sequence parallelism

그래서 Nanotron에서 SP를 찾을 때는 sequence_parallel=True 같은 이름만 찾으면 부족하다. 핵심은 TensorParallelLinearMode.REDUCE_SCATTER다.

row_linear(..., tp_mode=REDUCE_SCATTER)
  -> differentiable_reduce_scatter_sum(...)

column_linear(..., tp_mode=REDUCE_SCATTER)
  -> input all-gather before matmul

이 흐름은 “SP region에서 sequence shard로 들고 있다가, column linear에 들어갈 때 다시 full sequence를 모은다”는 뜻이다.

Megatron에서 보는 법

Megatron은 사용법에서는 명시적이다.

--tensor-model-parallel-size 4
--sequence-parallel

문서도 TP를 쓸 때 sequence parallel을 권장한다. 코드에서는 sequence_parallel이 켜져 있으면 forward에서 input을 all-gather하고, backward에서 input gradient를 reduce-scatter하는 경로가 나온다.

Megatron의 복잡도는 여기서 끝나지 않는다. gradient accumulation fusion, async communication, Transformer Engine fused LayerNormLinear, FP8, pipeline schedule과 함께 SP가 섞인다. 하지만 학습자가 먼저 잡아야 할 핵심은 Nanotron과 같다.

all-reduce result를 모든 rank에 복제하지 말고,
reduce-scatter로 정답을 만들면서 sequence shard만 남긴다.

왜 이게 activation memory를 줄이나

LayerNorm, Dropout, residual add 같은 구간은 hidden dimension 전체가 필요하거나 elementwise로 처리된다. 이 구간에서 모든 TP rank가 full (batch, seq, hidden)을 들고 있으면 TP를 했는데도 activation memory가 중복된다.

SP는 이 구간을 sequence 방향으로 나눠 둔다.

rank 0: tokens 0..seq/tp
rank 1: tokens seq/tp..2seq/tp

LayerNorm은 한 token의 hidden vector 전체만 있으면 되므로 (batch, seq/tp, hidden)에서 계산할 수 있다.

실습

GPU 없이 layout 전환만 확인한다.

cd labs/large-scale-training-parallelism
python3 sp_layout_sim.py

이 실습은 다음 두 가지가 같은 결과를 만든다는 것을 확인한다.

full output
reduce-scatter(sum(partial outputs))

SP를 이해할 때 중요한 것은 “정답이 바뀌지 않는다”와 “각 rank가 들고 있는 activation shape이 바뀐다”를 동시에 보는 것이다.

확인

  • Nanotron에서 어떤 tp_mode가 sequence parallelism을 활성화하는가?
  • TP only의 all-reduce와 TP+SP의 reduce-scatter는 결과 저장 방식이 어떻게 다른가?
  • LayerNorm이 (batch, seq/tp, hidden)에서 계산될 수 있는 이유는 무엇인가?