Model-Agnostic Meta Learning (MAML)

This is an exploration of the concept of Model-Agnostic Meta Learning (MAML) as described in the 2017 paper Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. MAML is a fascinating algorithm given how there are such few assumptions on the model one chooses to use—the only being trainability by gradient descent.

In the below application we seek to learn a linear approximation to perturbed step data. We define the perturbed step data by sampling observations uniformly from $J=[-x,x]$ for $x \in \mathcal{R}$, translating them step-wise throughout—a translation for each $(k,k+1)$ with $k \in \mathcal{Z} \subset J$—, and adding perturbations. Then we shall use MAML to train a single, task-generic set of weights, defining a meta-objective as the loss on any particular task (training regression problem) after K steps of gradient descent.

Our architecture is a Multilayer Perceptron (MLP) with Rectified Linear (ReLU) activations, two hidden layers, and a Mean Squared Error loss function for our internal objective. We consider a dataset defined on the interval $[-3,3]$, like the following:

In [0]:
import autograd.numpy as np
import autograd as ag
from matplotlib import pyplot as plt

def relu(z):
    return np.maximum(z, 0.)

def net_predict(params, x):
    """Compute the output of a ReLU MLP with 2 hidden layers."""
    H1 = relu(np.outer(x, params['W1']) + params['b1'])
    H2 = relu(np.dot(H1, params['W2']) + params['b2'])
    return np.dot(H2, params['w3']) + params['b3']

def random_init(std, nhid):
    return {'W1': np.random.normal(0, std, size=nhid),
            'b1': np.random.normal(0., std, size=nhid),
            'W2': np.random.normal(0., std, size=(nhid,nhid)),
            'b2': np.random.normal(0., std, size=nhid),
            'w3': np.random.normal(0., std, size=nhid),
            'b3': np.random.normal(0., std)}
    
class ToyDataGen:
    """Samples a random piecewise linear function, and then samples noisy
    observations of the function."""
    def __init__(self, xmin, xmax, ymin, ymax, std, num_pieces):
        self.xmin = xmin
        self.xmax = xmax
        self.ymin = ymin
        self.ymax = ymax
        self.std = std
        self.num_pieces = num_pieces
        
    def sample_dataset(self, npts):
        x = np.random.uniform(self.xmin, self.xmax, size=npts)
        heights = np.random.uniform(self.ymin, self.ymax, size=self.num_pieces)
        bins = np.floor((x - self.xmin) / (self.xmax - self.xmin) * self.num_pieces).astype(int)
        y = np.random.normal(heights[bins], self.std)
        return x, y
    
def gd_step(cost, params, lrate):
    """Perform one gradient descent step on the given cost function with learning
    rate lrate. Returns a new set of parameters, and (IMPORTANT) does not modify
    the input parameters."""
    
    grad_cost = ag.grad(cost)
    cost_params = grad_cost(params)
    params_new = {}
    
    params_new['W1'] = params['W1'] - lrate * cost_params['W1']
    params_new['b1'] = params['b1'] - lrate * cost_params['b1']
    params_new['W2'] = params['W2'] - lrate * cost_params['W2']
    params_new['b2'] = params['b2'] - lrate * cost_params['b2']
    params_new['w3'] = params['w3'] - lrate * cost_params['w3']
    params_new['b3'] = params['b3'] - lrate * cost_params['b3']

    return params_new


class InnerObjective:
    """Mean squared error."""
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __call__(self, params):
        return 0.5 * np.mean((self.y - net_predict(params, self.x)) ** 2)
    
class MetaObjective:
    """Mean squared error after some number of gradient descent steps
    on the inner objective."""
    def __init__(self, x, y, inner_lrate, num_steps):
        self.x = x
        self.y = y
        self.inner_lrate = inner_lrate
        self.num_steps = num_steps
        
    def __call__(self, params, return_traj=False):
        """Compute the meta-objective. If return_traj is True, you should return
        a list of the parameters after each update. (This is used for visualization.)"""
        trajectory = [params]
        trajectory_new = {}
        worker = InnerObjective(self.x, self.y)
		
		    # Update parameters and append each update to list
        while len(trajectory) <= self.num_steps:
            trajectory_new = gd_step(worker, trajectory[-1], self.inner_lrate)
            trajectory.append(trajectory_new)

        final_cost = worker(trajectory[-1])

        if return_traj:
            return final_cost, trajectory
        else:
            return final_cost
    
    def visualize(self, params, title, ax):
        _, trajectory = self(params, return_traj=True)
        
        ax.plot(self.x, self.y, 'bx', ms=3.)
        px = np.linspace(XMIN, XMAX, 1000)
        for i, new_params in enumerate(trajectory):
            py = net_predict(new_params, px)
            ax.plot(px, py, 'r-', alpha=(i+1)/len(trajectory))
        ax.set_title(title)


