2024년 9월에 스터디에서 진행한 논문 리뷰를 기록용 포스트로 남겨둡니다.
원문: https://arxiv.org/pdf/2403.08763
연구 목적/동기
- Continual Pretraining 시 새로운 데이터로 인해 distribution shift가 일어나, 이전 데이터에 대한 성능 저하가 일어나거나, 새로운 데이터에 적응을 잘 못시키는 문제
- 몇 가지 훈련 테크닉으로 scratch로 훈련시키는 것과 맞먹는 성능을 낼 수 있다.
연구 방법
데이터셋 (train/val)
- SlimPajama: llama 데이터셋인 RedPajama의 중복 제거하고 퀄리티 up (300B token dataset)
- German CommonCrawl (~200B token dataset)
- Pile: 800GB dataset for language modeling
데이터 셋팅 (D_0→ D_1)
- two datasets, weak shift (Pile → SlimPajama)
- 둘 다 영어, domain 자체도 겹치는 부분이 많음
- 퀄리티가 더 좋은 데이터(SlimPajama)를 만들어서 업데이트하고자하는 경우에 대한 테스트 (많은 실무자들이 관심있어할 것)
- two dataset, stronger shift (Pile → German Common Crawl)
- 독일어 데이터로 추가 훈련
- 도메인이 다르거나, 새로운 언어, 코드 등 사전 학습 당시 사용한 단어가 다른 경우에 대한 테스트
- three datasets, no shift (SlimPajama 100B → SlimPajama 100B → SlimPajama 100B)
- 논문에서 제안된 테크닉들을 추가적으로 더 훈련할 때의 유효성을 테스트하기 위함
- Domain incremental continual pretraining
- SlimPajama을 도메인 별로 나누어서, 도메인을 차례대로 훈련 $\{D_0, D_1, ..., D_{N-1}\}$
- 도메인이 바뀔 때마다 distribution shift가 나타남
- 결과가 구린데, general-purpose LLM은 domain마다 업데이트를 하는 것이 아니라, domain을 섞어서 continual pretrain이 되어야한다는 것을 시사.
실험 셋팅
- backbone 모델: GPT-NeoX 405M, 10B
- Pile 데이터셋에 BPE algorithm으로 훈련된 토크나이저 (?)
실험 결과
- 보통의 잘 나가는 LLM: linear warmup + cosine decay schedule + low minimum learning rate
- 이 연구의 hypothesis: 비교적 높은 값으로 lr을 re-warming 한 후, 새로운 데이터에 효율적으로 적응될 수 있도록 re-decaying을 진행한다
실험1: Linear warmup 의 영향
- training iteration의 0.5%, 1%, 2%로 warmup 하도록 실험
- linear warmup 이 짧을 수록 빠르게 LR을 올리기 때문에, 빠르게 forget, adapt함 (아예 warmup을 안하면 빠르게 잊고, 적응시켜서 혼란의 phase를 겪음)
- 처음에는 좀 다른데 결국엔 비슷하게 forget과 adapt를 하더라
- 결론: linear warm-up의 duration은 처음 loss에서 스파이크 나는 것 외에는 그다지 영향이 없더라
실험2: Re-warming, re-decaying
- Setting
- re-warm, re-decay / no re-warm with $η_{min}$ / re-warm with $η_{max}$ but no re-decay
- 사전학습 LR의 반값 / 같은 값 / 2배값 (linear warmup → cosine decay)
- 보라색(constant min)이 제일 적게 잊어버리면서, 적응이 안됨
- 빨,주,노(rewarm, redecay)가 새로운 데이터에 적응함 → 적응을 위해서는 필수적이다.
- 높은 lr max → 적응이 잘되고 잘 잊어버림
- 분포 shift가 크면, 더 빨리 잊어버리고, 적응함
- 낮은 lr max로 설정하면 사전학습 지식을 제일 덜 잊어버리고, 다른 lr max과 비슷하게 새로운 데이터셋에 피팅되기 때문에 제일 낫지 않나…. 라는 생각
실험3: Compute-equivalent Replay
- Setting
- D0에서 replay 토큰들을 가져오고, 같은 숫자의 토큰들을 D1에서 제거하여 훈련
- 0.5, 1, 5, 10, 50% 의 replay
- rewarm, redecay와 함께 replay를 해야 forget을 덜함 (adapt-forget trade-off 가 나아짐)
**최종 실험 결과 (한줄 요약)
- rewarm up with η_{max} = 3e-05 (cosine redecay)
- 데이터 분포가 weak shift인 경우: 5% replay
- 데이터 분포가 strong shift인 경우: 25% replay
이렇게 훈련하면 사전 훈련 + 추가 데이터로 scratch full retraining하는 것의 효과를 낼 수 있다~!
연구의 contribution
LR re-warming, LR re-decaying, 이전 데이터의 replay → 이것들의 조합으로 scratch로 fully retraining하는 것만큼의 성능을 보여줬다.