%%time
import numpy as np
n_in, n_h, n_out = 5, 4, 2
W1 = np.random.rand(4, 5)
W2 = np.random.rand(2, 4)
M = 1000 # no of training examples
x_in = np.random.rand(5, M)
y = np.random.rand(2, M)
learning_rate = 1e-6
for t in range(1000):
# forward pass
h = W1.dot(x_in)
relu_h = np.maximum(h, 0)
out = W2.dot(relu_h)
# calculate loss
loss = np.square(out - y).sum()
# backprop
out_grad = 2*(out - y)
W2_grad = out_grad.dot(relu_h.T)
grad_relu_h = W2.T.dot(out_grad)
grad_h = grad_relu_h.copy()
grad_h[h < 0] = 0
W1_grad = grad_h.dot(x_in.T)
# update params
W1 -= learning_rate * W1_grad
W2 -= learning_rate * W2_grad
if t%100==0:
print("loss {0} at run {1}".format(loss, t))
# repeat
loss 9726.30003915 at run 0 loss 814.959949267 at run 100 loss 320.872497012 at run 200 loss 259.944891952 at run 300 loss 249.216935453 at run 400 loss 245.718992076 at run 500 loss 243.451981739 at run 600 loss 241.492220103 at run 700 loss 239.674597278 at run 800 loss 237.9424254 at run 900 CPU times: user 284 ms, sys: 576 ms, total: 860 ms Wall time: 1.06 s
%%time
import torch
from torch.autograd import Variable
W1 = Variable(torch.rand(4, 5), requires_grad = True)
W2 = Variable(torch.rand(2, 4), requires_grad = True)
M = 1000 # no of training examples
x_in = Variable(torch.rand(5, M), requires_grad = False)
y = Variable(torch.rand(2, M), requires_grad = False)
learning_rate = 1e-6
for t in range(1000):
# forward pass
out = W2.mm(W1.mm(x_in).clamp(min = 0))
# calculate loss
loss = (out - y).pow(2).sum()
if t%100==0:
print("loss {0} at run {1}".format(loss.data[0], t))
# backprop
loss.backward()
# param update
W1.data -= learning_rate * W1.grad.data
W2.data -= learning_rate * W2.grad.data
# setting in graph grads zero
W1.grad.data.zero_()
W2.grad.data.zero_()
#repeat
loss 14459.4521484 at run 0 loss 745.285827637 at run 100 loss 260.097869873 at run 200 loss 222.204483032 at run 300 loss 218.247543335 at run 400 loss 217.245193481 at run 500 loss 216.542648315 at run 600 loss 215.900558472 at run 700 loss 215.294845581 at run 800 loss 214.721298218 at run 900 CPU times: user 320 ms, sys: 48 ms, total: 368 ms Wall time: 500 ms