메타러닝의 방법론은 크게 3가지로 나눌 수 있다.
(1) Model-based
(2) Metric-based
(3) Optimization-based
이번 시간에는 이 방법론들에 대해 중요한 부분들만 특징적으로 알아보도록 하자
1. Model-Based Meta-Learrning
Model-Based Meta-Learning은 학습 알고리즘 자체를 Model로써 설계하여 새로운 Task에 대한 빠른 Adaptation과 일반화를 가능하게 하는 방법론이다. 다시말하면 Model 자체가 새로운 Task에 빠르게 적응할 수 있는 매커니즘을 학습한다는 의미이다. 여기서 말하는 매커니즘이란 특정 메타 파라미터일 수도 있고, 정보를 담고 있는 특정 형태의 무언가 일 수 있다.
그렇다면 당연하게도 모델링을 얼마나 잘하는지가 성능에 큰 영향을 미치겠지..?
대표적인 모델로는 MANN, SNAIL 등이 있다.
Memory-Augmented Neural Networks(MANN)
MANN은 Neural Network에 External Memory를 추가하여 사용한다. Neural Turing Machine을 Base model로 사용하였으며 External Memory를 사용해서 필요할 때마다 정보를 꺼내쓰도록 구성했다고 보면 이해가 쉽다(RNN 계열의 모델들이 학습이 진행될수록 정보를 잊는 문제 보완) .
그럼 이게 메타러닝에 어떻게 사용되느냐 사실 이게 이해가 좀 어려웠는데, Task간에 전환이 일어나더라도 External Memory에 데이터가 유지된다는 장점을 가진다고 생각하니 메타러닝 학습에 있어 이전 Task들에 대한 정보가 다른 Task를 학습하는데 도움을 줄 수 있겠구나~ 라고 이해했다.
Simple Neural AttentIve meta-Learner(SNAIL)
i를 너무 억지로 붙인거 아닌가
SNAIL은 시간 순서를 가지는 Task에서 데이터를 처리하는데 최적화된 모델이다. Temporal Convolution과 Causal Attention을 사용하는데, Temporal Convolution은 이전의 데이터 context를 고려하여 현재 상태를 예측하는 데 필요한 정보를 모델에 제공한다. Causal Attention은 Task간의 주요 정보를 추출하고 주요한 포인트를 강조하는 역할을 수행한다(Self-Attention 매커니즘).
Temporal Convolution은 한정된 context size에서 병렬화된 방식으로 효율적으로 많은 데이터를 한번에 처리할 수 있고(high-bandwidth access), Causal Attention은 large context를 고려하여 상호 보완적이다.
HyperNetwork
HyperNetwork는 Main Model의 파라미터를 동적으로 생성하거나 조정하는 별도의 네트워크이다. HyperNetwork의 출력이 파라미터로서 Main Model(Target Network)에 전달된다. 이러한 방식을 사용한 모델로는 HyperTransformer라는 모델이 있으며, 여러 도메인에 대해 적용시킬 수 있는 큰 모델을 구성해두고 타겟에 대해서는 작게 구성하도록 디자인하였다.
Black-Box Adaptation
Black-Box Adaptiation이란 모델의 구조나 정보 없이 학습하는 것을 의미한다(Parameter 업데이트 없이 입력만으로 Adaptation하는 것). 모델이 학습될 때 Adaptation이 이루어지지 않고, 입력 데이터의 Context에 의존한다.
대표적인 예시로 GPT-3가 존재하는데, Fine-Tuning 없이도 빠르고 다양한 Task에 쉽게 Adaptation할 수 있지만, 그 능력에 한계가 존재한다는 단점이 있다.
2. Metric-Based Meta-Learning
Metric-Based Meta-Learning은 새로운 Task에 대해 데이터를 분류할 때 Similarity(Distance)를 기준으로 측정을 진행하는 방법이다. 이 방법은 특히 더 적은 데이터 환경에서 유용한 장점을 가진다.
Training Data가 있고 Test Data가 있다고 생각할 때, Test Data와 가장 Distance가 가까운 Data의 Class로 분류한다고 생각하면 된다.
Siamese Networks
두 이미지가 같은 class인지 아닌지를 예측하는 방법으로, Weight를 공유하는 2개의 동일한 Neural Network 구조를 가진다. 두 이미지가 비슷할 수록 0의 값을 가진다.
Matching Networks
Embedding Network와 유사도 측정기가 결합된 구조로, Attention을 기반으로 새로운 데이터와 기존 데이터 간의 유사성을 측정한다.
Prototypical Networks
각 Class의 Prototype(Centroid)를 학습하는 방법으로 새로 들어온 Query 데이터가 어떤 Class와 가장 가까운지를 측정한다. 간단한 Euclidean Distance로 거리를 측정하였다.
Relation Networks
Prototypical Networks에서 간단한 Euclidean Distance를 사용한 것과 달리 Relation을 모델링할 수 있는 학습 가능한 네트워크를 사용하였다. 복잡한 데이터 간의 Relation을 더 잘 학습할 수 있어(Non-Linear) Class간의 경계가 복잡하더라도 유연하게 대처할 수 있다.
GNNs
앞서 소개한 Matching Networks, Prototypical Networks, Relation Networks에 대한 많은 Variation 모델들이 등장하고 있다. 그 중 Relation Networks에서 Class간의 Full Context를 볼 수 있도록 설계한 모델이 Graph Neural Networks이다. Class간의 관계를 확인할 수 있는 구조로 설계되었으며, GNN이 가지는 Node간의 관계 파악 방법을 Node와 Edge와의 관계를 파악하는 방법으로 발전시킨 모델이 Edge-Labeling Graph Neural Networks(EGNN)이다.
3. Optimization-Based Meta-Learning
Optimization-Based Meta-Learning은 모델이 새로운 Task를 학습할 때 효율적인 최적화 과정을 통해 빠르게 적응할 수 있도록 설계된다. Inner Learning Process에서의 Optimization을 얼마나 효율적으로 만들 것인지가 중요하며, 초기 모델 파라미터나 최적화 알고리즘의 구성요소를 학습한다.
그 결과 새로운 Task가 주어졌을 때 모델은 몇 번의 Gradient Descent만으로 해당 Task에 쉽게 적응할 수 있게 된다.
➝ Task-Specific Parameter(Common Knowledge)를 잘 만드는 것이 중요!
이 때 모든 Class에 대해 Optimizer를 만드는 것이 어렵기 때문에 Learnable Optimizer를 활용한다. 간단한 예시로 RNN 기반의 Optimizer를 생각할 수 있는데(최적화과정을 RNN으로 학습), 이 경우 파라미터가 너무 많다는 문제가 생기고 그렇다고 Optimizer를 Coordinate-wise하게 만들면 파라미터간의 Correlation을 고려하지 못한다는 문제가 발생한다.
Model-Agnostic Meta-Learning(MAML)
그래서 등장한 가장 대표적인 메타러닝 방법으로, 메타학습 단계에서 모델 파라미터를 학습하여 새로운 작업에 대해 적은 수의 Gradient Descent만으로 쉽게 적응할 수 있도록 만든다. 이 때 초기 파라미터를 잘 만들기 위해 많은 Task들을 학습하며, 학습이 진행될 수록 Validation Loss가 줄어드는 쪽으로 메타 파라미터 $\theta$를 옮긴다고 생각하면 이해가 쉽다.
하지만, 2차 도함수를 사용하기 때문에 계산 비용이 높다는 단점이 있다.
Reptile
이러한 점을 극복하여 1차 도함수를 사용해 계산 비용을 낮춘 모델이 Reptile이다. MAML에서 원했던 것이 내가 구하고자 하는 $\theta$를 Loss가 줄어드는 쪽으로 이동시키는 것이었기 때문에, 거리를 줄이면 $\theta$가 해당 방향으로 움직일 것이라는게 이 모델의 핵심이다(Task Specific Parameter - Meta Parrameter).
Implicit MAML
Implicit MAML은 MAML ➝ Reptile로 가며 줄어든 gradient의 정보를 살려보고자(2nd Order를 조금이라도 살려보고자) 설계된 모델이다. 각각의 Task마다 Specific하게 Gradient를 계산하며 최적의 Optimal Point 방향에 대해서만 Hessian을 조금 써서 위치 이동을 반영한다.
Meta-SGD
MAML의 확장된 버전으로 Learning rate를 학습가능한 파라미터로 설정하여, Learning rate를 최적화하고 새로운 Task에 대해 빠르게 학습할 수 있게 만든다.
4. Combined Methods
당연하게도 지금까지 소개한 Model-based, Metric-based, Optimization-based 방법을 모두 사용한 방법이 존재한다.
Latent Embedding Optimization(LEO)
(Metric + Optimization) MAML을 확장하여 고차원 Parameter 공간을 효율적으로 탐색한 방법으로, 파라미터를 저차원의 Latent Space로 매핑한 후 효율적인 최적화를 진행하였다.
Embedding 추출(RelationNet) + Latent Optimization + Decoding
Task-Dependent Initialization을 위한 Encoder와 Latent에 Adaptation을 할 수 있는 Decoder를 사용하였는데, Encoder로는 Metric-based 방법인 RelationNet을 활용하였다. 이를 통해 LEO는 Latent Space에서 데이터의 관계를 명확히 이해하고 최적화 할 수 있도록 만든다.
본 개념 정리 글은 제가 수업을 듣고 이해한 내용과 인터넷 검색을 통해 찾은 정보를 바탕으로 작성되었습니다. 잘못된 개념이 있다면 언제든 알려주시면 감사하겠습니다.
'AI & CS 지식 > 메타러닝' 카테고리의 다른 글
[메타러닝] 4. Bayesian Meta-Learning (0) | 2024.12.16 |
---|---|
[메타러닝] 3. Meta Reinforcement Learning (1) | 2024.12.16 |
[메타러닝] 1. Meta-Learning이란? (0) | 2024.12.13 |
댓글