Sequence Parallelism
마지막 수정:
Sequence Parallelism은 이 경로에서는 TP와 함께 쓰이는 activation memory 절약 기법이다.
TP = 2, activation base shape = (b, s, h) (b, s/2, h) all-gather seq
(b, s, h/2) reduce-scatter seq
(b, s/2, h) (b, s/2, h) LayerNorm
Dropout Y0*
(b, s/2, h)
(b, s/2, h) LayerNorm
Dropout Y1*
(b, s/2, h)
(b, s/2, h) GPU 1
(b, s/2, h) -> each GPU gets full X
(b, s, h)
Column-linear TP splits output hidden columns, so every rank needs all tokens as input.
(b, s, h) x W_up0 x W_down0 P0
(b, s, h)
(b, s, h) x W_up1 x W_down1 P1
(b, s, h)
partial full sequence -> GPU 0
Y tokens 0..s/2 GPU 1
Y tokens s/2..s
Reduce makes the row-linear result correct; scatter avoids storing full (b, s, h) on every rank.
(b, s, h) full activation appears in non-TP regions (b, s, h) / tp sequence or hidden dimension stays sharded TP는 MLP와 Attention의 큰 matrix multiplication을 hidden dimension 기준으로 나눈다.
TP region:
hidden sharded
(b, s, h/tp)
하지만 Transformer block에는 TP matmul이 아닌 구간도 있다.
LayerNorm
Dropout
residual add
elementwise activation
이 구간에서 매번 모든 GPU가 full activation (b, s, h)를 들고 있으면 activation memory peak가 다시 커진다.
SP는 이 남는 구간을 sequence dimension으로 나눈다.
SP region:
sequence sharded
(b, s/tp, h)
LayerNorm과 Dropout
LayerNorm은 각 token마다 hidden vector 전체를 보고 평균과 분산을 계산한다.
one token: [h values]
mean/variance over hidden dimension
그래서 LayerNorm에서는 hidden dimension이 full이어야 한다.
OK: (b, s/tp, h)
Bad: (b, s, h/tp)
Dropout은 LayerNorm처럼 hidden 전체를 봐야 하는 연산은 아니다. Dropout은 원소별 mask를 곱한다.
y = x * mask / keep_prob
다만 Dropout도 TP가 직접 처리하는 matmul 영역이 아니라 elementwise 영역에 있으므로, SP region에서 sequence-sharded activation 위에 수행한다. 이때 rank별 random seed와 mask 관리는 deterministic correctness를 위해 맞춰야 한다.
SP에서 TP로 갈 때 왜 all-gather를 하나
SP region의 activation은 sequence가 나뉘어 있다.
GPU 0: (b, s/2, h)
GPU 1: (b, s/2, h)
그런데 TP의 column-linear는 hidden output을 나누는 방식이다.
W = [W0 | W1]
GPU 0: X x W0 -> Y0
GPU 1: X x W1 -> Y1
각 GPU는 자기 hidden shard를 모든 token에 대해 계산해야 한다. 그래서 입력 X는 각 GPU에 full sequence로 있어야 한다.
column-linear input expected:
GPU 0: (b, s, h)
GPU 1: (b, s, h)
따라서 SP에서 TP로 들어갈 때 sequence dimension을 all-gather한다.
(b, s/tp, h) -> all-gather -> (b, s, h)
TP에서 SP로 돌아갈 때 왜 reduce-scatter를 하나
row-linear의 local output은 partial sum이다.
GPU 0: H0 x W_down0 -> P0
GPU 1: H1 x W_down1 -> P1
Y = P0 + P1
TP만 쓰면 여기서 all-reduce를 해서 모든 GPU가 full Y = (b, s, h)를 갖는다.
하지만 SP의 목적은 full activation을 복제하지 않는 것이다. 그래서 all-reduce 대신 reduce-scatter를 사용한다.
reduce: P0 + P1로 정답 Y를 만든다
scatter: Y를 sequence chunk로 나눠서 각 GPU에 준다
결과는 다시 SP region이 원하는 모양이다.
(b, s, h) partials -> reduce-scatter -> (b, s/tp, h)
통신 비용
TP only에서는 row-linear 뒤에 all-reduce가 있다.
all-reduce = reduce-scatter + all-gather
TP+SP는 이 all-reduce를 명시적으로 all-gather와 reduce-scatter의 전환으로 바꿔 쓰는 셈이다.
그래서 통신 횟수는 많아 보이지만, ring all-reduce 관점의 총 통신량은 비슷하게 볼 수 있다.
하지만 TP와 마찬가지로 이 통신은 forward/backward 경로에 직접 들어간다. 완전히 compute 뒤에 숨기기 어렵기 때문에 보통 TP+SP는 빠른 GPU interconnect가 있는 단일 노드 안에서 쓴다.
한계
SP는 CP가 아니다.
SP:
TP와 결합
LayerNorm/Dropout 같은 non-TP region의 activation peak를 줄임
CP:
long context를 위해 sequence 자체를 모델 전체에 걸쳐 나눔
attention에서 K/V 교환이 핵심
SP를 써도 TP region에서는 full sequence를 다루는 순간이 있다. 아주 긴 context가 문제라면 다음 단계에서는 Context Parallelism이 필요하다.
확인
- LayerNorm은 왜 hidden dimension이 full이어야 하는가?
- Dropout은 왜 LayerNorm과 같은 이유로 full hidden이 필요한 것은 아닌가?
- SP region의 activation shape은 무엇인가?
- SP에서 TP column-linear로 들어가기 전에 왜 sequence all-gather가 필요한가?
- row-linear 뒤에 all-reduce 대신 reduce-scatter를 쓰면 어떤 두 일을 동시에 하는가?