Sequence Parallelism

마지막 수정:

trainingdistributedtensor-parallelismsequence-parallelismactivation-memory

Sequence Parallelism은 이 경로에서는 TP와 함께 쓰이는 activation memory 절약 기법이다.

TP + Sequence Parallelism Keep either sequence or hidden dimension sharded to reduce activation memory.
TP = 2, activation base shape = (b, s, h)
SP region sequence sharded, hidden full (b, s/2, h)
g
all-gather seq
TP region sequence full, hidden/intermediate sharded (b, s, h/2)
g*
reduce-scatter seq
SP region sequence sharded, hidden full (b, s/2, h)
1. LayerNorm / Dropout in SP region each GPU owns different tokens, but full hidden dimension
GPU 0 tokens 0..s/2
(b, s/2, h)
LayerNorm
Dropout
Y0*
(b, s/2, h)
GPU 1 tokens s/2..s
(b, s/2, h)
LayerNorm
Dropout
Y1*
(b, s/2, h)
Why this works LayerNorm needs full hidden per token, not the full sequence. Dropout is elementwise and can run on the same sequence shard.
2. SP -> TP: all-gather along sequence
GPU 0
(b, s/2, h)
GPU 1
(b, s/2, h)
-> each GPU gets full X
(b, s, h)

Column-linear TP splits output hidden columns, so every rank needs all tokens as input.

3. TP MLP region column-linear produces hidden shards, row-linear produces partial full-hidden outputs
GPU 0 X
(b, s, h)
x W_up0 x W_down0 P0
(b, s, h)
GPU 1 X
(b, s, h)
x W_up1 x W_down1 P1
(b, s, h)
4. TP -> SP: reduce-scatter along sequence
P0 + P1
partial full sequence
-> GPU 0
Y tokens 0..s/2
GPU 1
Y tokens s/2..s

Reduce makes the row-linear result correct; scatter avoids storing full (b, s, h) on every rank.

TP only peak (b, s, h) full activation appears in non-TP regions
TP + SP peak (b, s, h) / tp sequence or hidden dimension stays sharded
SP is a layout transition strategy: sequence-sharded outside TP, hidden-sharded inside TP.

TP는 MLP와 Attention의 큰 matrix multiplication을 hidden dimension 기준으로 나눈다.

TP region:
hidden sharded
(b, s, h/tp)

하지만 Transformer block에는 TP matmul이 아닌 구간도 있다.

LayerNorm
Dropout
residual add
elementwise activation

이 구간에서 매번 모든 GPU가 full activation (b, s, h)를 들고 있으면 activation memory peak가 다시 커진다.

SP는 이 남는 구간을 sequence dimension으로 나눈다.

SP region:
sequence sharded
(b, s/tp, h)

LayerNorm과 Dropout

LayerNorm은 각 token마다 hidden vector 전체를 보고 평균과 분산을 계산한다.

one token: [h values]
mean/variance over hidden dimension

그래서 LayerNorm에서는 hidden dimension이 full이어야 한다.

OK:  (b, s/tp, h)
Bad: (b, s, h/tp)

Dropout은 LayerNorm처럼 hidden 전체를 봐야 하는 연산은 아니다. Dropout은 원소별 mask를 곱한다.

y = x * mask / keep_prob

다만 Dropout도 TP가 직접 처리하는 matmul 영역이 아니라 elementwise 영역에 있으므로, SP region에서 sequence-sharded activation 위에 수행한다. 이때 rank별 random seed와 mask 관리는 deterministic correctness를 위해 맞춰야 한다.

SP에서 TP로 갈 때 왜 all-gather를 하나

SP region의 activation은 sequence가 나뉘어 있다.

GPU 0: (b, s/2, h)
GPU 1: (b, s/2, h)

그런데 TP의 column-linear는 hidden output을 나누는 방식이다.

W = [W0 | W1]

GPU 0: X x W0 -> Y0
GPU 1: X x W1 -> Y1

각 GPU는 자기 hidden shard를 모든 token에 대해 계산해야 한다. 그래서 입력 X는 각 GPU에 full sequence로 있어야 한다.

column-linear input expected:
GPU 0: (b, s, h)
GPU 1: (b, s, h)

따라서 SP에서 TP로 들어갈 때 sequence dimension을 all-gather한다.

(b, s/tp, h) -> all-gather -> (b, s, h)

TP에서 SP로 돌아갈 때 왜 reduce-scatter를 하나

row-linear의 local output은 partial sum이다.

GPU 0: H0 x W_down0 -> P0
GPU 1: H1 x W_down1 -> P1

Y = P0 + P1

TP만 쓰면 여기서 all-reduce를 해서 모든 GPU가 full Y = (b, s, h)를 갖는다.

하지만 SP의 목적은 full activation을 복제하지 않는 것이다. 그래서 all-reduce 대신 reduce-scatter를 사용한다.

reduce:  P0 + P1로 정답 Y를 만든다
scatter: Y를 sequence chunk로 나눠서 각 GPU에 준다

결과는 다시 SP region이 원하는 모양이다.

(b, s, h) partials -> reduce-scatter -> (b, s/tp, h)

통신 비용

TP only에서는 row-linear 뒤에 all-reduce가 있다.

all-reduce = reduce-scatter + all-gather

TP+SP는 이 all-reduce를 명시적으로 all-gatherreduce-scatter의 전환으로 바꿔 쓰는 셈이다.

그래서 통신 횟수는 많아 보이지만, ring all-reduce 관점의 총 통신량은 비슷하게 볼 수 있다.

하지만 TP와 마찬가지로 이 통신은 forward/backward 경로에 직접 들어간다. 완전히 compute 뒤에 숨기기 어렵기 때문에 보통 TP+SP는 빠른 GPU interconnect가 있는 단일 노드 안에서 쓴다.

한계

SP는 CP가 아니다.

SP:
TP와 결합
LayerNorm/Dropout 같은 non-TP region의 activation peak를 줄임

CP:
long context를 위해 sequence 자체를 모델 전체에 걸쳐 나눔
attention에서 K/V 교환이 핵심

SP를 써도 TP region에서는 full sequence를 다루는 순간이 있다. 아주 긴 context가 문제라면 다음 단계에서는 Context Parallelism이 필요하다.

확인

  • LayerNorm은 왜 hidden dimension이 full이어야 하는가?
  • Dropout은 왜 LayerNorm과 같은 이유로 full hidden이 필요한 것은 아닌가?
  • SP region의 activation shape은 무엇인가?
  • SP에서 TP column-linear로 들어가기 전에 왜 sequence all-gather가 필요한가?
  • row-linear 뒤에 all-reduce 대신 reduce-scatter를 쓰면 어떤 두 일을 동시에 하는가?