Continual Learning이라는 개념을 처음 접했던 건 23년 말 모두랜드라는 행사에 가서 논문 포스터 발표를 들었을 때였다. 그전까지 내가 찾고 있었던 개념이기에 큰 관심이 갔고, 이미지 관련 Continual Learning을 진행한 연구였기에 많은 질문을 드렸던 기억이 난다. Continual Learning은 개인적으로 AGI와 가장 가까운 연구 분야라고 생각하며, 결국 모든 딥러닝 모델이 추구해야 하는 학습 방법론이지 않나 싶다. 개념적인 부분은 알고 있었지만 그 디테일은 몰랐기에 이번 기회에 다시 복습하며 잘 정리해보려 한다.
그렇다면 우선 Continual Learning이란 무엇일까?
→ Continual Learning이란 모델이 시간이 지나면서 새로운 데이터를 학습하더라도 기존에 학습한 내용을 잃지 않도록 하는 학습방법이다.
Continual Learning이 떠오르기 이전에 Online Learning과 Lifelong Learning이라는 방법론들이 존재하였는데 이들은 비슷하면서도 다른 특징이 있다.
- Online Learning: 데이터가 연속적으로 들어올 때 실시간으로 모델을 업데이트(학습)하는 것이 목적으로 과거에 학습한 정보를 유지하는 것은 고려하지 않는다.
- Lifelong Learning: Multitask Learning과 Online Learning의 개념을 결합한 방법론으로 여러 도메인에서 지식을 축적하며 새로운 문제를 해결하는 것에 목적을 둔다.
결국 Continual Learning의 핵심은 이전에 학습한 내용에 대해 망각을 방지하고(Catastrophic Forgetting) 그렇게 얻은 지식을 다른 도메인으로 잘 Transfer 할 수 있는 능력을 키우는 것이다.
Continual Learning을 잘 수행하기 위해 여러 Setting 방법들이 있는데, 이번에는 가장 대표적인 3가지 방법에 대해서만 정리해보도록 하겠다.
1. Domain-Incremental Learning(DIL)
: 같은 Task이지만, 입력 데이터의 도메인이 변화하는 환경에서 학습하는 방법으로 Task들이 같은 Data Label Space를 가지지만 Input Distribution이 다른 경우를 말한다. 이 경우에는 모델이 데이터의 Distribution 변화에 적응해야 한다.
2. Task-Incremental Learning(TIL)
: 서로 다른 Task를 하나씩 추가하며 학습하는 방법으로 Task들이 서로 다른 Data Label Space를 가진다. 때문에 각 Task 별로 다른 모델을 학습하여 현재 어떤 Task를 학습하고 있는지를 안다.
3. Class-Incremental Learning(CIL)
: 새로운 Class를 추가해가며 학습하는 방법으로 Task ID를 모르는 상태에서 새로운 Class에 대한 학습이 이루어진다. 예를 들어 처음 Task에서 2개를 분류하였고 그다음 Task에서 3개를 분류하는 문제가 주어졌을 경우, 학습된 모델은 5가지 Class를 잘 구분할 수 있어야 한다. 따라서 가장 어려운 설정이라고 볼 수 있다.
MNIST 데이터 셋을 가지고 예시를 들어보자. 우선 MNIST 데이터 셋의 Task 구성이 다음과 같다고 가정한다(두 숫자를 구분하는 Task).
- Task 1: Class 0 / Class 1
- Task 2: Class 2 / Class 3
- Task 3: Class 3 / Class 4
- Task 4: Class 5 / Class 6
- Task 5: Class 7 / Class 8
이 때 각각의 Setting에서는 다음과 같은 학습을 진행한다.
- DIL: 모든 Task에 대해 같은 Class 분류를 수행하기 때문에 앞의 Class에 어떤 숫자가 들어오는지, 뒤의 Class에 어떤 숫자가 들어오는지로 생각하여 학습
- TIL: 각 Task를 별도로 학습(0/1 분류, 2/3 분류 등..)
- CIL: Task ID를 모르기 때문에 어떤 숫자 이미지가 들어왔을 때 해당 이미지가 무슨 숫자를 나타내는지를 학습
어떠한 Setting에서든 Continual Learning은 Catastrophic Forgetting을 막는 것이 가장 중요하며, 그러기 위해 Stability-Plasticity 간 Trade-Off를 잘 관리하는 것이 필요하다.
- Stability: 이전 Task에 대한 스킬을 잘 보유할 수 있는 능력
- Plasticity: 새로운 Task에 잘 적응할 수 있는 능력
이러한 능력을 발현하기 위해 Continual Learning에서는 여러 방법론을 가지는데, 크게 Regularization-based, Replay-based, Optimization-based, Representation-based, Architecture-based 방법으로 구분해서 볼 수 있다.
(1) Regularization-based
: 이전에 학습한 지식을 유지하기 위해 Loss Function에 Regularization을 적용하는 방식. 때문에 이전 Task에 대한 정보를 직접적으로 가지고 있을 필요가 없다. → 메모리를 많이 사용하지 않음! But 성능이 그렇게 좋지 않음
ex. EWC(Elastic Weight Consolidation)
(2) Replay-based
: 이전 데이터를 일부 저장하고 있다가 재사용한다. 또는 생성 모델을 활용한 데이터를 만들어 학습에 활용하기도 한다.
ex. Experience Replay, Generative Replay
(3) Optimization-based
: Task별로 적절한 최적화 기법을 적용한다.
(4) Representation-based
: 이전 Task와 새로운 Task가 서로 영향을 덜 미치도록 Representation을 분리한다.
(5) Architecture-based
: 새로운 Task를 학습할 때 새로운 신경망 구조를 추가하여 학습한다. 이 때 Task 마다 다른 모델 파라미터를 사용한다.
※ Continual Learning 평가지표
- Average Accuracy
- Average incremental accuracy
- Backward transfer(BWT): 새로운 Task를 학습한 후 이전 Task의 성능이 얼마나 떨어졌는지
- Forward transfer(FWT): 이전 Task에서 학습한 것이 새로운 Task 학습에 얼마나 도움되는지
(일반적인 기존 Online Learning Problem 평가: Regret Loss)
Regret Loss
최적 모델 대비 현재 Continual Learning 모델의 성능을 확인
Continual Learing에서 사용되는 또다른 Catastrophic Forgetting을 줄이는 방법이 있는데, 이는 새로운 Task를 학습하는 도중 이전 Task의 정보가 손실되지 않도록 Gradient Direction을 조절하는 방법인 Gradient Episodic Memory(GEM)이다.
GEM의 핵심 아이디어는 이전 Task에서 학습한 데이터를 Gradient Constraint로 사용하여 새로운 Task의 업데이트가 기존 정보를 망가뜨리지 않도록 조절하는 것이다.
이 개념이 지금 Titans에서 말하는 앞선 정보를 뒤까지 보내는 방법론과 비슷한 것으로 느껴지는데 어떤 점이 비슷한지 확인해보자.
Gradient Episodic Memory(GEM) 알고리즘
1. Memory Buffer를 사용하여 이전 Task에서 일부 데이터를 저장한다.
2. 새로운 Task를 학습할 때, 기존 Task의 데이터에 대한 Gradient를 계산한다.
3. 새로운 Gradient가 기존 Task의 Gradient 방향과 충돌하지 않도록 조정한다.
여기서부터는 메타 러닝을 결합!
Online Meta-Learning
: 이전 Task를 빠르게 기억하고 잊지 않게 하는 것이 목적. 새로 들어온 t번째 Task를 어떻게 adaptation 할지가 핵심
Meta-Continual Learning vs Continual Meta-Learning
Meta-Continual Learning | Continual Meta-Learning |
Meta-Learning을 활용하여 Continual Learning을 더 잘 수행 | Continual Learning 환경에서 Meta-Learning을 학습 |
새로운 Task를 빠르게 학습 + 망각 방지 | Outer Loop가 Stream of Episode가 되도록 구성 |
Meta-Continual Learning은 결국 하나의 Meta Parameter가 있어 학습을 해가며 하나의 Meta Parameter가 처리하는 Distribution이 점차 늘어난다고 생각하면 좋다.
Continual Learning과 Meta-Learning의 방법론을 합친 OSAKA(Online fast adaptation and knowledge accumulation)이라는 모델도 있으며, 비지도 학습 방법론을 결합한 Continual Unsupervised Representation Learning(CURL) 모델도 존재한다.
결국 Continual Learning은 과거 데이터의 Catastrophic Forgetting을 막으면서 새로운 데이터를 잘 학습시키는게 목적인 학습 방법이다. 시계열적인 흐름을 분석하는데 있어 중요한 학습법으로 생각되며 한 이미지 속에서도 이러한 Context를 고려하는 것이 중요하다는 연구를 관련해서 진행할 수 있지 않을까 생각된다.
본 개념 정리 글은 제가 수업을 듣고 이해한 내용과 인터넷 검색을 통해 찾은 정보를 바탕으로 작성되었습니다. 잘못된 개념이 있다면 언제든 알려주시면 감사하겠습니다.
'AI & CS 지식 > 메타러닝' 카테고리의 다른 글
[메타러닝] 7. Hyperparameter Optimization (0) | 2025.02.02 |
---|---|
[메타러닝] 6. AutoAugment (0) | 2024.12.31 |
[메타러닝] 5. NAS (0) | 2024.12.22 |
댓글