Activation Checkpointing
Activation checkpointing은 모든 activation을 저장하지 않고 일부만 저장하는 기법이다.
A0
A1
A2
A3
A4
저장: 일부 activation만 checkpoint로 남긴다.
재계산: backward 때 버린 activation을 다시 forward 계산으로 복원한다.
일반적인 학습은 forward 중에 많은 activation을 저장한다.
A0, A1, A2, A3, A4 모두 저장
checkpointing을 쓰면 일부만 저장한다.
A0, A2, A4만 저장
그 대신 backward 때 A1, A3 같은 중간 activation이 필요해지면, 가까운 checkpoint에서 다시 forward 계산을 수행해 복원한다.
trade-off
얻는 것:
- peak activation memory 감소
- 더 긴 sequence나 더 큰 batch를 학습할 가능성
잃는 것:
- backward 중 재계산이 들어가므로 compute 증가
- step time이 늘어날 수 있음
이름 때문에 gradient checkpointing이라고 부르는 경우도 많지만, 실제로 크게 줄이는 것은 gradient memory가 아니라 activation memory다.
확인
- checkpointing은 어떤 텐서를 덜 저장하는가?
- backward 때 저장하지 않은 activation은 어떻게 복원하는가?
- checkpointing의 핵심 trade-off는 무엇인가?