!pip install -U -q imgaug --user
import tensorflow as tf
tf.random.set_seed(42)
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
from imgaug import augmenters as iaa
import imgaug as ia
ia.seed(42)
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Tesla V100-SXM2-16GB, compute capability 7.0
For this example, we will be using the CIFAR10 dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
Total training examples: 50000 Total test examples: 10000
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 512
EPOCHS = 100
RandAugment
object¶rand_aug = iaa.RandAugment(n=3, m=7)
def augment(images):
# Input to `augment()` is a TensorFlow tensor which
# is not supported by `imgaug`. This is why we first
# convert it to its `numpy` variant.
images = tf.cast(images, tf.uint8)
return rand_aug(images=images.numpy())
Dataset
objects¶train_ds_rand = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
.map(
lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO,
)
# The returned output of `tf.py_function` contains an unncessary axis of
# 1-D and we need to remove it.
.map(
lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),
num_parallel_calls=AUTO,
)
.prefetch(AUTO)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test))
.batch(BATCH_SIZE)
.map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO)
.prefetch(AUTO)
)
For comparison purposes, let's also define a simple augmentation pipeline consisting of random flips, random rotations, and random zoomings.
simple_aug = tf.keras.Sequential(
[
layers.experimental.preprocessing.Resizing(72, 72),
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(factor=0.02),
layers.experimental.preprocessing.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Now, map the augmentation pipeline to our training dataset
train_ds_simple = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE*100)
.batch(BATCH_SIZE)
.map(lambda x, y: (simple_aug(x), y),
num_parallel_calls=AUTO)
.prefetch(AUTO)
)
sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")
simple_aug
¶sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")
def get_training_model():
resnet50_v2 = tf.keras.applications.ResNet50V2(
weights=None, include_top=True, input_shape=(72, 72, 3),
classes=10
)
model = tf.keras.Sequential([
layers.Input((72, 72, 3)),
layers.experimental.preprocessing.Rescaling(scale=1./127.5, offset=-1),
resnet50_v2,
layers.Activation("linear", dtype="float32")
])
return model
get_training_model().summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 72, 72, 3) 0 _________________________________________________________________ resnet50v2 (Functional) (None, 10) 23585290 _________________________________________________________________ activation (Activation) (None, 10) 0 ================================================================= Total params: 23,585,290 Trainable params: 23,539,850 Non-trainable params: 45,440 _________________________________________________________________
# For reproducibility, we first serialize the initialize weights
initial_model = get_training_model()
initial_model.save_weights("initial_weights.h5")
# We also set up an early stopping callback to prevent the models
# from overfitting
es = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10, restore_best_weights=True
)
rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial_weights.h5")
rand_aug_model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
rand_aug_model.fit(train_ds_rand,
validation_data=test_ds,
epochs=EPOCHS,
callbacks=[es])
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
Epoch 1/100 98/98 [==============================] - 85s 705ms/step - loss: 3.1592 - accuracy: 0.1834 - val_loss: nan - val_accuracy: 0.1045 Epoch 2/100 98/98 [==============================] - 69s 653ms/step - loss: 1.7934 - accuracy: 0.3600 - val_loss: 7.4934 - val_accuracy: 0.1491 Epoch 3/100 98/98 [==============================] - 68s 646ms/step - loss: 1.8542 - accuracy: 0.3434 - val_loss: nan - val_accuracy: 0.1157 Epoch 4/100 98/98 [==============================] - 68s 640ms/step - loss: 1.7277 - accuracy: 0.3930 - val_loss: 4.5487 - val_accuracy: 0.2414 Epoch 5/100 98/98 [==============================] - 71s 626ms/step - loss: 1.6077 - accuracy: 0.4401 - val_loss: 1.5698 - val_accuracy: 0.4525 Epoch 6/100 98/98 [==============================] - 68s 643ms/step - loss: 1.5469 - accuracy: 0.4623 - val_loss: 2.3275 - val_accuracy: 0.2820 Epoch 7/100 98/98 [==============================] - 69s 645ms/step - loss: 1.5825 - accuracy: 0.4494 - val_loss: 1.4070 - val_accuracy: 0.5059 Epoch 8/100 98/98 [==============================] - 68s 636ms/step - loss: 1.3531 - accuracy: 0.5279 - val_loss: 1.5508 - val_accuracy: 0.4852 Epoch 9/100 98/98 [==============================] - 69s 644ms/step - loss: 1.3057 - accuracy: 0.5373 - val_loss: 1.2720 - val_accuracy: 0.5548 Epoch 10/100 98/98 [==============================] - 67s 634ms/step - loss: 1.2089 - accuracy: 0.5755 - val_loss: 1.2662 - val_accuracy: 0.5588 Epoch 11/100 98/98 [==============================] - 68s 645ms/step - loss: 1.1647 - accuracy: 0.5924 - val_loss: 1.0406 - val_accuracy: 0.6329 Epoch 12/100 98/98 [==============================] - 67s 632ms/step - loss: 1.0877 - accuracy: 0.6202 - val_loss: 0.9969 - val_accuracy: 0.6526 Epoch 13/100 98/98 [==============================] - 66s 622ms/step - loss: 1.0120 - accuracy: 0.6445 - val_loss: 1.0054 - val_accuracy: 0.6457 Epoch 14/100 98/98 [==============================] - 67s 629ms/step - loss: 0.9721 - accuracy: 0.6620 - val_loss: 0.9890 - val_accuracy: 0.6531 Epoch 15/100 98/98 [==============================] - 67s 627ms/step - loss: 1.0118 - accuracy: 0.6485 - val_loss: 0.8240 - val_accuracy: 0.7164 Epoch 16/100 98/98 [==============================] - 67s 634ms/step - loss: 0.8661 - accuracy: 0.6961 - val_loss: 0.8105 - val_accuracy: 0.7148 Epoch 17/100 98/98 [==============================] - 67s 631ms/step - loss: 0.8365 - accuracy: 0.7076 - val_loss: 0.8673 - val_accuracy: 0.7106 Epoch 18/100 98/98 [==============================] - 67s 638ms/step - loss: 0.7939 - accuracy: 0.7200 - val_loss: 0.9348 - val_accuracy: 0.7002 Epoch 19/100 98/98 [==============================] - 70s 659ms/step - loss: 0.7548 - accuracy: 0.7371 - val_loss: 0.9441 - val_accuracy: 0.7137 Epoch 20/100 98/98 [==============================] - 72s 684ms/step - loss: 0.7254 - accuracy: 0.7493 - val_loss: 0.7852 - val_accuracy: 0.7308 Epoch 21/100 98/98 [==============================] - 67s 637ms/step - loss: 0.6884 - accuracy: 0.7604 - val_loss: 0.7827 - val_accuracy: 0.7539 Epoch 22/100 98/98 [==============================] - 66s 627ms/step - loss: 0.6703 - accuracy: 0.7661 - val_loss: 0.6435 - val_accuracy: 0.7755 Epoch 23/100 98/98 [==============================] - 67s 637ms/step - loss: 0.6349 - accuracy: 0.7790 - val_loss: 0.6341 - val_accuracy: 0.7806 Epoch 24/100 98/98 [==============================] - 67s 633ms/step - loss: 0.6040 - accuracy: 0.7882 - val_loss: 1.1893 - val_accuracy: 0.7360 Epoch 25/100 98/98 [==============================] - 71s 678ms/step - loss: 0.5804 - accuracy: 0.7971 - val_loss: 0.6380 - val_accuracy: 0.7992 Epoch 26/100 98/98 [==============================] - 70s 662ms/step - loss: 0.5789 - accuracy: 0.7970 - val_loss: 0.5683 - val_accuracy: 0.8024 Epoch 27/100 98/98 [==============================] - 67s 630ms/step - loss: 0.5501 - accuracy: 0.8078 - val_loss: 0.6145 - val_accuracy: 0.7967 Epoch 28/100 98/98 [==============================] - 66s 622ms/step - loss: 0.5329 - accuracy: 0.8135 - val_loss: 0.5557 - val_accuracy: 0.8100 Epoch 29/100 98/98 [==============================] - 68s 641ms/step - loss: 0.5112 - accuracy: 0.8227 - val_loss: 0.5435 - val_accuracy: 0.8206 Epoch 30/100 98/98 [==============================] - 66s 627ms/step - loss: 0.4964 - accuracy: 0.8272 - val_loss: 0.7384 - val_accuracy: 0.7746 Epoch 31/100 98/98 [==============================] - 67s 629ms/step - loss: 0.4890 - accuracy: 0.8289 - val_loss: 0.6565 - val_accuracy: 0.7866 Epoch 32/100 98/98 [==============================] - 68s 641ms/step - loss: 0.4700 - accuracy: 0.8383 - val_loss: 0.7538 - val_accuracy: 0.7881 Epoch 33/100 98/98 [==============================] - 67s 633ms/step - loss: 0.4668 - accuracy: 0.8368 - val_loss: 0.5000 - val_accuracy: 0.8297 Epoch 34/100 98/98 [==============================] - 67s 637ms/step - loss: 0.4425 - accuracy: 0.8480 - val_loss: 0.5721 - val_accuracy: 0.8093 Epoch 35/100 98/98 [==============================] - 68s 641ms/step - loss: 0.4215 - accuracy: 0.8531 - val_loss: 0.6007 - val_accuracy: 0.8047 Epoch 36/100 98/98 [==============================] - 67s 635ms/step - loss: 0.4156 - accuracy: 0.8553 - val_loss: 0.5727 - val_accuracy: 0.8144 Epoch 37/100 98/98 [==============================] - 67s 635ms/step - loss: 0.4091 - accuracy: 0.8592 - val_loss: 0.5009 - val_accuracy: 0.8340 Epoch 38/100 98/98 [==============================] - 67s 635ms/step - loss: 0.4002 - accuracy: 0.8595 - val_loss: 0.5630 - val_accuracy: 0.8236 Epoch 39/100 98/98 [==============================] - 68s 641ms/step - loss: 0.3822 - accuracy: 0.8673 - val_loss: 0.4742 - val_accuracy: 0.8419 Epoch 40/100 98/98 [==============================] - 70s 664ms/step - loss: 0.3588 - accuracy: 0.8766 - val_loss: 0.4799 - val_accuracy: 0.8422 Epoch 41/100 98/98 [==============================] - 67s 633ms/step - loss: 0.3549 - accuracy: 0.8762 - val_loss: 0.4908 - val_accuracy: 0.8396 Epoch 42/100 98/98 [==============================] - 70s 673ms/step - loss: 0.3447 - accuracy: 0.8787 - val_loss: 0.5453 - val_accuracy: 0.8311 Epoch 43/100 98/98 [==============================] - 68s 644ms/step - loss: 0.3497 - accuracy: 0.8782 - val_loss: 0.5162 - val_accuracy: 0.8385 Epoch 44/100 98/98 [==============================] - 67s 638ms/step - loss: 0.3273 - accuracy: 0.8856 - val_loss: 0.5166 - val_accuracy: 0.8308 Epoch 45/100 98/98 [==============================] - 68s 645ms/step - loss: 0.3143 - accuracy: 0.8917 - val_loss: 0.4778 - val_accuracy: 0.8448 Epoch 46/100 98/98 [==============================] - 67s 633ms/step - loss: 0.3052 - accuracy: 0.8931 - val_loss: 0.5501 - val_accuracy: 0.8382 Epoch 47/100 98/98 [==============================] - 67s 634ms/step - loss: 0.3007 - accuracy: 0.8947 - val_loss: 0.4818 - val_accuracy: 0.8509 Epoch 48/100 98/98 [==============================] - 68s 643ms/step - loss: 0.2832 - accuracy: 0.8985 - val_loss: 0.5162 - val_accuracy: 0.8439 Epoch 49/100 98/98 [==============================] - 68s 639ms/step - loss: 0.2729 - accuracy: 0.9048 - val_loss: 0.5436 - val_accuracy: 0.8367 20/20 [==============================] - 1s 30ms/step - loss: 0.4742 - accuracy: 0.8419 Test accuracy: 84.19%
rand_aug_model.save("rand_aug_model")
INFO:tensorflow:Assets written to: rand_aug_model/assets
simple_aug
¶simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial_weights.h5")
simple_aug_model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
simple_aug_model.fit(train_ds_simple,
validation_data=test_ds,
epochs=EPOCHS,
callbacks=[es])
_, test_acc = simple_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
Epoch 1/100 98/98 [==============================] - 28s 202ms/step - loss: 2.3730 - accuracy: 0.2644 - val_loss: 7.0031 - val_accuracy: 0.1286 Epoch 2/100 98/98 [==============================] - 19s 183ms/step - loss: 1.2847 - accuracy: 0.5453 - val_loss: 1.7744 - val_accuracy: 0.4162 Epoch 3/100 98/98 [==============================] - 19s 181ms/step - loss: 1.0809 - accuracy: 0.6203 - val_loss: nan - val_accuracy: 0.0737 Epoch 4/100 98/98 [==============================] - 19s 183ms/step - loss: 0.9696 - accuracy: 0.6610 - val_loss: 1.0999 - val_accuracy: 0.6323 Epoch 5/100 98/98 [==============================] - 19s 183ms/step - loss: 0.9832 - accuracy: 0.6644 - val_loss: nan - val_accuracy: 0.1006 Epoch 6/100 98/98 [==============================] - 19s 183ms/step - loss: 1.0794 - accuracy: 0.6355 - val_loss: 12.5741 - val_accuracy: 0.1205 Epoch 7/100 98/98 [==============================] - 19s 182ms/step - loss: 0.8914 - accuracy: 0.6890 - val_loss: 0.8805 - val_accuracy: 0.7018 Epoch 8/100 98/98 [==============================] - 19s 181ms/step - loss: 0.6830 - accuracy: 0.7598 - val_loss: 0.7834 - val_accuracy: 0.7346 Epoch 9/100 98/98 [==============================] - 19s 182ms/step - loss: 0.5906 - accuracy: 0.7949 - val_loss: 0.7656 - val_accuracy: 0.7445 Epoch 10/100 98/98 [==============================] - 19s 184ms/step - loss: 0.5317 - accuracy: 0.8146 - val_loss: 0.7136 - val_accuracy: 0.7549 Epoch 11/100 98/98 [==============================] - 19s 182ms/step - loss: 0.4830 - accuracy: 0.8303 - val_loss: 0.7174 - val_accuracy: 0.7580 Epoch 12/100 98/98 [==============================] - 19s 183ms/step - loss: 0.4508 - accuracy: 0.8427 - val_loss: 0.6619 - val_accuracy: 0.7824 Epoch 13/100 98/98 [==============================] - 19s 181ms/step - loss: 0.4086 - accuracy: 0.8557 - val_loss: 0.7537 - val_accuracy: 0.7533 Epoch 14/100 98/98 [==============================] - 19s 183ms/step - loss: 0.3770 - accuracy: 0.8678 - val_loss: 0.6286 - val_accuracy: 0.7898 Epoch 15/100 98/98 [==============================] - 19s 182ms/step - loss: 0.3477 - accuracy: 0.8779 - val_loss: 0.6000 - val_accuracy: 0.8014 Epoch 16/100 98/98 [==============================] - 19s 184ms/step - loss: 0.3211 - accuracy: 0.8884 - val_loss: 0.6156 - val_accuracy: 0.8045 Epoch 17/100 98/98 [==============================] - 19s 183ms/step - loss: 0.2923 - accuracy: 0.8966 - val_loss: 0.8128 - val_accuracy: 0.7648 Epoch 18/100 98/98 [==============================] - 19s 182ms/step - loss: 0.2739 - accuracy: 0.9036 - val_loss: 0.6538 - val_accuracy: 0.7948 Epoch 19/100 98/98 [==============================] - 19s 185ms/step - loss: 0.2517 - accuracy: 0.9121 - val_loss: 0.6547 - val_accuracy: 0.8092 Epoch 20/100 98/98 [==============================] - 20s 186ms/step - loss: 0.2416 - accuracy: 0.9174 - val_loss: 0.6659 - val_accuracy: 0.8075 Epoch 21/100 98/98 [==============================] - 19s 184ms/step - loss: 0.2173 - accuracy: 0.9243 - val_loss: 0.6265 - val_accuracy: 0.8131 Epoch 22/100 98/98 [==============================] - 19s 184ms/step - loss: 0.2059 - accuracy: 0.9285 - val_loss: 0.6124 - val_accuracy: 0.8186 Epoch 23/100 98/98 [==============================] - 19s 185ms/step - loss: 0.1905 - accuracy: 0.9334 - val_loss: 0.6885 - val_accuracy: 0.8131 Epoch 24/100 98/98 [==============================] - 19s 185ms/step - loss: 0.1830 - accuracy: 0.9359 - val_loss: 0.6368 - val_accuracy: 0.8256 Epoch 25/100 98/98 [==============================] - 19s 182ms/step - loss: 0.1661 - accuracy: 0.9421 - val_loss: 0.7623 - val_accuracy: 0.8005 20/20 [==============================] - 1s 31ms/step - loss: 0.6000 - accuracy: 0.8014 Test accuracy: 80.14%
simple_aug_model.save("simple_aug_model")
INFO:tensorflow:Assets written to: simple_aug_model/assets
# Load and prepare the CIFAR-10-C dataset
# (If it's not already downloaded, it takes ~10 minutes of time to download)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test",
as_supervised=True)
cifar_10_c = (
cifar_10_c
.batch(BATCH_SIZE)
.map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO))
Downloading and preparing dataset 2.72 GiB (download: 2.72 GiB, generated: Unknown size, total: 2.72 GiB) to /home/jupyter/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0... Dataset cifar10_corrupted downloaded and prepared to /home/jupyter/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. Subsequent calls will reuse this data.
# Evaluate `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))
# Evaluate `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))
Accuracy with RandAugment on CIFAR-10-C (saturate_5): 76.64% Accuracy with simple_aug on CIFAR-10-C (saturate_5): 64.80%