Context Parallelism

마지막 수정:

trainingdistributedcontext-parallelismring-attentionlong-context

Context Parallelism은 긴 sequence를 여러 GPU가 token chunk 단위로 나눠 처리하는 방법이다.

Context Parallelism Shard the sequence across GPUs, exchange K/V chunks only when attention needs them.
sequence length s -> s / CP per GPU
GPU 0 tokens 0..s/4 Q0 K0 V0
GPU 1 tokens s/4..s/2 Q1 K1 V1
GPU 2 tokens s/2..3s/4 Q2 K2 V2
GPU 3 tokens 3s/4..s Q3 K3 V3
MLP / LayerNorm / residual Each token chunk can be processed locally.
GPU 0
local tokens
GPU 1
local tokens
GPU 2
local tokens
GPU 3
local tokens
Attention is the hard part Local Q needs K/V chunks from the rest of the sequence.
GPU 2 keeps Q2 Q2
Needed K/V chunks K0 V0 K1 V1 K2 V2 K3 V3
masked for causal rows when future
All-gather K/V
K0/V0 K1/V1 K2/V2 K3/V3
simple, but high temporary memory

Every GPU materializes full K/V before attention.

Ring Attention
GPU 0 -> GPU 1 -> GPU 2 -> GPU 3 ->
chunked K/V, lower peak memory

Compute with current K/V chunk while the next chunk is in flight.

Online softmax state Allows exact attention without holding all K/V at once.
m max score seen so far
l running softmax denominator
o running output numerator
KV chunk 0 update m,l,o KV chunk 1 update m,l,o KV chunk 2 final output
Sequential split early-token ranks do less work, late-token ranks do more
Zig-zag split mix early and late tokens to balance causal attention work
CP is easy for token-wise layers. The real work is making attention exact and memory-efficient while K/V is distributed.

SP도 sequence dimension을 나누지만, SP는 TP를 보조하는 기법이다. TP region에 들어갈 때는 full sequence를 다시 all-gather하는 순간이 있다.

SP:
non-TP region만 sequence sharded
TP region에서는 full sequence가 필요한 순간이 있음

CP는 sequence split을 모델 전체에 적용한다.

CP:
GPU 0: tokens 0..s/4
GPU 1: tokens s/4..s/2
GPU 2: tokens s/2..3s/4
GPU 3: tokens 3s/4..s

즉 한 GPU가 full sequence activation을 들지 않게 해서 long context memory를 줄인다.

쉬운 부분: token-wise 연산

MLP, LayerNorm, residual add 같은 연산은 대부분 token별로 독립이다.

token i의 MLP는 token j를 몰라도 된다
token i의 LayerNorm은 token i의 hidden vector만 보면 된다

그래서 CP에서는 각 GPU가 자기 token chunk만 처리해도 된다.

GPU 0: 자기 token chunk의 MLP / LayerNorm 계산
GPU 1: 자기 token chunk의 MLP / LayerNorm 계산

어려운 부분: attention

Attention은 다르다. 각 token의 query는 다른 token들의 key/value를 봐야 한다.

scores = Q K^T
out = softmax(scores) V

CP에서는 각 GPU가 자기 token chunk의 Q/K/V만 갖고 있다.

GPU 0: Q0, K0, V0
GPU 1: Q1, K1, V1
GPU 2: Q2, K2, V2
GPU 3: Q3, K3, V3

GPU 2가 자기 token들의 attention을 계산하려면 K0/V0, K1/V1, K2/V2 같은 다른 chunk도 필요하다. causal attention에서는 미래 token은 보지 않지만, 이전 token chunk는 봐야 한다.

그래서 CP의 핵심은 attention에서 필요한 K/V를 어떻게 교환하느냐다.

방법 1: all-gather

가장 단순한 방법은 모든 GPU가 전체 K/V를 모으는 것이다.

K0/V0, K1/V1, K2/V2, K3/V3
-> all-gather
-> every GPU has full K/V

이 방식은 이해하기 쉽다. 하지만 attention 계산 중에 전체 K/V를 임시로 저장해야 하므로 peak memory가 커진다.

CP를 쓰는 목적이 long context memory를 낮추는 것이므로, all-gather 방식은 그 장점을 약하게 만든다.

방법 2: Ring Attention

Ring Attention은 K/V를 한 번에 다 모으지 않는다. 각 GPU가 K/V chunk를 하나씩 받아가며 attention 결과를 누적한다.

1. 내 Q는 고정한다
2. 현재 가진 K/V chunk로 attention 일부를 계산한다
3. K/V chunk를 다음 GPU로 보낸다
4. 이전 GPU에서 다음 K/V chunk를 받는다
5. 모든 K/V chunk를 볼 때까지 반복한다

이때 중요한 수학적 도구가 online softmax다.

softmax는 전체 score의 분모가 필요하므로 chunk별 softmax를 따로 계산해서 더하면 틀린다.

softmax(score_chunk_0) + softmax(score_chunk_1)  # wrong

하지만 online softmax를 쓰면 chunk를 하나씩 보면서도 전체 attention과 같은 결과를 만들 수 있다.

m = 지금까지 본 score max
l = 지금까지 본 exp(score - m)의 누적 합
o = 지금까지 누적한 output

K/V chunk가 올 때마다 m, l, o를 갱신한다. 모든 chunk를 처리하면 full K/V를 한 번에 모아서 계산한 것과 같은 attention output이 나온다.

FlashAttention과 닮은 점

Ring Attention은 FlashAttention과 같은 핵심 아이디어를 공유한다.

FlashAttention:
한 GPU 안에서 K/V block을 나눠 읽고 online softmax로 누적

Ring Attention:
여러 GPU 사이에서 K/V chunk를 돌리고 online softmax로 누적

차이는 chunk가 어디서 오느냐다.

FlashAttention: 같은 GPU 메모리의 block
Ring Attention: 다른 GPU가 가진 K/V chunk

causal attention의 불균형

causal attention에서는 앞쪽 token과 뒤쪽 token의 계산량이 다르다.

초반 token: 볼 수 있는 과거 token이 적음
후반 token: 볼 수 있는 과거 token이 많음

sequence를 순서대로 나누면 앞쪽 GPU는 일이 적고 뒤쪽 GPU는 일이 많아질 수 있다.

GPU 0: early tokens -> less work
GPU 3: late tokens  -> more work

그래서 Zig-Zag Attention처럼 각 GPU에 앞쪽 token과 뒤쪽 token을 섞어 배치해서 계산량을 맞추는 구현이 필요할 수 있다.

정리

CP의 본질은 sequence length 문제를 직접 나누는 것이다.

CP:
sequence/token dimension을 모델 전체에서 shard
MLP/LayerNorm은 token chunk별로 계산
Attention은 K/V chunk 교환이 필요
Ring Attention은 K/V를 chunk 단위로 돌리며 online softmax로 누적

SP와 CP의 차이는 적용 범위다.

SP:
TP의 activation memory 보완
TP region 바깥에서 sequence shard

CP:
long context를 위한 병렬화 축
모델 전체에서 sequence shard
attention K/V 통신이 핵심

확인

  • CP는 SP와 달리 sequence split을 어디에 적용하는가?
  • MLP와 LayerNorm은 왜 CP에서 상대적으로 쉬운가?
  • attention에서는 왜 다른 GPU의 K/V가 필요한가?
  • all-gather 방식의 장점과 단점은 무엇인가?
  • Ring Attention이 online softmax를 필요로 하는 이유는 무엇인가?