Context Parallel Ring Attention
마지막 수정:
이 카드는 CP 이론이 아니라 ring attention 구현을 읽는 법을 다룬다.
sequence length s -> s / CP per GPU Q0 K0 V0 Q1 K1 V1 Q2 K2 V2 Q3 K3 V3 local tokens GPU 1
local tokens GPU 2
local tokens GPU 3
local tokens
masked for causal rows when future
Every GPU materializes full K/V before attention.
Compute with current K/V chunk while the next chunk is in flight.
Picotron의 CP 구현은 작지만 핵심이 잘 보인다.
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k = comm.send_recv(k)
next_v = comm.send_recv(v)
comm.commit()
if not is_causal or step <= comm.rank:
block_out, block_lse = ring_attention_forward(q, k, v, sm_scale, ...)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
여기서 local q는 고정된다. 움직이는 것은 k/v chunk다.
rank 0: Q0 고정, K0/V0 -> K3/V3 -> K2/V2 ...
rank 1: Q1 고정, K1/V1 -> K0/V0 -> K3/V3 ...
rank 2: Q2 고정, K2/V2 -> K1/V1 -> K0/V0 ...
Picotron의 ContextCommunicate는 이 이동을 send_recv로 표현한다.
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=cp_group)
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=cp_group)
dist.batch_isend_irecv([send_operation, recv_operation])
TP의 collective와 다르게, ring attention은 각 rank가 이웃 rank와 K/V chunk를 교환하면서 attention을 누적한다.
왜 online softmax가 필요한가
Attention은 chunk별로 softmax를 따로 계산해서 더하면 틀린다.
softmax([scores for K0]) V0 + softmax([scores for K1]) V1 # wrong
softmax의 분모는 전체 score를 기준으로 정해져야 한다. 그래서 Picotron은 block_out, block_lse를 누적한다.
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
lse는 log-sum-exp다. 이미 본 chunk의 정규화 상수와 새 chunk의 정규화 상수를 합쳐서, full K/V를 한 번에 본 것과 같은 output을 만든다.
current_lse = logsumexp(previous scores)
block_lse = logsumexp(new chunk scores)
new_lse = logsumexp(current_lse, block_lse)
따라서 CP의 attention은 “approximation”이 아니다. 올바르게 구현하면 full attention과 같은 값을 만든다.
causal attention에서는 future chunk를 건너뛴다
Picotron forward에는 이 조건이 있다.
if not is_causal or step <= comm.rank:
...
causal attention에서 rank 2의 query는 rank 3의 future token chunk를 보면 안 된다. 그래서 자기보다 미래인 chunk는 계산하지 않는다.
rank 0 Q: K0만 필요
rank 1 Q: K0, K1 필요
rank 2 Q: K0, K1, K2 필요
rank 3 Q: K0, K1, K2, K3 필요
이 구조 때문에 naive contiguous sequence split은 load imbalance를 만든다. 뒤쪽 rank일수록 더 많은 과거 K/V chunk를 본다. Megatron 문서가 seq_length % (2 * context_parallel_size) == 0 같은 제약과 zig-zag 배치를 강조하는 이유가 여기에 있다.
backward는 dK/dV도 ring으로 되돌린다
Forward에서는 K/V chunk가 rank들을 돈다. Backward에서는 local query가 만든 dK/dV contribution을 원래 K/V owner에게 되돌려야 한다.
Picotron은 communication object를 둘로 나눈다.
kv_comm = ContextCommunicate("kv_comm")
d_kv_comm = ContextCommunicate("d_kv_comm")
루프 안에서는 현재 K/V chunk로 dQ, dK, dV를 계산하고, dK/dV를 다시 ring으로 전달해 owner 쪽 gradient가 합쳐지게 한다.
forward ring: K/V chunk moves to query owners
backward ring: dK/dV contribution moves back to K/V owners
Nanotron은 FlashAttention varlen 위에 ring을 얹는다
Nanotron의 ring_attention.py는 Picotron과 같은 구조를 갖지만 block attention 계산을 FlashAttention varlen kernel로 바꾼다.
next_k, next_v = comm.send_recv_kv(k, v)
outputs = _flash_attn_varlen_forward(...)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
즉 Nanotron에서 배울 점은 ring attention의 수학보다 integration이다.
Picotron: PyTorch matmul로 원리 확인
Nanotron: FlashAttention varlen + ring communication
또 llama3_ring_attention.py는 sequence packing과 head stride 같은 실제 모델 조건이 들어오면 K/V slice 계산이 얼마나 복잡해지는지 보여준다.
Megatron은 CP를 attention backend와 topology 문제로 다룬다
Megatron은 CP를 직접 작은 Python ring attention으로 보여주기보다 Transformer Engine attention에 CP group과 communication type을 넘긴다.
extra_kwargs["cp_group"] = pg_collection.cp
extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks(pg_collection.cp)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
extra_kwargs["cp_comm_type"] = cp_comm_type
cp_comm_type에는 p2p, a2a, a2a+p2p 같은 선택지가 있다. 특히 a2a+p2p는 hierarchical context parallel group이 필요하다.
Picotron: ring p2p를 직접 구현
Nanotron: ring + FlashAttention integration
Megatron: TE backend + cp_comm_type + hierarchical CP
이 차이는 규모의 차이다. 작은 코드에서는 ring이 핵심이고, production에서는 어떤 topology에서 어떤 communication path를 선택할지가 핵심이 된다.
실습
cp_ring_attention_sim.py는 full causal attention과 chunked online attention을 비교한다.
python3 labs/large-scale-training-parallelism/cp_ring_attention_sim.py --cp 4 --tokens-per-rank 3 --rank 2
예상 출력은 다음 형태다.
cp=4, tokens_per_rank=3, rank=2
rank 2 receives K/V chunks: [0, 1, 2]
max_abs_diff(full, chunked_online)=...
이 실습은 GPU 통신을 실행하지 않는다. 대신 CP rank 하나의 query가 여러 K/V chunk를 순서대로 보면서 online softmax로 full attention과 같은 결과를 만드는지 검증한다.
읽는 순서
CP 구현은 다음 순서로 읽으면 된다.
- Picotron
ContextCommunicate에서 ring의 send rank와 recv rank를 확인한다. ring_attention.forward에서 local Q는 고정되고 K/V가 회전한다는 점을 본다.update_out_and_lse가 chunk별 결과를 full softmax 결과로 병합하는 방식을 읽는다.- backward에서
d_kv_comm이 왜 따로 필요한지 확인한다. - Nanotron에서 FlashAttention varlen과 sequence packing이 붙는 지점을 본다.
- Megatron에서
cp_comm_type, hierarchical CP, Transformer Engine integration을 확인한다.
CP를 이해했다는 신호는 “sequence를 나눈다”가 아니다. K/V가 어느 방향으로 움직이고, dK/dV가 어디로 돌아가며, online softmax가 왜 정확성을 보장하는지를 말할 수 있어야 한다.