#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial demonstrates the use of Cyclical Learning Rate from the Addons package.
It has been shown it is beneficial to adjust the learning rate as training progresses for a neural network. It has manifold benefits ranging from saddle point recovery to preventing numerical instabilities that may arise during backpropagation. But how does one know how much to adjust with respect to a particular training timestamp? In 2015, Leslie Smith noticed that you would want to increase the learning rate to traverse faster across the loss landscape but you would also want to reduce the learning rate when approaching convergence. To realize this idea, he proposed Cyclical Learning Rates (CLR) where you would adjust the learning rate with respect to the cycles of a function. For a visual demonstration, you can check out this blog. CLR is now available as a TensorFlow API. For more details, check out the original paper here.
!pip install -q -U tensorflow_addons
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.random.set_seed(42)
np.random.seed(42)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
BATCH_SIZE = 64
EPOCHS = 10
INIT_LR = 1e-4
MAX_LR = 1e-2
def get_training_model():
model = tf.keras.Sequential(
[
layers.InputLayer((28, 28, 1)),
layers.experimental.preprocessing.Rescaling(scale=1./255),
layers.Conv2D(16, (5, 5), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(32, (5, 5), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.SpatialDropout2D(0.2),
layers.GlobalAvgPool2D(),
layers.Dense(128, activation="relu"),
layers.Dense(10, activation="softmax"),
]
)
return model
def train_model(model, optimizer):
model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer,
metrics=["accuracy"])
history = model.fit(x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
epochs=EPOCHS)
return history
In the interest of reproducibility, the initial model weights are serialized which you will be using to conduct our experiments.
initial_model = get_training_model()
initial_model.save("initial_model")
standard_model = tf.keras.models.load_model("initial_model")
no_clr_history = train_model(standard_model, optimizer="sgd")
The tfa.optimizers.CyclicalLearningRate
module return a direct schedule that can be passed to an optimizer. The schedule takes a step as its input and outputs a value calculated using CLR formula as laid out in the paper.
steps_per_epoch = len(x_train) // BATCH_SIZE
clr = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=INIT_LR,
maximal_learning_rate=MAX_LR,
scale_fn=lambda x: 1/(2.**(x-1)),
step_size=2 * steps_per_epoch
)
optimizer = tf.keras.optimizers.SGD(clr)
Here, you specify the lower and upper bounds of the learning rate and the schedule will oscillate in between that range ([1e-4, 1e-2] in this case). scale_fn
is used to define the function that would scale up and scale down the learning rate within a given cycle. step_size
defines the duration of a single cycle. A step_size
of 2 means you need a total of 4 iterations to complete one cycle. The recommended value for step_size
is as follows:
factor * steps_per_epoch
where factor lies within the [2, 8] range.
In the same CLR paper, Leslie also presented a simple and elegant method to choose the bounds for learning rate. You are encouraged to check it out as well. This blog post provides a nice introduction to the method.
Below, you visualize how the clr
schedule looks like.
step = np.arange(0, EPOCHS * steps_per_epoch)
lr = clr(step)
plt.plot(step, lr)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()
In order to better visualize the effect of CLR, you can plot the schedule with an increased number of steps.
step = np.arange(0, 100 * steps_per_epoch)
lr = clr(step)
plt.plot(step, lr)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()
The function you are using in this tutorial is referred to as the triangular2
method in the CLR paper. There are other two functions there were explored namely triangular
and exp
(short for exponential).
clr_model = tf.keras.models.load_model("initial_model")
clr_history = train_model(clr_model, optimizer=optimizer)
As expected the loss starts higher than the usual and then it stabilizes as the cycles progress. You can confirm this visually with the plots below.
(fig, ax) = plt.subplots(2, 1, figsize=(10, 8))
ax[0].plot(no_clr_history.history["loss"], label="train_loss")
ax[0].plot(no_clr_history.history["val_loss"], label="val_loss")
ax[0].set_title("No CLR")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].set_ylim([0, 2.5])
ax[0].legend()
ax[1].plot(clr_history.history["loss"], label="train_loss")
ax[1].plot(clr_history.history["val_loss"], label="val_loss")
ax[1].set_title("CLR")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("Loss")
ax[1].set_ylim([0, 2.5])
ax[1].legend()
fig.tight_layout(pad=3.0)
fig.show()
Even though for this toy example, you did not see the effects of CLR much but be noted that it is one of the main ingredients behind Super Convergence and can have a really good impact when training in large-scale settings.