Here we will understand the reparameterization trick used by Kingma and Welling (2014) to train their variational autoencoder.

Assume we have a normal distribution $q$ that is parameterized by $\theta$, specifically $q_{\theta}(x) = N(\theta,1)$. We want to solve the below problem $$ \text{min}_{\theta} \quad E_q[x^2] $$ This is of course a rather silly problem and the optimal $\theta$ is obvious. We want to understand how the reparameterization trick helps in calculating the gradient of this objective $E_q[x^2]$.

One way to calculate $\nabla_{\theta} E_q[x^2]$ is as follows $$ \nabla_{\theta} E_q[x^2] = \nabla_{\theta} \int q_{\theta}(x) x^2 dx = \int x^2 \nabla_{\theta} q_{\theta}(x) \frac{q_{\theta}(x)}{q_{\theta}(x)} dx = \int q_{\theta}(x) \nabla_{\theta} \log q_{\theta}(x) x^2 dx = E_q[x^2 \nabla_{\theta} \log q_{\theta}(x)] $$

For our example where $q_{\theta}(x) = N(\theta,1)$, this method gives $$ \nabla_{\theta} E_q[x^2] = E_q[x^2 (x-\theta)] $$

In [64]:

```
import numpy as np
N = 1000
theta = 2.0
eps = np.random.randn(N)
x = theta + eps
grad1 = lambda x: np.sum(np.square(x)*(x-theta)) / x.size
grad2 = lambda eps: np.sum(2*(theta + eps)) / x.size
print grad1(x)
print grad2(eps)
```

3.86872102149 4.03506045463

Let us plot the variance for different sample sizes.

In [66]:

```
Ns = [10, 100, 1000, 10000, 100000]
reps = 100
means1 = np.zeros(len(Ns))
vars1 = np.zeros(len(Ns))
means2 = np.zeros(len(Ns))
vars2 = np.zeros(len(Ns))
est1 = np.zeros(reps)
est2 = np.zeros(reps)
for i, N in enumerate(Ns):
for r in range(reps):
x = np.random.randn(N) + theta
est1[r] = grad1(x)
eps = np.random.randn(N)
est2[r] = grad2(eps)
means1[i] = np.mean(est1)
means2[i] = np.mean(est2)
vars1[i] = np.var(est1)
vars2[i] = np.var(est2)
print means1
print means2
print
print vars1
print vars2
```

In [67]:

```
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(vars1)
plt.plot(vars2)
plt.legend(['no rt', 'rt'])
```

/usr/local/lib/python2.7/dist-packages/matplotlib/__init__.py:872: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter. warnings.warn(self.msg_depr % (key, alt_key))

Out[67]:

<matplotlib.legend.Legend at 0x7facb844ae50>