Tensor Parallelism

마지막 수정:

trainingdistributedtensor-parallelismmatmul

Tensor Parallelism은 layer 내부의 tensor 계산 자체를 여러 GPU가 나눠 수행하는 방법이다.

Column-wise

XW = X[W0 | W1 | W2] = [XW0 | XW1 | XW2]

Row-wise

XW = [X0 | X1 | X2] [W0; W1; W2] = sum_i XiWi
Column-wise sharding split W by output columns
GPU 0
X
[m x k]
x W0
[k x n/3]
= Y0
[m x n/3]
GPU 1
X
[m x k]
x W1
[k x n/3]
= Y1
[m x n/3]
GPU 2
X
[m x k]
x W2
[k x n/3]
= Y2
[m x n/3]
optional all-gather [Y0 | Y1 | Y2] -> full Y when the next op needs full output
Row-wise sharding split X and W along the shared inner dimension
GPU 0
X0
[m x k/3]
x W0
[k/3 x n]
= P0
[m x n]
GPU 1
X1
[m x k/3]
x W1
[k/3 x n]
= P1
[m x n]
GPU 2
X2
[m x k/3]
x W2
[k/3 x n]
= P2
[m x n]
all-reduce P0 + P1 + P2 -> final Y on every GPU
Column-wise sharding produces output shards. Row-wise sharding produces same-shaped partial outputs that must be summed.

ZeRO는 parameter, gradient, optimizer state를 shard해서 저장 메모리를 줄였다. 하지만 activation memory가 커지면 ZeRO만으로는 부족하다.

ZeRO-3도 계산할 때는 layer parameter를 all-gather해서 각 DP rank가 자기 micro-batch forward를 수행한다.

ZeRO-3:
store parameters sharded
compute each layer with temporary full parameters
activation stays local to each DP rank

TP는 접근이 다르다. 행렬곱 자체를 여러 GPU가 나눠 계산한다.

store tensors sharded
compute with sharded tensors
communicate only the result shape needed by the next operation

기본 행렬곱

neural network의 linear layer는 보통 다음처럼 쓴다.

Y = XW

여기서:

X = input 또는 activation
W = linear layer weight
Y = output activation

행렬곱은 두 가지 방식으로 나눌 수 있다.

1. W를 column 방향으로 나누기
2. X와 W를 inner dimension 방향으로 나누기

이 두 방식이 TP의 기본 재료다.

Column-wise sharding

column-wise sharding은 weight W를 output column 방향으로 나눈다.

W = [W0 | W1 | W2]

그러면:

Y = XW = [XW0 | XW1 | XW2]

각 GPU는 같은 full X를 가지고 자기 weight shard만 곱한다.

GPU 0: X x W0 -> Y0
GPU 1: X x W1 -> Y1
GPU 2: X x W2 -> Y2

결과 Y는 column shard로 나뉘어 있다.

Y = [Y0 | Y1 | Y2]

다음 연산이 sharded Y를 그대로 받을 수 있으면 all-gather를 미룰 수 있다. 하지만 full Y가 필요하면 all-gather로 합친다.

Y0, Y1, Y2 -> all-gather -> full Y

Row-wise sharding

row-wise sharding은 weight W를 input row 방향으로 나눈다. 이때 X도 같은 inner dimension 기준으로 나눠야 한다.

X = [X0 | X1 | X2]

W = [ W0
      W1
      W2 ]

그러면:

Y = XW = X0W0 + X1W1 + X2W2

각 GPU는 partial output을 만든다.

GPU 0: X0 x W0 -> partial Y0
GPU 1: X1 x W1 -> partial Y1
GPU 2: X2 x W2 -> partial Y2

각 partial output은 이미 Y와 같은 shape이지만, 값은 전체 결과가 아니다. 최종 Y를 만들려면 partial output을 더해야 한다.

partial Y0 + partial Y1 + partial Y2 -> all-reduce -> Y

두 방식의 통신 차이

column-wise sharding은 입력 X가 모든 GPU에 필요하다.

full X 필요
W column shard
output Y shard
필요하면 all-gather

row-wise sharding은 입력 X도 나눠야 한다.

X shard 필요
W row shard
partial Y 생성
all-reduce로 합산

그래서 TP를 이해할 때는 항상 이 질문을 해야 한다.

