import numpy as np
from bokeh.plotting import figure, show, output_notebook
def gradient_descent(F, dF, x, steps=100, lr=0.001):
loss = []
for _ in range(steps):
dx = dF(x)
x -= lr * dx
loss.append(F(x))
return x, loss
def rmsprop(F, dF, x, steps=100, lr=0.001, decay=.9, eps=1e-8):
loss = []
dx_mean_sqr = np.zeros(x.shape, dtype=float)
for _ in range(steps):
dx = dF(x)
dx_mean_sqr = decay * dx_mean_sqr + (1 - decay) * dx ** 2
x -= lr * dx / (np.sqrt(dx_mean_sqr) + eps)
loss.append(F(x))
return x, loss
def rmsprop_momentum(F, dF, x, steps=100, lr=0.001, decay=.9, eps=1e-8, mu=.9):
loss = []
dx_mean_sqr = np.zeros(x.shape, dtype=float)
momentum = np.zeros(x.shape, dtype=float)
for _ in range(steps):
dx = dF(x)
dx_mean_sqr = decay * dx_mean_sqr + (1 - decay) * dx ** 2
momentum = mu * momentum + lr * dx / (np.sqrt(dx_mean_sqr) + eps)
x -= momentum
loss.append(F(x))
return x, loss
def F(x):
residual = A @ x - np.eye(len(A), dtype=float)
return np.sum(residual ** 2)
def dF(x):
return 2 * A.T @ (A @ x - np.eye(len(A), dtype=float))
A = np.array([
[2, 5, 1, 4, 6],
[3, 5, 0, 0, 0],
[1, 1, 0, 3, 8],
[6, 6, 2, 2, 1],
[8, 3, 5, 1, 4],
], dtype=float)
X, loss1 = gradient_descent(F, dF, A * 0, steps=300)
(A @ X).round(2), loss1[-1]
(array([[ 0.79, -0.01, 0.18, 0.19, -0.08], [-0.01, 0.8 , 0. , 0.2 , -0.07], [ 0.18, 0. , 0.85, -0.15, 0.07], [ 0.19, 0.2 , -0.15, 0.66, 0.13], [-0.08, -0.07, 0.07, 0.13, 0.95]]), 0.54691984767143453)
X, loss2 = rmsprop(F, dF, A * 0, steps=300)
(A @ X).round(2), loss2[-1]
(array([[ 0.84, -0.05, 0.1 , 0.1 , -0.06], [-0.04, 0.82, 0.03, 0.19, -0.03], [ 0.12, 0.03, 0.9 , -0.08, 0.03], [ 0.15, 0.2 , -0.12, 0.75, 0.06], [-0.08, -0.09, 0.04, 0.1 , 0.97]]), 0.32396954419819657)
X, loss3 = rmsprop_momentum(F, dF, A * 0, steps=300)
(A @ X).round(2), loss3[-1]
(array([[ 0.99, 0.01, 0. , -0.01, 0. ], [-0. , 1. , 0. , -0. , 0. ], [-0. , 0.01, 1. , -0.01, 0. ], [-0.01, 0.01, 0. , 0.99, 0. ], [-0.01, 0.01, 0. , -0.01, 1. ]]), 0.00062303887772378397)
output_notebook()
plot = figure()
plot.line(x=range(len(loss1)), y=loss1, color='steelblue', legend='gd')
plot.line(x=range(len(loss2)), y=loss2, color='green', legend='rmsprop')
plot.line(x=range(len(loss3)), y=loss3, color='red', legend='rmsprop+momentum')
show(plot)