This is an exploration of the concept of Model-Agnostic Meta Learning (MAML) as described in the 2017 paper Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. MAML is a fascinating algorithm given how there are such few assumptions on the model one chooses to use—the only being trainability by gradient descent.
In the below application we seek to learn a linear approximation to perturbed step data. We define the perturbed step data by sampling observations uniformly from $J=[-x,x]$ for $x \in \mathcal{R}$, translating them step-wise throughout—a translation for each $(k,k+1)$ with $k \in \mathcal{Z} \subset J$—, and adding perturbations. Then we shall use MAML to train a single, task-generic set of weights, defining a meta-objective as the loss on any particular task (training regression problem) after K steps of gradient descent.
Our architecture is a Multilayer Perceptron (MLP) with Rectified Linear (ReLU) activations, two hidden layers, and a Mean Squared Error loss function for our internal objective. We consider a dataset defined on the interval $[-3,3]$, like the following:
import autograd.numpy as np
import autograd as ag
from matplotlib import pyplot as plt
def relu(z):
return np.maximum(z, 0.)
def net_predict(params, x):
"""Compute the output of a ReLU MLP with 2 hidden layers."""
H1 = relu(np.outer(x, params['W1']) + params['b1'])
H2 = relu(np.dot(H1, params['W2']) + params['b2'])
return np.dot(H2, params['w3']) + params['b3']
def random_init(std, nhid):
return {'W1': np.random.normal(0, std, size=nhid),
'b1': np.random.normal(0., std, size=nhid),
'W2': np.random.normal(0., std, size=(nhid,nhid)),
'b2': np.random.normal(0., std, size=nhid),
'w3': np.random.normal(0., std, size=nhid),
'b3': np.random.normal(0., std)}
class ToyDataGen:
"""Samples a random piecewise linear function, and then samples noisy
observations of the function."""
def __init__(self, xmin, xmax, ymin, ymax, std, num_pieces):
self.xmin = xmin
self.xmax = xmax
self.ymin = ymin
self.ymax = ymax
self.std = std
self.num_pieces = num_pieces
def sample_dataset(self, npts):
x = np.random.uniform(self.xmin, self.xmax, size=npts)
heights = np.random.uniform(self.ymin, self.ymax, size=self.num_pieces)
bins = np.floor((x - self.xmin) / (self.xmax - self.xmin) * self.num_pieces).astype(int)
y = np.random.normal(heights[bins], self.std)
return x, y
def gd_step(cost, params, lrate):
"""Perform one gradient descent step on the given cost function with learning
rate lrate. Returns a new set of parameters, and (IMPORTANT) does not modify
the input parameters."""
grad_cost = ag.grad(cost)
cost_params = grad_cost(params)
params_new = {}
params_new['W1'] = params['W1'] - lrate * cost_params['W1']
params_new['b1'] = params['b1'] - lrate * cost_params['b1']
params_new['W2'] = params['W2'] - lrate * cost_params['W2']
params_new['b2'] = params['b2'] - lrate * cost_params['b2']
params_new['w3'] = params['w3'] - lrate * cost_params['w3']
params_new['b3'] = params['b3'] - lrate * cost_params['b3']
return params_new
class InnerObjective:
"""Mean squared error."""
def __init__(self, x, y):
self.x = x
self.y = y
def __call__(self, params):
return 0.5 * np.mean((self.y - net_predict(params, self.x)) ** 2)
class MetaObjective:
"""Mean squared error after some number of gradient descent steps
on the inner objective."""
def __init__(self, x, y, inner_lrate, num_steps):
self.x = x
self.y = y
self.inner_lrate = inner_lrate
self.num_steps = num_steps
def __call__(self, params, return_traj=False):
"""Compute the meta-objective. If return_traj is True, you should return
a list of the parameters after each update. (This is used for visualization.)"""
trajectory = [params]
trajectory_new = {}
worker = InnerObjective(self.x, self.y)
# Update parameters and append each update to list
while len(trajectory) <= self.num_steps:
trajectory_new = gd_step(worker, trajectory[-1], self.inner_lrate)
trajectory.append(trajectory_new)
final_cost = worker(trajectory[-1])
if return_traj:
return final_cost, trajectory
else:
return final_cost
def visualize(self, params, title, ax):
_, trajectory = self(params, return_traj=True)
ax.plot(self.x, self.y, 'bx', ms=3.)
px = np.linspace(XMIN, XMAX, 1000)
for i, new_params in enumerate(trajectory):
py = net_predict(new_params, px)
ax.plot(px, py, 'r-', alpha=(i+1)/len(trajectory))
ax.set_title(title)
OUTER_LRATE = 0.01
OUTER_STEPS = 12000
INNER_LRATE = 0.1
INNER_STEPS = 5
PRINT_EVERY = 100
DISPLAY_EVERY = 1000
XMIN = -3
XMAX = 3
YMIN = -3
YMAX = 3
NOISE = 0.1
BINS = 6
NDATA = 100
INIT_STD = 0.1
NHID = 50
def train():
np.random.seed(0)
data_gen = ToyDataGen(XMIN, XMAX, YMIN, YMAX, NOISE, BINS)
params = random_init(INIT_STD, NHID)
fig, ax = plt.subplots(3, 4, figsize=(16, 9))
plot_id = 0
# Generate validation set
x_val, y_val = data_gen.sample_dataset(NDATA)
for i in range(OUTER_STEPS):
# Sample regression dataset for training
x_train, y_train = data_gen.sample_dataset(NDATA)
# Grad descent
val_cost = MetaObjective(x_train, y_train, INNER_LRATE, INNER_STEPS)
new_params = gd_step(val_cost, params, OUTER_LRATE)
params = new_params
if (i+1) % PRINT_EVERY == 0:
val_cost = MetaObjective(x_val, y_val, INNER_LRATE, INNER_STEPS)
print('Iteration %d Meta-objective: %1.3f' % (i+1, val_cost(params)))
#print('Outer cost:', cost(params))
if (i+1) % DISPLAY_EVERY == 0:
val_cost.visualize(params, 'Iteration %d' % (i+1), ax.flat[plot_id])
plot_id += 1
We train our model by calling train()
below.
The above program will produce visualizations that show the progression of our learning task on the validation data set. We should see better approximations (red lines below) to the validation set as the training progresses even though we are sampling different regression data sets at every iteration.
train()
Iteration 100 Meta-objective: 1.254 Iteration 200 Meta-objective: 1.226 Iteration 300 Meta-objective: 1.215 Iteration 400 Meta-objective: 1.190 Iteration 500 Meta-objective: 1.177 Iteration 600 Meta-objective: 1.160 Iteration 700 Meta-objective: 1.137 Iteration 800 Meta-objective: 1.124 Iteration 900 Meta-objective: 1.087 Iteration 1000 Meta-objective: 1.056 Iteration 1100 Meta-objective: 1.039 Iteration 1200 Meta-objective: 1.030 Iteration 1300 Meta-objective: 1.018 Iteration 1400 Meta-objective: 1.004 Iteration 1500 Meta-objective: 0.995 Iteration 1600 Meta-objective: 0.986 Iteration 1700 Meta-objective: 0.971 Iteration 1800 Meta-objective: 0.978 Iteration 1900 Meta-objective: 0.960 Iteration 2000 Meta-objective: 0.939 Iteration 2100 Meta-objective: 0.911 Iteration 2200 Meta-objective: 0.891 Iteration 2300 Meta-objective: 0.893 Iteration 2400 Meta-objective: 0.883 Iteration 2500 Meta-objective: 0.881 Iteration 2600 Meta-objective: 0.860 Iteration 2700 Meta-objective: 0.831 Iteration 2800 Meta-objective: 0.824 Iteration 2900 Meta-objective: 0.777 Iteration 3000 Meta-objective: 0.799 Iteration 3100 Meta-objective: 0.774 Iteration 3200 Meta-objective: 0.717 Iteration 3300 Meta-objective: 0.777 Iteration 3400 Meta-objective: 0.721 Iteration 3500 Meta-objective: 0.668 Iteration 3600 Meta-objective: 0.663 Iteration 3700 Meta-objective: 0.671 Iteration 3800 Meta-objective: 0.635 Iteration 3900 Meta-objective: 0.621 Iteration 4000 Meta-objective: 0.599 Iteration 4100 Meta-objective: 0.635 Iteration 4200 Meta-objective: 0.619 Iteration 4300 Meta-objective: 0.570 Iteration 4400 Meta-objective: 0.592 Iteration 4500 Meta-objective: 0.530 Iteration 4600 Meta-objective: 0.525 Iteration 4700 Meta-objective: 0.510 Iteration 4800 Meta-objective: 0.527 Iteration 4900 Meta-objective: 0.545 Iteration 5000 Meta-objective: 0.530 Iteration 5100 Meta-objective: 0.506 Iteration 5200 Meta-objective: 0.618 Iteration 5300 Meta-objective: 0.603 Iteration 5400 Meta-objective: 0.504 Iteration 5500 Meta-objective: 0.499 Iteration 5600 Meta-objective: 0.486 Iteration 5700 Meta-objective: 0.478 Iteration 5800 Meta-objective: 0.644 Iteration 5900 Meta-objective: 0.639 Iteration 6000 Meta-objective: 0.492 Iteration 6100 Meta-objective: 0.548 Iteration 6200 Meta-objective: 0.563 Iteration 6300 Meta-objective: 0.584 Iteration 6400 Meta-objective: 0.562 Iteration 6500 Meta-objective: 0.547 Iteration 6600 Meta-objective: 0.467 Iteration 6700 Meta-objective: 0.467 Iteration 6800 Meta-objective: 0.463 Iteration 6900 Meta-objective: 0.455 Iteration 7000 Meta-objective: 0.424 Iteration 7100 Meta-objective: 0.412 Iteration 7200 Meta-objective: 0.382 Iteration 7300 Meta-objective: 0.430 Iteration 7400 Meta-objective: 0.426 Iteration 7500 Meta-objective: 0.426 Iteration 7600 Meta-objective: 0.415 Iteration 7700 Meta-objective: 0.360 Iteration 7800 Meta-objective: 0.379 Iteration 7900 Meta-objective: 0.413 Iteration 8000 Meta-objective: 0.345 Iteration 8100 Meta-objective: 0.420 Iteration 8200 Meta-objective: 0.352 Iteration 8300 Meta-objective: 0.334 Iteration 8400 Meta-objective: 0.313 Iteration 8500 Meta-objective: 0.334 Iteration 8600 Meta-objective: 0.323 Iteration 8700 Meta-objective: 0.312 Iteration 8800 Meta-objective: 0.316 Iteration 8900 Meta-objective: 0.311 Iteration 9000 Meta-objective: 0.307 Iteration 9100 Meta-objective: 0.297 Iteration 9200 Meta-objective: 0.295 Iteration 9300 Meta-objective: 0.296 Iteration 9400 Meta-objective: 0.289 Iteration 9500 Meta-objective: 0.282 Iteration 9600 Meta-objective: 0.289 Iteration 9700 Meta-objective: 0.278 Iteration 9800 Meta-objective: 0.281 Iteration 9900 Meta-objective: 0.284 Iteration 10000 Meta-objective: 0.299 Iteration 10100 Meta-objective: 0.277 Iteration 10200 Meta-objective: 0.294 Iteration 10300 Meta-objective: 0.271 Iteration 10400 Meta-objective: 0.271 Iteration 10500 Meta-objective: 0.288 Iteration 10600 Meta-objective: 0.284 Iteration 10700 Meta-objective: 0.281 Iteration 10800 Meta-objective: 0.264 Iteration 10900 Meta-objective: 0.270 Iteration 11000 Meta-objective: 0.261 Iteration 11100 Meta-objective: 0.260 Iteration 11200 Meta-objective: 0.258 Iteration 11300 Meta-objective: 0.259 Iteration 11400 Meta-objective: 0.258 Iteration 11500 Meta-objective: 0.261 Iteration 11600 Meta-objective: 0.256 Iteration 11700 Meta-objective: 0.254 Iteration 11800 Meta-objective: 0.251 Iteration 11900 Meta-objective: 0.249 Iteration 12000 Meta-objective: 0.247