import torch
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(17);
use_gpu = True
n = 100
x = torch.ones(n, 2)
x[:,0].uniform_(-1., 1);
w_y = torch.tensor([3., 2]); w_y
y = x@w_y + torch.rand(n)
y = y[:,None]
plt.scatter(x[:,0], y);
def mse(y_hat, y): return ((y_hat-y)**2).mean(0)
device = torch.device('cuda') if use_gpu and torch.cuda.is_available() else torch.device('cpu')
x = x.to(device)
y = y.to(device)
# initial weights
w0 = torch.rand(2) * 1000 - 500; w0
def train_stepper_sgd(w0, lr, n_epochs=100, min_loss=0.1, verbose=True):
w = w0[:,None].clone().to(device).requires_grad_()
if verbose: print('Epoch\tLoss')
for i in range(n_epochs):
y_hat = x@w
loss = mse(y, y_hat)
loss.backward()
with torch.no_grad():
w -= lr * w.grad
w.grad.zero_()
if verbose: print(f'{i+1}\t{loss.item():.3f}')
if loss.item() < min_loss: break
print(f'Final loss: {loss.item():.3f} in {i+1} epochs.')
lr = 0.3
train_stepper_sgd(w0, lr)
def train_walker_sgd(w0, lrs, n_epochs=100, min_loss=0.1, verbose=True, record=False):
lrs = torch.tensor(lrs, dtype=torch.float32).to(device)
n_lrs = lrs.size(0)
w = w0.repeat(n_lrs, 1).transpose(0,1).clone().to(device).requires_grad_()
if record: rec = []
if verbose: print('Epoch\tLR\tLoss')
for i in range(n_epochs):
w_rec = w.data
y_hat = x@w
losses = mse(y, y_hat)
# identify the best learning rate
bst_lr_idx = losses.argmin()
bst_loss = losses[bst_lr_idx]
bst_loss.backward()
with torch.no_grad():
# take the weights of the best lr and copy them over the others lrs,
# dismissing weights from non-optimal lrs.
w_grad = w.grad[:,bst_lr_idx].repeat(n_lrs, 1).transpose(0,1)
w.data = w.data[:,bst_lr_idx].repeat(n_lrs, 1).transpose(0,1)
w.data -= lrs * w_grad
w.grad.zero_()
if record: rec.append((w_rec.to('cpu'), losses.data.to('cpu'), bst_lr_idx.to('cpu')))
if verbose: print(f'{i+1}\t{lrs[bst_lr_idx].item():.2f}\t{bst_loss.item():.3f}')
if bst_loss.item() < min_loss: break
print(f'Final loss: {bst_loss.item():.3f} in {i+1} epochs.')
if record:
rec = list(zip(*rec))
wgts = np.stack(rec[0])
loss = np.stack(rec[1])
best = np.stack(rec[2])
return wgts, loss, best
lrs = np.arange(0.1, 1, 0.1)
train_walker_sgd(w0, lrs);
from mpl_toolkits import mplot3d
from matplotlib import animation
plt.rc('animation', html='html5')
def loss_wrt_wgts(w1, w2):
w = torch.Tensor([w1, w2]).to(device)
y_hat = x@w
loss = mse(y_hat[:,None], y)
return loss.item()
loss_wgts = np.vectorize(loss_wrt_wgts)
w0_range = np.linspace(-20, 20, 50)
w1_range = np.linspace(-20, 20, 50)
mesh = np.meshgrid(w0_range, w1_range)
loss_mesh = loss_wgts(*mesh)
w0 = torch.tensor([-19., -19])
lrs = np.linspace(0.1, 1, 7)
wgts, loss, best = train_walker_sgd(w0, lrs, record=True)
fig = plt.figure(figsize=(12,10))
ax = plt.axes(projection='3d')
ax.plot_surface(*mesh, loss_mesh, cmap='viridis', alpha=0.8)
line0, = ax.plot3D([], [], [], c='r', marker='o', label='Current walk')
line1, = ax.plot3D([], [], [], c='b', marker='o', label='Learning curve')
line2, = ax.plot3D([], [], [], c='r', marker='*', markersize=20, label='Best step in walk', linewidth=0)
ax.set_xlabel('w0'); ax.set_ylabel('w1'); ax.set_zlabel('Loss')
fig.suptitle(f'"Walker SGD"', fontsize=22)
ax.view_init(30, 20)
ax.legend()
fig.tight_layout()
plt.close()
def animate(i):
if i > 0:
line0.set_data(wgts[i,0], wgts[i,1])
line0.set_3d_properties(loss[i])
line2.set_data(wgts[i,0,best[i]], wgts[i,1,best[i]])
line2.set_3d_properties(loss[i,best[i]])
rng = range(i)
line1.set_data(wgts[rng,0,best[:i]], wgts[rng,1,best[:i]])
line1.set_3d_properties(loss[rng,best[:i]])
return line0, line1, line2
animation.FuncAnimation(fig, animate, 6, interval=1000)
n = 1000000
torch.manual_seed(17);
w0 = torch.rand(2) * 1000 - 500; w0
lrs = np.linspace(0.1, 1, 10)
%time train_walker_sgd(w0, lrs, verbose=False)
lr = 0.1
%time train_stepper_sgd(w0, lr, verbose=False)
lr = 0.9
%time train_stepper_sgd(w0, lr, verbose=False)
lr = 0.5
%time train_stepper_sgd(w0, lr, verbose=False)
lr = 0.7 # optimal learning rate
%time train_stepper_sgd(w0, lr, verbose=False)