v03 Prefill / Decode Split

마지막 수정:

inferencenano-vllmprefilldecode

LLM inference는 크게 두 단계로 나뉜다.

prefill:
prompt 전체를 처리하고 KV cache를 만든다.

decode:
이전 token과 KV cache를 사용해 다음 token 하나를 만든다.

읽을 코드:

labs/nano-vllm/versions/v03_prefill_decode/run.py

이번 버전은 toy model 안에서도 prefill()decode()를 나눈다.

ToyModel과 LLMEngine.step() labs/nano-vllm/versions/v03_prefill_decode/run.py:39-72

  
      
      class ToyModel:
    
      
          def prefill(self, prompt_tokens: list[int]) -> int:
    
      
              return (sum(prompt_tokens) + len(prompt_tokens) + 1) % VOCAB_SIZE
    
      
       
    
      
          def decode(self, last_token: int, position: int) -> int:
    
      
              return (last_token * 3 + position + 1) % VOCAB_SIZE
    
      
       
    
      
       
    
      
      class LLMEngine:
    
      
          def __init__(self):
    
      
              self.model = ToyModel()
    
      
              self.seqs: list[Sequence] = []
    
      
       
    
      
          def add_request(self, prompt: list[int], max_tokens: int) -> None:
    
      
              self.seqs.append(Sequence(prompt, max_tokens))
    
      
       
    
      
          def is_finished(self) -> bool:
    
      
              return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
    
      
       
    
      
          def step(self) -> list[tuple[int, list[int]]]:
    
      
              outputs = []
    
      
              for seq in self.seqs:
    
      
                  if seq.status == SequenceStatus.FINISHED:
    
      
                      continue
    
      
                  seq.status = SequenceStatus.RUNNING
    
      
                  if seq.is_prefill:
    
      
                      token = self.model.prefill(seq.prompt_token_ids)
    
      
                      seq.is_prefill = False
    
      
                  else:
    
      
                      token = self.model.decode(seq.last_token, len(seq.token_ids))
    
      
                  seq.append_token(token)
    
      
                  if seq.status == SequenceStatus.FINISHED:
    
      
                      outputs.append((seq.seq_id, seq.output_token_ids))
    
      
              return outputs
    

실제 모델에서는 이 차이가 더 중요하다. prefill은 큰 matrix 연산에 가깝고, decode는 작은 batch와 KV cache lookup의 반복에 가깝다.

다음 문제

아직 여러 request를 효율적으로 섞지 못한다.

특히 decode 단계에서는 여러 sequence의 다음 token을 한 batch로 묶어야 GPU를 더 잘 쓸 수 있다.