The following additional libraries are needed to run this notebook. Note that running on Colab is experimental, please report a Github issue if you have any problem.
!pip install d2l==0.17.6
:label:sec_minibatch_sgd
So far we encountered two extremes in the approach to gradient based learning: :numref:sec_gd
uses the full dataset to compute gradients and to update parameters, one pass at a time. Conversely :numref:sec_sgd
processes one observation at a time to make progress. Each of them has its own drawbacks. Gradient Descent is not particularly data efficient whenever data is very similar. Stochastic Gradient Descent is not particularly computationally efficient since CPUs and GPUs cannot exploit the full power of vectorization. This suggests that there might be a happy medium, and in fact, that's what we have been using so far in the examples we discussed.
At the heart of the decision to use minibatches is computational efficiency. This is most easily understood when considering parallelization to multiple GPUs and multiple servers. In this case we need to send at least one image to each GPU. With 8 GPUs per server and 16 servers we already arrive at a minibatch size of 128.
Things are a bit more subtle when it comes to single GPUs or even CPUs. These devices have multiple types of memory, often multiple type of compute units and different bandwidth constraints between them. For instance, a CPU has a small number of registers and then L1, L2 and in some cases even L3 cache (which is shared between the different processor cores). These caches are of increasing size and latency (and at the same time they are of decreasing bandwidth). Suffice it to say, the processor is capable of performing many more operations than what the main memory interface is able to provide.
The way to alleviate these constraints is to use a hierarchy of CPU caches which are actually fast enough to supply the processor with data. This is the driving force behind batching in deep learning. To keep matters simple, consider matrix-matrix multiplication, say $\mathbf{A} = \mathbf{B}\mathbf{C}$. We have a number of options for calculating $\mathbf{A}$. For instance we could try the following:
If we follow the first option, we will need to copy one row and one column vector into the CPU each time we want to compute an element $\mathbf{A}_{ij}$. Even worse, due to the fact that matrix elements are aligned sequentially we are thus required to access many disjoint locations for one of the two vectors as we read them from memory. The second option is much more favorable. In it, we are able to keep the column vector $\mathbf{C}_{:,j}$ in the CPU cache while we keep on traversing through $B$. This halves the memory bandwidth requirement with correspondingly faster access. Of course, option 3 is most desirable. Unfortunately, most matrices might not entirely fit into cache (this is what we are discussing after all). However, option 4 offers a practically useful alternative: we can move blocks of the matrix into cache and multiply them locally. Optimized libraries take care of this for us. Let us have a look at how efficient these operations are in practice.
Beyond computational efficiency, the overhead introduced by Python and by the deep learning framework itself is considerable. Recall that each time we execute a command the Python interpreter sends a command to the MXNet engine which needs to insert it into the computational graph and deal with it during scheduling. Such overhead can be quite detrimental. In short, it is highly advisable to use vectorization (and matrices) whenever possible.
%matplotlib inline
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l
timer = d2l.Timer()
A = tf.Variable(tf.zeros((256, 256)))
B = tf.Variable(tf.random.normal([256, 256], 0, 1))
C = tf.Variable(tf.random.normal([256, 256], 0, 1))
Element-wise assignment simply iterates over all rows and columns of $\mathbf{B}$ and $\mathbf{C}$ respectively to assign the value to $\mathbf{A}$.
# Compute A = BC one element at a time
timer.start()
for i in range(256):
for j in range(256):
A[i, j].assign(tf.tensordot(B[i, :], C[:, j], axes=1))
timer.stop()
123.80362010002136
A faster strategy is to perform column-wise assignment.
timer.start()
for j in range(256):
A[:, j].assign(tf.tensordot(B, C[:, j], axes=1))
timer.stop()
0.3721303939819336
Last, the most effective manner is to perform the entire operation in one block. Let us see what the respective speed of the operations is.
timer.start()
A.assign(tf.tensordot(B, C, axes=1))
timer.stop()
# Multiply and add count as separate operations (fused in practice)
gigaflops = [2/i for i in timer.times]
print(f'performance in Gigaflops: element {gigaflops[0]:.3f}, '
f'column {gigaflops[1]:.3f}, full {gigaflops[2]:.3f}')
performance in Gigaflops: element 0.016, column 5.374, full 84.145
:label:sec_minibatches
In the past we took it for granted that we would read minibatches of data rather than single observations to update parameters. We now give a brief justification for it. Processing single observations requires us to perform many single matrix-vector (or even vector-vector) multiplications, which is quite expensive and which incurs a significant overhead on behalf of the underlying deep learning framework. This applies both to evaluating a network when applied to data (often referred to as inference) and when computing gradients to update parameters. That is, this applies whenever we perform $\mathbf{w} \leftarrow \mathbf{w} - \eta_t \mathbf{g}_t$ where
$$\mathbf{g}_t = \partial_{\mathbf{w}} f(\mathbf{x}_{t}, \mathbf{w})$$We can increase the computational efficiency of this operation by applying it to a minibatch of observations at a time. That is, we replace the gradient $\mathbf{g}_t$ over a single observation by one over a small batch
$$\mathbf{g}_t = \partial_{\mathbf{w}} \frac{1}{|\mathcal{B}_t|} \sum_{i \in \mathcal{B}_t} f(\mathbf{x}_{i}, \mathbf{w})$$Let us see what this does to the statistical properties of $\mathbf{g}_t$: since both $\mathbf{x}_t$ and also all elements of the minibatch $\mathcal{B}_t$ are drawn uniformly at random from the training set, the expectation of the gradient remains unchanged. The variance, on the other hand, is reduced significantly. Since the minibatch gradient is composed of $b := |\mathcal{B}_t|$ independent gradients which are being averaged, its standard deviation is reduced by a factor of $b^{-\frac{1}{2}}$. This, by itself, is a good thing, since it means that the updates are more reliably aligned with the full gradient.
Naively this would indicate that choosing a large minibatch $\mathcal{B}_t$ would be universally desirable. Alas, after some point, the additional reduction in standard deviation is minimal when compared to the linear increase in computational cost. In practice we pick a minibatch that is large enough to offer good computational efficiency while still fitting into the memory of a GPU. To illustrate the savings let us have a look at some code. In it we perform the same matrix-matrix multiplication, but this time broken up into "minibatches" of 64 columns at a time.
timer.start()
for j in range(0, 256, 64):
A[:, j:j+64].assign(tf.tensordot(B, C[:, j:j+64], axes=1))
timer.stop()
print(f'performance in Gigaflops: block {2 / timer.times[3]:.3f}')
performance in Gigaflops: block 283.083
As we can see, the computation on the minibatch is essentially as efficient as on the full matrix. A word of caution is in order. In :numref:sec_batch_norm
we used a type of regularization that was heavily dependent on the amount of variance in a minibatch. As we increase the latter, the variance decreases and with it the benefit of the noise-injection due to batch normalization. See e.g., :cite:Ioffe.2017
for details on how to rescale and compute the appropriate terms.
Let us have a look at how minibatches are efficiently generated from data. In the following we use a dataset developed by NASA to test the wing noise from different aircraft to compare these optimization algorithms. For convenience we only use the first $1,500$ examples. The data is whitened for preprocessing, i.e., we remove the mean and rescale the variance to $1$ per coordinate.
#@save
d2l.DATA_HUB['airfoil'] = (d2l.DATA_URL + 'airfoil_self_noise.dat',
'76e5be1548fd8222e5074cf0faae75edff8cf93f')
#@save
def get_data_ch11(batch_size=10, n=1500):
data = np.genfromtxt(d2l.download('airfoil'),
dtype=np.float32, delimiter='\t')
data = (data - data.mean(axis=0)) / data.std(axis=0)
data_iter = d2l.load_array((data[:n, :-1], data[:n, -1]),
batch_size, is_train=True)
return data_iter, data.shape[1]-1
Recall the minibatch stochastic gradient descent implementation from :numref:sec_linear_scratch
. In the following we provide a slightly more general implementation. For convenience it has the same call signature as the other optimization algorithms introduced later in this chapter. Specifically, we add the status
input states
and place the hyperparameter in dictionary hyperparams
. In
addition, we will average the loss of each minibatch example in the training
function, so the gradient in the optimization algorithm does not need to be
divided by the batch size.
def sgd(params, grads, states, hyperparams):
for param, grad in zip(params, grads):
param.assign_sub(hyperparams['lr']*grad)
Next, we implement a generic training function to facilitate the use of the other optimization algorithms introduced later in this chapter. It initializes a linear regression model and can be used to train the model with minibatch stochastic gradient descent and other algorithms introduced subsequently.
#@save
def train_ch11(trainer_fn, states, hyperparams, data_iter,
feature_dim, num_epochs=2):
# Initialization
w = tf.Variable(tf.random.normal(shape=(feature_dim, 1),
mean=0, stddev=0.01),trainable=True)
b = tf.Variable(tf.zeros(1), trainable=True)
# Train
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[0, num_epochs], ylim=[0.22, 0.35])
n, timer = 0, d2l.Timer()
for _ in range(num_epochs):
for X, y in data_iter:
with tf.GradientTape() as g:
l = tf.math.reduce_mean(loss(net(X), y))
dw, db = g.gradient(l, [w, b])
trainer_fn([w, b], [dw, db], states, hyperparams)
n += X.shape[0]
if n % 200 == 0:
timer.stop()
p = n/X.shape[0]
q = p/tf.data.experimental.cardinality(data_iter).numpy()
r = (d2l.evaluate_loss(net, data_iter, loss),)
animator.add(q, r)
timer.start()
print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')
return timer.cumsum(), animator.Y[0]
Let us see how optimization proceeds for batch gradient descent. This can be achieved by setting the minibatch size to 1500 (i.e., to the total number of examples). As a result the model parameters are updated only once per epoch. There is little progress. In fact, after 6 steps progress stalls.
def train_sgd(lr, batch_size, num_epochs=2):
data_iter, feature_dim = get_data_ch11(batch_size)
return train_ch11(
sgd, None, {'lr': lr}, data_iter, feature_dim, num_epochs)
gd_res = train_sgd(1, 1500, 10)
loss: 0.247, 0.032 sec/epoch
When the batch size equals 1, we use stochastic gradient descent for optimization. For simplicity of implementation we picked a constant (albeit small) learning rate. In stochastic gradient descent, the model parameters are updated whenever an example is processed. In our case this amounts to 1500 updates per epoch. As we can see, the decline in the value of the objective function slows down after one epoch. Although both the procedures processed 1500 examples within one epoch, stochastic gradient descent consumes more time than gradient descent in our experiment. This is because stochastic gradient descent updated the parameters more frequently and since it is less efficient to process single observations one at a time.
sgd_res = train_sgd(0.005, 1)
loss: 0.244, 0.570 sec/epoch
Finally, when the batch size equals 100, we use minibatch stochastic gradient descent for optimization. The time required per epoch is shorter than the time needed for stochastic gradient descent and the time for batch gradient descent.
mini1_res = train_sgd(.4, 100)
loss: 0.245, 0.008 sec/epoch
Reducing the batch size to 10, the time for each epoch increases because the workload for each batch is less efficient to execute.
mini2_res = train_sgd(.05, 10)
loss: 0.243, 0.060 sec/epoch
Now we can compare the time vs. loss for the previous four experiments. As can be seen, although stochastic gradient descent converges faster than GD in terms of number of examples processed, it uses more time to reach the same loss than GD because computing the gradient example by example is not as efficient. Minibatch stochastic gradient descent is able to trade-off convergence speed and computation efficiency. A minibatch size of 10 is more efficient than stochastic gradient descent; a minibatch size of 100 even outperforms GD in terms of runtime.
d2l.set_figsize([6, 3])
d2l.plot(*list(map(list, zip(gd_res, sgd_res, mini1_res, mini2_res))),
'time (sec)', 'loss', xlim=[1e-2, 10],
legend=['gd', 'sgd', 'batch size=100', 'batch size=10'])
d2l.plt.gca().set_xscale('log')
In Gluon, we can use the Trainer
class to call optimization algorithms. This is used to implement a generic training function. We will use this throughout the current chapter.
#@save
def train_concise_ch11(trainer_fn, hyperparams, data_iter, num_epochs=2):
# Initialization
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1,
kernel_initializer=tf.random_normal_initializer(stddev=0.01)))
optimizer = trainer_fn(**hyperparams)
loss = tf.keras.losses.MeanSquaredError()
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[0, num_epochs], ylim=[0.22, 0.35])
n, timer = 0, d2l.Timer()
for _ in range(num_epochs):
for X, y in data_iter:
with tf.GradientTape() as g:
out = net(X)
l = loss(y, out)
params = net.trainable_variables
grads = g.gradient(l, params)
optimizer.apply_gradients(zip(grads, params))
n += X.shape[0]
if n % 200 == 0:
timer.stop()
p = n/X.shape[0]
q = p/tf.data.experimental.cardinality(data_iter).numpy()
# `MeanSquaredError` computes squared error without the 1/2
# factor
r = (d2l.evaluate_loss(net, data_iter, loss) / 2,)
animator.add(q, r)
timer.start()
print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')
Using Gluon to repeat the last experiment shows identical behavior.
data_iter, _ = get_data_ch11(10)
trainer = tf.keras.optimizers.SGD
train_concise_ch11(trainer, {'learning_rate': 0.05}, data_iter)
loss: 0.244, 0.104 sec/epoch
Trainer
class set_learning_rate
function to reduce the learning rate of the minibatch stochastic gradient descent to 1/10 of its previous value after each epoch.