Activation Checkpointing

trainingmemorycheckpointing

Activation checkpointing은 모든 activation을 저장하지 않고 일부만 저장하는 기법이다.

A0
A1
A2
A3
A4

저장: 일부 activation만 checkpoint로 남긴다.

재계산: backward 때 버린 activation을 다시 forward 계산으로 복원한다.

일반적인 학습은 forward 중에 많은 activation을 저장한다.

A0, A1, A2, A3, A4 모두 저장

checkpointing을 쓰면 일부만 저장한다.

A0, A2, A4만 저장

그 대신 backward 때 A1, A3 같은 중간 activation이 필요해지면, 가까운 checkpoint에서 다시 forward 계산을 수행해 복원한다.

trade-off

얻는 것:

  • peak activation memory 감소
  • 더 긴 sequence나 더 큰 batch를 학습할 가능성

잃는 것:

  • backward 중 재계산이 들어가므로 compute 증가
  • step time이 늘어날 수 있음

이름 때문에 gradient checkpointing이라고 부르는 경우도 많지만, 실제로 크게 줄이는 것은 gradient memory가 아니라 activation memory다.

확인

  • checkpointing은 어떤 텐서를 덜 저장하는가?
  • backward 때 저장하지 않은 activation은 어떻게 복원하는가?
  • checkpointing의 핵심 trade-off는 무엇인가?