Context Parallelism
마지막 수정:
Context Parallelism은 긴 sequence를 여러 GPU가 token chunk 단위로 나눠 처리하는 방법이다.
sequence length s -> s / CP per GPU Q0 K0 V0 Q1 K1 V1 Q2 K2 V2 Q3 K3 V3 local tokens GPU 1
local tokens GPU 2
local tokens GPU 3
local tokens
masked for causal rows when future
Every GPU materializes full K/V before attention.
Compute with current K/V chunk while the next chunk is in flight.
SP도 sequence dimension을 나누지만, SP는 TP를 보조하는 기법이다. TP region에 들어갈 때는 full sequence를 다시 all-gather하는 순간이 있다.
SP:
non-TP region만 sequence sharded
TP region에서는 full sequence가 필요한 순간이 있음
CP는 sequence split을 모델 전체에 적용한다.
CP:
GPU 0: tokens 0..s/4
GPU 1: tokens s/4..s/2
GPU 2: tokens s/2..3s/4
GPU 3: tokens 3s/4..s
즉 한 GPU가 full sequence activation을 들지 않게 해서 long context memory를 줄인다.
쉬운 부분: token-wise 연산
MLP, LayerNorm, residual add 같은 연산은 대부분 token별로 독립이다.
token i의 MLP는 token j를 몰라도 된다
token i의 LayerNorm은 token i의 hidden vector만 보면 된다
그래서 CP에서는 각 GPU가 자기 token chunk만 처리해도 된다.
GPU 0: 자기 token chunk의 MLP / LayerNorm 계산
GPU 1: 자기 token chunk의 MLP / LayerNorm 계산
어려운 부분: attention
Attention은 다르다. 각 token의 query는 다른 token들의 key/value를 봐야 한다.
scores = Q K^T
out = softmax(scores) V
CP에서는 각 GPU가 자기 token chunk의 Q/K/V만 갖고 있다.
GPU 0: Q0, K0, V0
GPU 1: Q1, K1, V1
GPU 2: Q2, K2, V2
GPU 3: Q3, K3, V3
GPU 2가 자기 token들의 attention을 계산하려면 K0/V0, K1/V1, K2/V2 같은 다른 chunk도 필요하다. causal attention에서는 미래 token은 보지 않지만, 이전 token chunk는 봐야 한다.
그래서 CP의 핵심은 attention에서 필요한 K/V를 어떻게 교환하느냐다.
방법 1: all-gather
가장 단순한 방법은 모든 GPU가 전체 K/V를 모으는 것이다.
K0/V0, K1/V1, K2/V2, K3/V3
-> all-gather
-> every GPU has full K/V
이 방식은 이해하기 쉽다. 하지만 attention 계산 중에 전체 K/V를 임시로 저장해야 하므로 peak memory가 커진다.
CP를 쓰는 목적이 long context memory를 낮추는 것이므로, all-gather 방식은 그 장점을 약하게 만든다.
방법 2: Ring Attention
Ring Attention은 K/V를 한 번에 다 모으지 않는다. 각 GPU가 K/V chunk를 하나씩 받아가며 attention 결과를 누적한다.
1. 내 Q는 고정한다
2. 현재 가진 K/V chunk로 attention 일부를 계산한다
3. K/V chunk를 다음 GPU로 보낸다
4. 이전 GPU에서 다음 K/V chunk를 받는다
5. 모든 K/V chunk를 볼 때까지 반복한다
이때 중요한 수학적 도구가 online softmax다.
softmax는 전체 score의 분모가 필요하므로 chunk별 softmax를 따로 계산해서 더하면 틀린다.
softmax(score_chunk_0) + softmax(score_chunk_1) # wrong
하지만 online softmax를 쓰면 chunk를 하나씩 보면서도 전체 attention과 같은 결과를 만들 수 있다.
m = 지금까지 본 score max
l = 지금까지 본 exp(score - m)의 누적 합
o = 지금까지 누적한 output
새 K/V chunk가 올 때마다 m, l, o를 갱신한다. 모든 chunk를 처리하면 full K/V를 한 번에 모아서 계산한 것과 같은 attention output이 나온다.
FlashAttention과 닮은 점
Ring Attention은 FlashAttention과 같은 핵심 아이디어를 공유한다.
FlashAttention:
한 GPU 안에서 K/V block을 나눠 읽고 online softmax로 누적
Ring Attention:
여러 GPU 사이에서 K/V chunk를 돌리고 online softmax로 누적
차이는 chunk가 어디서 오느냐다.
FlashAttention: 같은 GPU 메모리의 block
Ring Attention: 다른 GPU가 가진 K/V chunk
causal attention의 불균형
causal attention에서는 앞쪽 token과 뒤쪽 token의 계산량이 다르다.
초반 token: 볼 수 있는 과거 token이 적음
후반 token: 볼 수 있는 과거 token이 많음
sequence를 순서대로 나누면 앞쪽 GPU는 일이 적고 뒤쪽 GPU는 일이 많아질 수 있다.
GPU 0: early tokens -> less work
GPU 3: late tokens -> more work
그래서 Zig-Zag Attention처럼 각 GPU에 앞쪽 token과 뒤쪽 token을 섞어 배치해서 계산량을 맞추는 구현이 필요할 수 있다.
정리
CP의 본질은 sequence length 문제를 직접 나누는 것이다.
CP:
sequence/token dimension을 모델 전체에서 shard
MLP/LayerNorm은 token chunk별로 계산
Attention은 K/V chunk 교환이 필요
Ring Attention은 K/V를 chunk 단위로 돌리며 online softmax로 누적
SP와 CP의 차이는 적용 범위다.
SP:
TP의 activation memory 보완
TP region 바깥에서 sequence shard
CP:
long context를 위한 병렬화 축
모델 전체에서 sequence shard
attention K/V 통신이 핵심
확인
- CP는 SP와 달리 sequence split을 어디에 적용하는가?
- MLP와 LayerNorm은 왜 CP에서 상대적으로 쉬운가?
- attention에서는 왜 다른 GPU의 K/V가 필요한가?
- all-gather 방식의 장점과 단점은 무엇인가?
- Ring Attention이 online softmax를 필요로 하는 이유는 무엇인가?