We propose an algorithm for meta-learning that is model-agnostic, in the sense that it is compatible with any model trained with gradient descent and applicable to a variety of different learning problems, including classification, regression, and reinforcement learning. The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples. In our approach, the parameters of the model are explicitly trained such that a small number of gradient steps with a small amount of training data from a new task will produce good generalization performance on that task. In effect, our method trains the model to be easy to fine-tune. We demonstrate that this approach leads to state-of-the-art performance on two few-shot image classification benchmarks, produces good results on few-shot regression, and accelerates fine-tuning for policy gradient reinforcement learning with neural network policies.
Source: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (2017-03-09). See: paper link.
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.autograd import grad
def mlp(x, params):
h = torch.relu(torch.nn.functional.linear(x, params[0], bias=params[1]))
h = torch.relu(torch.nn.functional.linear(h, params[2], bias=params[3]))
return torch.nn.functional.linear(h, params[4], bias=params[5])
class Task:
def __init__(self, a, b):
self.a = a
self.b = b
def sample(self, K):
x = torch.rand((K, 1)) * 10 - 5 # Sample x in [-5, 5]
y = self.a * torch.sin(x + self.b)
loss_fct = nn.MSELoss()
return x, y, loss_fct
@torch.no_grad()
def sample_task():
a = torch.rand(1).item() * 4.9 + .1 # Sample the amplitude in [0.1, 5.0]
b = torch.rand(1).item() * np.pi # Sample the phase in [0, pi]
return Task(a, b)
def perform_k_training_steps(params, task, batch_size, inner_training_steps, alpha, device='cpu'):
for epoch in range(inner_training_steps):
x_batch, target, loss_fct = task.sample(batch_size)
loss = loss_fct(mlp(x_batch.to(device), params), target.to(device))
for p in params: # Zero grad
p.grad = None
gradients = grad(loss, params)
for p, g in zip(params, gradients): # Grad step
p.data -= alpha * g
return params
def maml(p_model, meta_optimizer, inner_training_steps, nb_epochs, batch_size_K, alpha, nb_tasks=10, device='cpu'):
"""
Algorithm from https://arxiv.org/pdf/1703.03400v3.pdf (MAML for Few-Shot Supervised Learning)
"""
training_loss = []
for epoch in tqdm(range(nb_epochs)): # Line 2 in the pseudocode
theta_i_prime = []
D_i_prime = []
# Sample batch of tasks
tasks = [sample_task() for _ in range(nb_tasks)] # Line 3 in the pseudocode
for task in tasks:
theta_i_prime.append(perform_k_training_steps([p.clone() for p in p_model], task, batch_size_K,
inner_training_steps, alpha, device=device))
# Sample data points Di' for the meta-update (line 8 in the pseudocode)
x, y, loss_fct = task.sample(25)
D_i_prime.append((x, y, loss_fct))
# Meta update
meta_optimizer.zero_grad()
batch_training_loss = []
for i in range(nb_tasks):
x, y, loss_fct = D_i_prime[i]
f_theta_prime = theta_i_prime[i]
# Compute \nabla_theta L(f_theta_i_prime) for task ti
loss = loss_fct(mlp(x.to(device), f_theta_prime), y.to(device))
loss.backward()
batch_training_loss.append(loss.item())
meta_optimizer.step() # Line 10 in the pseudocode
training_loss.append(np.mean(batch_training_loss))
return training_loss
if __name__ == "__main__":
device = 'cuda'
params = [torch.rand(40, 1, device=device).uniform_(-np.sqrt(6. / 41), np.sqrt(6. / 41)).requires_grad_(),
torch.zeros(40, device=device).requires_grad_(),
torch.rand(40, 40, device=device).uniform_(-np.sqrt(6. / 80), np.sqrt(6. / 80)).requires_grad_(),
torch.zeros(40, device=device).requires_grad_(),
torch.rand(1, 40, device=device).uniform_(-np.sqrt(6. / 41), np.sqrt(6. / 41)).requires_grad_(),
torch.zeros(1, device=device).requires_grad_()]
device = 'cuda'
meta_optimizer = torch.optim.Adam(params, lr=1e-3)
training_loss = maml(params, meta_optimizer, 1, 70_000, 10, 1e-3, device=device, nb_tasks=10)
plt.title('MAML, K=10')
x = torch.linspace(-5, 5, 50).to(device)
y = mlp(x[..., None], params)
plt.plot(x.data.cpu().numpy(), y.data.cpu().numpy(), c='lightgreen', linestyle='--', linewidth=2.2,
label='pre-update')
# New task
task = sample_task()
ground_truth_y = task.a * torch.sin(x + task.b)
plt.plot(x.data.cpu().numpy(), ground_truth_y.data.cpu().numpy(), c='red', label='ground truth')
# Fine-tuning, 10 gradient steps
new_params = perform_k_training_steps([p.clone() for p in params], task, 10, 10, 1e-3, device=device)
# After 10 gradient steps
y = mlp(x[..., None], new_params)
plt.plot(x.data.cpu().numpy(), y.data.cpu().numpy(), c='darkgreen', linestyle='--', linewidth=2.2,
label='10 grad step')
plt.legend()
plt.savefig('maml.png')
plt.show()
python implementation Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks in 100 lines
2017-03-09