#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial demonstrates how to use tf.distribute.Strategy
with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.
We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.
from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow
!pip install tf-nightly-gpu
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
train_labels = train_labels.astype('int64')
test_labels = test_labels.astype('int64')
How does tf.distribute.MirroredStrategy
strategy work?
Note: You can put all the code below inside a single scope. We are dividing it into several code cells for illustration purposes.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
tf.distribute.Strategy.experimental_distribute_dataset
evenly distributes the dataset across all the replicas.
with strategy.scope():
train_dataset = tf.data.Dataset.from_tensor_slices(
(train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_ds = strategy.experimental_distribute_dataset(train_dataset)
test_dataset = tf.data.Dataset.from_tensor_slices(
(test_images, test_labels)).batch(BATCH_SIZE)
test_ds = strategy.experimental_distribute_dataset(test_dataset)
Create a model using tf.keras.Sequential
. You can also use the Model Subclassing API to do this.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
optimizer = tf.train.GradientDescentOptimizer(0.001)
Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.
So, how should the loss be calculated when using a tf.distribute.Strategy
?
across the replicas (4 GPUs), each replica getting an input of size 16.
Why do this?
How to do this in TensorFlow?
scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
or you can use tf.nn.compute_average_loss
which takes the per example loss,
optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.
the loss value by number of replicas. You can do this by using the tf.nn.scale_regularization_loss
function.
Using tf.reduce_mean
is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.
This reduction and scaling is done automatically in keras model.compile
and model.fit
If using tf.keras.losses
classes, the loss reduction needs to be explicitly specified to be one of NONE
or SUM
. AUTO
and SUM_OVER_BATCH_SIZE
are disallowed when used with tf.distribute.Strategy
. AUTO
is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. SUM_OVER_BATCH_SIZE
is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.
with strategy.scope():
def train_step(dist_inputs):
def step_fn(inputs):
images, labels = inputs
logits = model(images)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
loss = tf.nn.compute_average_loss(cross_entropy, global_batch_size=BATCH_SIZE)
train_op = optimizer.minimize(loss)
with tf.control_dependencies([train_op]):
return tf.identity(loss)
per_replica_losses = strategy.experimental_run_v2(
step_fn, args=(dist_inputs,))
mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss
with strategy.scope():
train_iterator = train_ds.make_initializable_iterator()
iterator_init = train_iterator.initialize()
var_init = tf.global_variables_initializer()
loss = train_step(next(train_iterator))
with tf.Session() as sess:
sess.run([var_init])
for epoch in range(EPOCHS):
sess.run([iterator_init])
for step in range(10000):
if step % 1000 == 0:
print('Epoch {} Step {} Loss {:.4f}'.format(epoch+1,
step,
sess.run(loss)))
Try out the new tf.distribute.Strategy
API on your models.