#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial demonstrates manual image manipulations and augmentation using tf.image
.
Data augmentation is a common technique to improve results and avoid overfitting, see Overfitting and Underfitting for others.
!pip install git+https://github.com/tensorflow/docs
import urllib
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers
AUTOTUNE = tf.data.experimental.AUTOTUNE
import tensorflow_docs as tfdocs
import tensorflow_docs.plots
import tensorflow_datasets as tfds
import PIL.Image
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12, 5)
import numpy as np
Let's check the data augmentation features on an image and then augment a whole dataset later to train a model.
Download this image, by Von.grzanka, for augmentation.
image_path = tf.keras.utils.get_file("cat.jpg", "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg")
PIL.Image.open(image_path)
Read and decode the image to tensor format.
image_string=tf.io.read_file(image_path)
image=tf.image.decode_jpeg(image_string,channels=3)
A function to visualize and compare the original and augmented image side by side.
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
Flip the image either vertically or horizontally.
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)
Grayscale an image.
grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
plt.colorbar()
Saturate an image by providing a saturation factor.
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)
Change the brightness of image by providing a brightness factor.
bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)
Rotate an image by 90 degrees.
rotated = tf.image.rot90(image)
visualize(image, rotated)
Crop the image from center upto the image part you desire.
cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image,cropped)
See the tf.image
reference for details about available augmentation options.
Train a model on an augmented dataset.
Note: The problem solved here is somewhat artificial. It trains a densely connected network to be shift invariant by jittering the input images. It's much more efficient to use convolutional layers instead.
dataset, info = tfds.load('mnist', as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
num_train_examples= info.splits['train'].num_examples
Write a function to augment the images. Map it over the the dataset. This returns a dataset that augments the data on the fly.
def convert(image, label):
image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
return image, label
def augment(image,label):
image,label = convert(image, label)
image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding
image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28
image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
return image,label
BATCH_SIZE = 64
# Only use a subset of the data so it's easier to overfit, for this tutorial
NUM_EXAMPLES = 2048
Create the augmented dataset.
augmented_train_batches = (
train_dataset
# Only train on a subset, so you can quickly see the effect.
.take(NUM_EXAMPLES)
.cache()
.shuffle(num_train_examples//4)
# The augmentation is added here.
.map(augment, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE)
)
And a non-augmented one for comparison.
non_augmented_train_batches = (
train_dataset
# Only train on a subset, so you can quickly see the effect.
.take(NUM_EXAMPLES)
.cache()
.shuffle(num_train_examples//4)
# No augmentation.
.map(convert, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE)
)
Setup the validation dataset. This doesn't change whether or not you're using the augmentation.
validation_batches = (
test_dataset
.map(convert, num_parallel_calls=AUTOTUNE)
.batch(2*BATCH_SIZE)
)
Create and compile the model. The model is a two layered, fully-connected neural network without convolution.
def make_model():
model = tf.keras.Sequential([
layers.Flatten(input_shape=(28, 28, 1)),
layers.Dense(4096, activation='relu'),
layers.Dense(4096, activation='relu'),
layers.Dense(10)
])
model.compile(optimizer = 'adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
Train the model, without augmentation:
model_without_aug = make_model()
no_aug_history = model_without_aug.fit(non_augmented_train_batches, epochs=50, validation_data=validation_batches)
Train it again with augmentation:
model_with_aug = make_model()
aug_history = model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches)
In this example the augmented model converges to an accuracy ~95% on validation set. This is slightly higher (+1%) than the model trained without data augmentation.
plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({"Augmented": aug_history, "Non-Augmented": no_aug_history}, metric = "accuracy")
plt.title("Accuracy")
plt.ylim([0.75,1])
In terms of loss, the non-augmented model is obviously in the overfitting regime. The augmented model, while a few epoch slower, is still training correctly and clearly not overfitting.
plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({"Augmented": aug_history, "Non-Augmented": no_aug_history}, metric = "loss")
plt.title("Loss")
plt.ylim([0,1])