Notes for Why does deep and cheap learning work so well? (ArXiv:1608.08225v1/cond-mat.dis-nn) by Lin and Tegmark.
Let's implement a simple multiplication network as per figure 2 in the paper. Obviously this could be done much more efficiently with any of the deep-learning-in-a-box packages that exist today. My interest is strictly to gain some intuition about this result, so I'm doing it all by hand here.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = 10, 6
plt.rcParams['axes.facecolor'] = "0.92"
def σ(y):
"Logistic sigmoid"
return 1/(1+np.exp(-y))
y = np.linspace(-10, 10)
plt.plot(y, σ(y));
We need $\sigma''(0)$. Analytically I'm getting
$$ \sigma''(y) = e^{-y} \frac{e^{-y}-1}{(e^{-y}+1)^3} $$I'm rusty, so let's double check my algebra with SymPy:
import sympy as S
S.init_printing(use_latex=True)
y = S.symbols('y')
s = 1/(1+S.exp(-y))
s1 = S.diff(s, y)
s1
s2 = S.simplify(S.diff(s1, y))
s2
These appear to be different, but it's b/c sympy simplified it to use $e^y$ and I computed it with $e^{-y}$. Since I'm lazy, let's have sympy check that the two forms are identical (the algebra is trivial but it's late and I'm tired):
ey = S.exp(-y)
s2f = ey*(ey-1)/(ey+1)**3
s2f
S.simplify(s2 - s2f)
Indeed.
As per the paper, $\mu$ must be defined as
$$ \mu = \frac{\lambda^{-2}}{8\sigma''(0)} $$Note that in order to construct the full network, we need shift the origin of $\sigma$ by 1 to ensure $\sigma''(0) \neq 0$ so $\mu$ can be finite. We can do this by adding a bias term $\mathbf{b} = [1, 1, 1, 1]$, as described in the paper. This means we need to evaluate $\sigma''(1)$ instead:
S.init_printing(use_latex=False) # Turn it off again so it doesn't slow everything down.
# Let's evaluate the 2nd derivative at the origin (shifted to y=1), numerically
# (so we can use it below in the network). Since the the point of evaluation of σ'' must be consistent
# with the definition of the vector b below, let's store it in a variable we also use to construct b
σ_origin = 1
s2_ori = float(s2.subs(y, σ_origin).n())
s2_ori
-0.09085774767294841
Let's define the generic form of an affine layer. To support proper composition in a multi-layer network, we need to express the affine layer as a function:
def A(W, b):
"Affine layer"
return lambda y: W @ y + b
Now, let's construct the full network in Fig. 2 (left side). It should converge to the multiplication operator as $\lambda \rightarrow 0$.
Warning: I am having to redefine $\mu \rightarrow 2\times\mu$, in order to get the right numerical results. It may be an error in the paper's algebra or in my implementation, but I haven't had the time to track it down yet. Feedback/hints welcome.
# The connection matrices for the two affine layers. These are just constants
# that only need to be defined once
w1 = np.array([1.0, 1, -1, -1, 1, -1, -1, 1]).reshape(4,2)
b = σ_origin * np.array([1.0, 1, 1, 1]) # The bias shift to avoid evaluating σ''(0)
w2 = np.array([1.0, 1, -1, -1])
λ = 0.00001 # as it -> 0, the multiply() function improves in accuracy
μ = 1.0/(8*s2_ori*λ**2) # from formula in fig. 2
# The actual matrices carry λ and μ:
W1 = λ*w1
W2 = 2*μ*w2 # This factor of 2 is not in the paper. Error in my algebra?
# Now we build the affine layers as functions
A1 = A(W1, b)
A2 = A(W2, 0)
# With these in place, we can then build the 3-layer network that approximates
# the multiplication operator
def multiply(u, v):
"Multiply two numbers with a neural network."
y = np.array([u, v], dtype=float)
return A2(σ(A1(y)))
# Let's verify it with two numbers
u, v = 376, 432
uvn = multiply(u, v)
uv = u*v
err = abs(uv - uvn)
print("λ :", λ)
print("Network u*v:", uvn)
print("Exact u*v :", uv)
print("Abs. Error : %.2g" % err)
print("Rel. Error : %.2g" % (err/uv))
λ : 1e-05 Network u*v: 162430.792953 Exact u*v : 162432 Abs. Error : 1.2 Rel. Error : 7.4e-06
Let's have a quick look at how the approximation converges. From eq. (8) in the paper, we expect the error to be $\mathcal{O}(\lambda^2(u^2+v^2))$. Note that the equation in the paper states the error without the extra $\lambda^2$ factor, but that's because it's making an analysis assuming $|u| \ll 1, |v| \ll 1$, where as once implemented we relax this restriction by rescaling $u$ and $v$ by $\lambda$ via the first affine layer.
In order to conveniently scan over values of $\lambda$, $u$ and $v$, it will be helpful to encapsulate the above construction into a callable object that precomputes all relevant quantities at construction time (for each $\lambda$) and then can be quickly called:
class NNMultiply:
def __init__(self, λ=1e-5):
self.λ = λ
μ = 1.0/(8*s2_ori*λ**2)
self.W1 = λ*w1
self.W2 = 2*μ*w2 # This factor of 2 is not in the paper. Error in my algebra?
self.b = σ_origin * np.array([1.0, 1, 1, 1]) # The bias shift to avoid evaluating σ''(0)
def __call__(self, u, v):
"Multiply two numbers with a neural network."
y = np.array([u, v], dtype=float)
# Since we'll be calling this a lot, let's make a small optimization and "unroll"
# our network to avoid a few unnecessary function calls
return self.W2 @ σ(self.W1 @ y + self.b)
# Let's verify it with the same values as above, as a sanity check
u, v = 376, 432
mult = NNMultiply(λ)
uvn = mult(u, v)
uv = u*v
err = abs(uv - uvn)
print("λ :", mult.λ)
print("Network u*v:", uvn)
print("Exact u*v :", uv)
print("Abs. Error : %.2g" % err)
print("Rel. Error : %.2g" % (err/uv))
λ : 1e-05 Network u*v: 162430.792953 Exact u*v : 162432 Abs. Error : 1.2 Rel. Error : 7.4e-06
Now, we can build an error plot at various values of $\lambda$. Since the function we're approximating is simply $uv$, we can keep $u$ constant and only scan over $v$ for each $\lambda$, as long as we cover a range that goes from $u \ll v$ to $u \gg v$.
In the figure below, next to each line for a given $\lambda$ and in the same color, is a dashed line that plots $\lambda^2(u^2+v^2)$, which should be (modulo a constant) a good estimate of the observed error.
u = 10
lambdas = np.logspace(-1, -7, 7)
vv = np.linspace(0.1, 100, 50)
fig, ax = plt.subplots()
for λ in lambdas:
mult = NNMultiply(λ)
err = []
for v in vv:
uvn = mult(u, v)
uv = u*v
err.append(abs(uv - uvn)/uv)
l, = ax.semilogy(vv, err, label=r"$\lambda =$ %.2g" % λ)
ax.semilogy(vv, (λ**2)*(u**2+vv**2), '--', color=l.get_color())
ax.set_xlabel('v')
ax.set_ylabel('Rel. error')
ax.set_title("u fixed at %g" % u)
ax.legend();
As we see, once $\lambda < 10^{-5}$, we start hitting some numerical issues, and below $\lambda < 10^{-6}$, the error is not only worse than the analytical estimate, it actually starts getting worse as $\lambda$ gets smaller. This is because $\mu \sim 1/\lambda^2$, and in double precision we don't have enough digits to go further.