어느 dimension을 shard했는가?
각 GPU의 local matmul 결과 shape은 무엇인가?
다음 layer가 원하는 shape을 만들려면 어떤 collective가 필요한가?

Transformer MLP에 적용하기

Transformer layer의 MLP는 보통 두 개의 linear layer로 볼 수 있다.

X -> Linear up -> activation -> Linear down -> Y

hidden size를 h, MLP intermediate size를 4h라고 하면 모양은 대략 이렇다.

X:      [batch, h]
W_up:   [h, 4h]
W_down: [4h, h]

TP에서는 이 두 linear를 다음 순서로 배치하는 것이 핵심이다.

1. W_up은 column-wise로 shard한다.
2. W_down은 row-wise로 shard한다.
Transformer MLP with TP Column-linear up projection followed by row-linear down projection
X [b x h] -> H [b x 4h] -> Y [b x h]
1. Column-linear up split W_up by output columns, keep H sharded
GPU 0
X
[b x h]
x W_up0
[h x 4h/3]
=
GPU 1
X
[b x h]
x W_up1
[h x 4h/3]
=
GPU 2
X
[b x h]
x W_up2
[h x 4h/3]
=
No intermediate all-gather H stays split as [H0 | H1 | H2], which is exactly what the next row-linear layer wants.
2. Row-linear down split W_down by input rows, sum partial outputs
GPU 0
x W_down0
[4h/3 x h]
= P0
[b x h]
GPU 1
x W_down1
[4h/3 x h]
= P1
[b x h]
GPU 2
x W_down2
[4h/3 x h]
= P2
[b x h]
All-reduce once P0 + P1 + P2 -> Y on every TP rank.
Column -> Row X replicated/synced -> H sharded -> final all-reduce
Row -> Column row-linear needs an all-reduce before the next split can continue
The MLP uses column-linear first so the intermediate activation shard can flow directly into row-linear without being gathered.

첫 번째 linear인 W_up을 column-wise로 나누면 각 GPU는 intermediate activation의 일부만 만든다.

GPU 0: X x W_up0 -> H0
GPU 1: X x W_up1 -> H1
GPU 2: X x W_up2 -> H2

여기서 H = [H0 | H1 | H2]이지만, 굳이 full H로 all-gather하지 않는다. 바로 다음 layer인 W_down을 row-wise로 나누면 H0, H1, H2를 그대로 입력 shard로 쓸 수 있기 때문이다.

GPU 0: H0 x W_down0 -> P0
GPU 1: H1 x W_down1 -> P1
GPU 2: H2 x W_down2 -> P2

마지막 output은 partial output들의 합이다.

Y = P0 + P1 + P2

그래서 forward에서 필요한 주요 communication은 마지막 all-reduce다.

column-linear -> no intermediate all-gather
row-linear    -> all-reduce at the end

원문에서 말하는 broadcast는 “각 TP rank가 같은 X를 갖도록 복사한다”는 뜻이다. 하지만 실제 학습에서는 이전 단계의 출력이 이미 TP rank들 사이에서 같은 값으로 맞춰져 있게 설계할 수 있어서, 매번 명시적인 broadcast가 필요하지 않은 경우가 많다.

반대로 row-linear -> column-linear 순서로 시작하면 첫 row-linear의 결과를 만들기 위해 먼저 all-reduce가 필요하다. 그 다음 column-linear를 수행하므로, 두 split 사이에 중간 통신이 끼어든다. 그래서 MLP에서는 보통 column-linear -> row-linear 조합이 더 자연스럽다.

Attention에 적용하기

Multi-head attention도 같은 패턴을 쓴다.