OUTER_LRATE = 0.01
OUTER_STEPS = 12000
INNER_LRATE = 0.1
INNER_STEPS = 5

PRINT_EVERY = 100
DISPLAY_EVERY = 1000

XMIN = -3
XMAX = 3
YMIN = -3
YMAX = 3
NOISE = 0.1
BINS = 6
NDATA = 100

INIT_STD = 0.1
NHID = 50

def train():
    np.random.seed(0)
    data_gen = ToyDataGen(XMIN, XMAX, YMIN, YMAX, NOISE, BINS)
    params = random_init(INIT_STD, NHID)
    fig, ax = plt.subplots(3, 4, figsize=(16, 9))
    plot_id = 0
    
    # Generate validation set
    x_val, y_val = data_gen.sample_dataset(NDATA)
    
    for i in range(OUTER_STEPS):
       
        # Sample regression dataset for training
        x_train, y_train = data_gen.sample_dataset(NDATA)

        # Grad descent
        val_cost = MetaObjective(x_train, y_train, INNER_LRATE, INNER_STEPS)
        new_params = gd_step(val_cost, params, OUTER_LRATE)
        params = new_params
        
        if (i+1) % PRINT_EVERY == 0:
            val_cost = MetaObjective(x_val, y_val, INNER_LRATE, INNER_STEPS)
            print('Iteration %d Meta-objective: %1.3f' % (i+1, val_cost(params)))
        
        #print('Outer cost:', cost(params))
        if (i+1) % DISPLAY_EVERY == 0:
            val_cost.visualize(params, 'Iteration %d' % (i+1), ax.flat[plot_id])
            plot_id += 1

We train our model by calling train() below.

The above program will produce visualizations that show the progression of our learning task on the validation data set. We should see better approximations (red lines below) to the validation set as the training progresses even though we are sampling different regression data sets at every iteration.

