7 x 6 x 5 x 2 의 4계층 신경망 구조로 backpropagation과 numeric gradient, analytic gradient를 검증해본다.
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
np.set_printoptions(suppress=True)
import pandas as pd
import sys
sys.path.append('/Users/kaonpark/workspace/github.com/likejazz/kaon-learn')
import kaonlearn
from kaonlearn.plots import plot_decision_regions, plot_history
def _gradient_check(analytic, numeric):
numerator = abs(analytic - numeric)
denominator = max(analytic, numeric)
if denominator == 0:
print ("Correct!")
else:
difference = numerator / denominator
# cs231n의 권장 수치는 1e-7이나 그 기준을 맞출 수가 없다.
if difference < 1e-7:
print ("Correct!")
else:
print("\x1b[31mWrong!\x1b[0m")
def gradient_checking(nn, l = 3):
nn.__init__()
nn.train()
if l == 1:
w = nn.w_1
elif l == 2:
w = nn.w_2
elif l == 3:
w = nn.w_3
for k in range(w.shape[0]):
for j in range(w.shape[1]):
nn.__init__()
if l == 1:
nn.w_1[k][j] += nn.h
elif l == 2:
nn.w_2[k][j] += nn.h
elif l == 3:
nn.w_3[k][j] += nn.h
nn.query()
e1 = np.sum((nn.t - nn.out_o) ** 2) / 2
nn.__init__()
if l == 1:
nn.w_1[k][j] -= nn.h
elif l == 2:
nn.w_2[k][j] -= nn.h
elif l == 3:
nn.w_3[k][j] -= nn.h
nn.query()
e2 = np.sum((nn.t - nn.out_o) ** 2) / 2
if l == 1:
delta = nn.delta_w_1[k][j]
elif l == 2:
delta = nn.delta_w_2[k][j]
elif l == 3:
delta = nn.delta_w_3[k][j]
numeric_gradient = (e1 - e2) / (2 * nn.h)
# 수치 미분(numeric gradient) 결과가 해석적 미분(analytic gradient)과 동일한지 검증
print("%.16f, %.16f" % (delta, numeric_gradient), end=", ")
_gradient_check(delta, numeric_gradient)
nn.__init__()
if l == 1:
nn.b_1[k] += nn.h
elif l == 2:
nn.b_2[k] += nn.h
elif l == 3:
nn.b_3[k] += nn.h
nn.query()
e1 = np.sum((nn.t - nn.out_o) ** 2) / 2
nn.__init__()
if l == 1:
nn.b_1[k] -= nn.h
elif l == 2:
nn.b_2[k] -= nn.h
elif l == 3:
nn.b_3[k] -= nn.h
nn.query()
e2 = np.sum((nn.t - nn.out_o) ** 2) / 2
print()
if l == 1:
delta = nn.delta_b_1[k]
elif l == 2:
delta = nn.delta_b_2[k]
elif l == 3:
delta = nn.delta_b_3[k]
numeric_gradient = (e1 - e2) / (2 * nn.h)
print("%.16f, %.16f" % (delta, numeric_gradient), end=", ")
_gradient_check(delta, numeric_gradient)
print()
def sigmoid(z: np.ndarray):
return 1 / (1 + np.exp(-z))
def d_sigmoid(z: np.ndarray):
return sigmoid(z) * (1.0 - sigmoid(z))
def relu(z: np.ndarray):
return np.maximum(z, 0)
def d_relu(z: np.ndarray):
return z > 0
# --
def GD(self, delta, t, l):
return - self.lr * delta
def adam(self, delta, t, l):
beta1 = .9
beta2 = .999
eps = 1e-8
self.m[l] = beta1 * self.m[l] + (1. - beta1) * delta
self.v[l] = beta2 * self.v[l] + (1. - beta2) * delta**2
self.m_k_hat = self.m[l] / (1. - beta1**(t))
self.v_k_hat = self.v[l] / (1. - beta2**(t))
self.update_parameters = - (self.lr * self.m_k_hat / (np.sqrt(self.v_k_hat) + eps))
return self.update_parameters
def momentum(self, delta, t, l):
gamma = .9
self.m[l] = gamma * self.m[l] + self.lr * delta
return - self.m[l]
class NeuralNetwork:
def __init__(self):
self.i = np.array([0.4,-0.2,0.1,0.1,-0.15,0.6,-0.9]).reshape(-1, 1)
np.random.seed(12)
self.w_1 = np.random.rand(6, 7)
self.b_1 = np.random.rand(6).reshape(-1, 1)
self.w_2 = np.random.rand(5, 6)
self.b_2 = np.random.rand(5).reshape(-1, 1)
self.w_3 = np.random.rand(2, 5)
self.b_3 = np.random.rand(2).reshape(-1, 1)
self.t = np.array([[0.87503811],[0.83690408]])
self.lr = 0.1
self.h = 1e-4
# Optimizer Parameters
self.iter = 1
self.m = [
np.zeros(self.w_3.shape),
np.zeros(self.b_3.shape),
np.zeros(self.w_2.shape),
np.zeros(self.b_2.shape),
np.zeros(self.w_1.shape),
np.zeros(self.b_1.shape),
]
self.v = [
np.zeros(self.w_3.shape),
np.zeros(self.b_3.shape),
np.zeros(self.w_2.shape),
np.zeros(self.b_2.shape),
np.zeros(self.w_1.shape),
np.zeros(self.b_1.shape),
]
def _forward(self):
self.net_h1 = np.dot(self.w_1, self.i) + self.b_1
self.out_h1 = relu(self.net_h1)
self.net_h2 = np.dot(self.w_2, self.out_h1) + self.b_2
self.out_h2 = sigmoid(self.net_h2)
self.net_o = np.dot(self.w_3, self.out_h2) + self.b_3
self.out_o = sigmoid(self.net_o)
def _backward(self, optimizer):
d_o_errors = - (self.t - self.out_o)
self.delta_w_3 = np.dot(d_o_errors * d_sigmoid(self.net_o), self.out_h2.T)
self.w_3 += optimizer(self, self.delta_w_3, self.iter, 0)
self.delta_b_3 = d_o_errors * d_sigmoid(self.net_o)
self.b_3 += optimizer(self, self.delta_b_3, self.iter, 1)
d_h2_errors = np.dot(self.w_3.T, d_o_errors * d_sigmoid(self.net_o))
self.delta_w_2 = np.dot(d_h2_errors * d_sigmoid(self.net_h2), self.out_h1.T)
self.w_2 += optimizer(self, self.delta_w_2, self.iter, 2)
self.delta_b_2 = d_h2_errors * d_sigmoid(self.net_h2)
self.b_2 += optimizer(self, self.delta_b_2, self.iter, 3)
d_h1_errors = np.dot(self.w_2.T, d_h2_errors * d_sigmoid(self.net_h2))
self.delta_w_1 = np.dot(d_h1_errors * d_relu(self.net_h1), self.i.T)
self.w_1 += optimizer(self, self.delta_w_1, self.iter, 4)
self.delta_b_1 = d_h1_errors * d_relu(self.net_h1)
self.b_1 += optimizer(self, self.delta_b_1, self.iter, 5)
self.iter += 1
def train(self, optimizer = GD):
self._forward()
self._backward(optimizer)
def query(self):
self._forward()
def result(self):
print(self.t - self.out_o)
nn = NeuralNetwork()
출력 레이어에 activation(여기서는 sigmoid)이 없다면, 아래 처럼 최종 가중치 행렬의 delta 값과, 이전 가중치 행렬에 부여되는 에러값 계산이 다르다.
def _forward():
...
# 최종 출력 레이어에 activation(sigmoid)이 없다면,
out_o = net_o
def _backward():
...
# 최종 출력 레이어에 activation(sigmoid)이 없다면,
delta_w_3 = np.dot(d_o_errors, out_h2.T)
delta_b_3 = d_o_errors
...
# 이전 레이어의 에러에도 activation 미분이 생략된다.
d_h2_errors = np.dot(w_3.T, d_o_errors)
히든 레이어의 w1에 대한 delta_w1 수식은 아래와 같다.
$$\frac{\partial E_{total}}{\partial w_{1}} = (\sum\limits_{o}{\frac{\partial E_{total}}{\partial out_{o}} * \frac{\partial out_{o}}{\partial net_{o}} * \frac{\partial net_{o}}{\partial out_{h1}}}) * \frac{\partial out_{h1}}{\partial net_{h1}} * \frac{\partial net_{h1}}{\partial w_{1}}$$$y_n$ 을 구하는 것이 역전파의 핵심이며 수식에서, $$\frac{\partial out_{o}}{\partial net_{o}} * \frac{\partial net_{o}}{\partial out_{h1}}$$ 부분이다. 즉, 출력 레이어의 activation 미분과 이전 가중치(w5, w6)를 곱한 값이 된다.
gradient_checking(nn, 3)
# 거의 비슷하여 정답으로 간주할 수 있으나 cs231n의 기준에는 미치지 못한다.
gradient_checking(nn, 2)
gradient_checking(nn, 1)
# Gradient Descent 학습
nn.__init__()
delta_w_1_history = []
w_1_history = []
for _ in range(7):
delta_w_1_history.append([])
w_1_history.append([])
delta_b_1_history = []
b_1_history = []
for _ in range(2000):
nn.train()
for j in range(7):
delta_w_1_history[j].append(nn.delta_w_1[1][j])
w_1_history[j].append(nn.w_1[1][j])
delta_b_1_history.append(nn.delta_b_1[1][0])
b_1_history.append(nn.b_1[1][0])
nn.query()
nn.result()
# plot with various axes scales
plt.figure(1)
for j in range(7):
plt.subplot(221)
plt.plot(delta_w_1_history[j])
plt.title("delta_w_1_history")
plt.subplot(222)
plt.plot(w_1_history[j])
plt.title("w_1_history")
plt.subplot(223)
plt.plot(delta_b_1_history)
plt.title("delta_b_1_history")
plt.subplot(224)
plt.plot(b_1_history)
plt.title("b_1_history")
# Adjust the subplot layout, because the logit one may take more space
# than usual, due to y-tick labels like "1 - 10^{-3}"
# https://matplotlib.org/gallery/pyplots/pyplot_scales.html#sphx-glr-gallery-pyplots-pyplot-scales-py
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.5, wspace=0.35)
plt.show()
w_1의 기울기가 0에 이르는 지점이 bias 학습이 함께 진행될때는 1,200 epoch 정도이나,
bias 학습을 하지 않으면 1,500 epoch 이상을 넘어선다.
nn.w_1
# Adam 학습
nn.__init__()
adam_w_1_history = []
adam_b_1_history = []
for _ in range(2000):
nn.train(adam)
adam_w_1_history.append(nn.w_1[1][0])
adam_b_1_history.append(nn.b_1[1][0])
nn.query()
nn.result()
nn.w_1
# Momentum 학습
nn.__init__()
momentum_w_1_history = []
momentum_b_1_history = []
for _ in range(2000):
nn.train(momentum)
momentum_w_1_history.append(nn.w_1[1][0])
momentum_b_1_history.append(nn.b_1[1][0])
nn.query()
nn.result()
nn.w_1
plt.figure(1)
plt.subplot(221)
plt.plot(adam_w_1_history)
plt.plot(w_1_history[0])
plt.title("w_1_history")
plt.subplot(222)
plt.plot(adam_b_1_history)
plt.plot(b_1_history)
plt.title("b_1_history")
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.5, wspace=0.35)
plt.show()
# Adam의 학습 시간은 GD의 1/20 수준에 불과했다. 도달하는 값은 많이 다르다.
plt.figure(1)
plt.subplot(221)
plt.plot(momentum_w_1_history)
plt.plot(w_1_history[0])
plt.title("w_1_history")
plt.subplot(222)
plt.plot(momentum_b_1_history)
plt.plot(b_1_history)
plt.title("b_1_history")
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.5, wspace=0.35)
plt.show()
# Momentum은 GD와 거의 비슷한 값에 도달하는데, 학습 시간은 1/12 수준에 불과했다.
plt.figure(1)
plt.subplot(221)
plt.plot(momentum_w_1_history[:200])
plt.plot(adam_w_1_history[:200])
plt.title("w_1_history")
plt.subplot(222)
plt.plot(momentum_b_1_history[:200])
plt.plot(adam_b_1_history[:200])
plt.title("b_1_history")
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.5, wspace=0.35)
plt.show()
# Adam과 Momentum의 비교, 하락폭이 큰 쪽이 Adam 이다.
# Adam의 m, v 조회
nn.__init__()
adam_m_1_history = []
adam_v_1_history = []
adam_m_1_hat_history = []
adam_v_1_hat_history = []
update_parameters = []
for _ in range(200):
nn.train(adam)
adam_m_1_history.append(nn.m[1][0]) # m_k_hat 기준과 동일하게 맞추려면 마지막 레이어인 5로 지정한다.
adam_v_1_history.append(nn.v[1][0])
adam_m_1_hat_history.append(nn.m_k_hat[5])
adam_v_1_hat_history.append(nn.v_k_hat[5])
update_parameters.append(nn.update_parameters[5])
nn.query()
nn.result()
plt.figure(1)
plt.subplot(221)
plt.plot(adam_m_1_history)
plt.title("adam_m_1_history")
plt.subplot(222)
plt.plot(adam_v_1_history)
plt.title("adam_v_1_history")
plt.subplot(223)
plt.plot(adam_m_1_hat_history)
plt.title("adam_m_1_hat_history")
plt.subplot(224)
plt.plot(adam_v_1_hat_history)
plt.title("adam_v_1_hat_history")
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.5, wspace=0.35)
plt.show()
# Adam의 m, v 조회
plt.subplot(224)
plt.plot(update_parameters)
plt.title("update_parameters")
plt.show()