Q/K/V projection: column-wise
attention heads: GPU별 subset 계산
output projection: row-wise
Attention with TP Column-parallel Q/K/V, head-local attention, row-parallel output projection
8 heads / TP=2 -> 4 heads per GPU
GPU 0 heads 0, 1, 2, 3
0123
GPU 1 heads 4, 5, 6, 7
4567
1. Column-linear Q/K/V split projection outputs by attention heads
GPU 0
X
[b x h]
x Wq0/Wk0/Wv0
heads 0-3
= Q0 K0 V0
head shard
GPU 1
X
[b x h]
x Wq1/Wk1/Wv1
heads 4-7
= Q1 K1 V1
head shard
2. Head-local attention each GPU computes only its assigned heads
GPU 0
Q0 K0 V0 -> softmax(Q0 K0^T) V0 = A0
heads 0-3
GPU 1
Q1 K1 V1 -> softmax(Q1 K1^T) V1 = A1
heads 4-7
No head all-gather A stays split as [A0 | A1], then flows directly into the row-parallel output projection.
3. Row-linear output projection partial outputs are summed with all-reduce
GPU 0
A0
[b x h/2]
x Wo0
[h/2 x h]
= P0
[b x h]
GPU 1
A1
[b x h/2]
x Wo1
[h/2 x h]
= P1
[b x h]
All-reduce once P0 + P1 -> final attention output Y on every TP rank.
Good split 32 heads / TP 4 = 8 heads per GPU
Bad split 32 heads / TP 6 = uneven head shards
Attention TP shards the naturally independent head dimension, then uses row-linear output projection to avoid gathering head outputs before the final all-reduce.

Attention block은 먼저 input X에서 Q, K, V를 만든다.

Q = XWq
K = XWk
V = XWv

여기서 Wq, Wk, Wv를 column-wise로 나누면 output dimension이 head 단위로 나뉜다.

GPU 0: heads 0-3에 필요한 Q0, K0, V0 생성
GPU 1: heads 4-7에 필요한 Q1, K1, V1 생성

각 attention head는 자기 head의 Q, K, V만 필요로 한다.

head_i = softmax(Q_i K_i^T) V_i

그래서 GPU 0이 head 0-3을 계산할 때 GPU 1의 head 4-7 값을 기다릴 필요가 없다. 이게 attention에서 TP가 자연스럽게 맞는 이유다.

head별 attention output도 shard 상태로 남는다.

A = [A0 | A1]

이 intermediate attention output을 바로 all-gather하지 않고, output projection Wo를 row-wise로 나눈다.

GPU 0: A0 x Wo0 -> P0
GPU 1: A1 x Wo1 -> P1

최종 output은 partial output들의 합이다.

Y = P0 + P1

그래서 마지막에 all-reduce가 필요하다.

P0 + P1 -> all-reduce -> Y

MLP와 attention은 같은 구조로 볼 수 있다.

MLP:
column-linear W_up -> hidden shard
row-linear W_down -> all-reduce

Attention:
column-linear Wq/Wk/Wv -> head shard
row-linear Wo -> all-reduce

차이는 shard하는 독립 dimension이다.

MLP: intermediate hidden dimension
Attention: num_attention_heads dimension

TP size를 고를 때 head 수를 봐야 한다

Attention TP에서는 TP size가 attention head 수를 깔끔하게 나눌 수 있어야 한다.

num_attention_heads % TP_size == 0

예를 들어 head가 32개이고 TP size가 4이면 자연스럽다.

GPU 0: heads 0-7
GPU 1: heads 8-15
GPU 2: heads 16-23
GPU 3: heads 24-31

하지만 head가 32개인데 TP size가 6이면 균등하게 나눌 수 없다.

32 / 6 = 5.333...

이 경우 head가 “반으로 잘린다”기보다는, GPU별 tensor shape이 균일하지 않아서 일반적인 TP 구현이 처리하기 어려워진다. 대부분의 프레임워크는 이런 설정을 에러로 막는다.

GQA나 MQA에서는 K/V head 수가 query head 수보다 적다. 그래서 num_key_value_heads도 같이 봐야 한다.

num_attention_heads % TP_size == 0
num_key_value_heads % TP_size == 0
hidden_size % TP_size == 0
intermediate_size % TP_size == 0

즉 TP size는 모델의 shard 가능한 dimension을 깔끔하게 나누는 값이어야 한다.

확인

  • column-wise sharding에서 W = [W0 | W1]이면 output Y는 어떤 모양으로 나뉘는가?
  • row-wise sharding에서 왜 X도 같은 inner dimension 기준으로 나눠야 하는가?
  • column-wise는 보통 어떤 collective로 output shard를 합칠 수 있는가?
  • row-wise는 왜 all-reduce가 필요한가?
  • Transformer MLP에서 왜 column-linear -> row-linear 순서가 중간 all-reduce를 피하게 해주는가?
  • Attention TP에서 왜 head 수가 TP size로 나누어 떨어져야 하는가?