In [0]:
train()
Iteration 100 Meta-objective: 1.254
Iteration 200 Meta-objective: 1.226
Iteration 300 Meta-objective: 1.215
Iteration 400 Meta-objective: 1.190
Iteration 500 Meta-objective: 1.177
Iteration 600 Meta-objective: 1.160
Iteration 700 Meta-objective: 1.137
Iteration 800 Meta-objective: 1.124
Iteration 900 Meta-objective: 1.087
Iteration 1000 Meta-objective: 1.056
Iteration 1100 Meta-objective: 1.039
Iteration 1200 Meta-objective: 1.030
Iteration 1300 Meta-objective: 1.018
Iteration 1400 Meta-objective: 1.004
Iteration 1500 Meta-objective: 0.995
Iteration 1600 Meta-objective: 0.986
Iteration 1700 Meta-objective: 0.971
Iteration 1800 Meta-objective: 0.978
Iteration 1900 Meta-objective: 0.960
Iteration 2000 Meta-objective: 0.939
Iteration 2100 Meta-objective: 0.911
Iteration 2200 Meta-objective: 0.891
Iteration 2300 Meta-objective: 0.893
Iteration 2400 Meta-objective: 0.883
Iteration 2500 Meta-objective: 0.881
Iteration 2600 Meta-objective: 0.860
Iteration 2700 Meta-objective: 0.831
Iteration 2800 Meta-objective: 0.824
Iteration 2900 Meta-objective: 0.777
Iteration 3000 Meta-objective: 0.799
Iteration 3100 Meta-objective: 0.774
Iteration 3200 Meta-objective: 0.717
Iteration 3300 Meta-objective: 0.777
Iteration 3400 Meta-objective: 0.721
Iteration 3500 Meta-objective: 0.668
Iteration 3600 Meta-objective: 0.663
Iteration 3700 Meta-objective: 0.671
Iteration 3800 Meta-objective: 0.635
Iteration 3900 Meta-objective: 0.621
Iteration 4000 Meta-objective: 0.599
Iteration 4100 Meta-objective: 0.635
Iteration 4200 Meta-objective: 0.619
Iteration 4300 Meta-objective: 0.570
Iteration 4400 Meta-objective: 0.592
Iteration 4500 Meta-objective: 0.530
Iteration 4600 Meta-objective: 0.525
Iteration 4700 Meta-objective: 0.510
Iteration 4800 Meta-objective: 0.527
Iteration 4900 Meta-objective: 0.545
Iteration 5000 Meta-objective: 0.530
Iteration 5100 Meta-objective: 0.506
Iteration 5200 Meta-objective: 0.618
Iteration 5300 Meta-objective: 0.603
Iteration 5400 Meta-objective: 0.504
Iteration 5500 Meta-objective: 0.499
Iteration 5600 Meta-objective: 0.486
Iteration 5700 Meta-objective: 0.478
Iteration 5800 Meta-objective: 0.644
Iteration 5900 Meta-objective: 0.639
Iteration 6000 Meta-objective: 0.492
Iteration 6100 Meta-objective: 0.548
Iteration 6200 Meta-objective: 0.563
Iteration 6300 Meta-objective: 0.584
Iteration 6400 Meta-objective: 0.562
Iteration 6500 Meta-objective: 0.547
Iteration 6600 Meta-objective: 0.467
Iteration 6700 Meta-objective: 0.467
Iteration 6800 Meta-objective: 0.463
Iteration 6900 Meta-objective: 0.455
Iteration 7000 Meta-objective: 0.424
Iteration 7100 Meta-objective: 0.412
Iteration 7200 Meta-objective: 0.382
Iteration 7300 Meta-objective: 0.430
Iteration 7400 Meta-objective: 0.426
Iteration 7500 Meta-objective: 0.426
Iteration 7600 Meta-objective: 0.415
Iteration 7700 Meta-objective: 0.360
Iteration 7800 Meta-objective: 0.379
Iteration 7900 Meta-objective: 0.413
Iteration 8000 Meta-objective: 0.345
Iteration 8100 Meta-objective: 0.420
Iteration 8200 Meta-objective: 0.352
Iteration 8300 Meta-objective: 0.334
Iteration 8400 Meta-objective: 0.313
Iteration 8500 Meta-objective: 0.334
Iteration 8600 Meta-objective: 0.323
Iteration 8700 Meta-objective: 0.312
Iteration 8800 Meta-objective: 0.316
Iteration 8900 Meta-objective: 0.311
Iteration 9000 Meta-objective: 0.307
Iteration 9100 Meta-objective: 0.297
Iteration 9200 Meta-objective: 0.295
Iteration 9300 Meta-objective: 0.296
Iteration 9400 Meta-objective: 0.289
Iteration 9500 Meta-objective: 0.282
Iteration 9600 Meta-objective: 0.289
Iteration 9700 Meta-objective: 0.278
Iteration 9800 Meta-objective: 0.281
Iteration 9900 Meta-objective: 0.284
Iteration 10000 Meta-objective: 0.299
Iteration 10100 Meta-objective: 0.277
Iteration 10200 Meta-objective: 0.294
Iteration 10300 Meta-objective: 0.271
Iteration 10400 Meta-objective: 0.271
Iteration 10500 Meta-objective: 0.288
Iteration 10600 Meta-objective: 0.284
Iteration 10700 Meta-objective: 0.281
Iteration 10800 Meta-objective: 0.264
Iteration 10900 Meta-objective: 0.270
Iteration 11000 Meta-objective: 0.261
Iteration 11100 Meta-objective: 0.260
Iteration 11200 Meta-objective: 0.258
Iteration 11300 Meta-objective: 0.259
Iteration 11400 Meta-objective: 0.258
Iteration 11500 Meta-objective: 0.261
Iteration 11600 Meta-objective: 0.256
Iteration 11700 Meta-objective: 0.254
Iteration 11800 Meta-objective: 0.251
Iteration 11900 Meta-objective: 0.249
Iteration 12000 Meta-objective: 0.247