To run any of Eden's notebooks, please check the guides on our Wiki page.
There you will find instructions on how to deploy the notebooks on your local system, on Google Colab, or on MyBinder, as well as other useful links, troubleshooting tips, and more.
Note: If you find any issues while executing the notebook, don't hesitate to open an issue on Github. We will try to reply as soon as possible.
During the last years, the tendency for applying transfer learning was to directly fine-tune the ImageNet weights in the new domain problem where (usually) there was a lack of images available for training. However, a new trend has arisen recently; in this case, before fine-tuning, there is a previous step called pre-training where the neural network trains with the target images to learn how to put images with the same class as close as possible in the latent space. And, on the other hand, to put images with a different class as far as possible in the latent space. Depending on whether this pre-training phase uses labels or not this stage will be supervised or unsupervised.
Therefore, Supervised Contrastive Learning [1] is a training approach that may outperform supervised training with the traditional cross-entropy loss function on classification tasks. Essentially, training an image classifier under this approach has two phases: (Pre-)Training an encoder (e.g.: ResNet or EfficientNet) to produce vector representations of input images where the representations of images in the same class will be more similar compared to representations of images in different classes. Specifically, during this phase, the supervised contrastive loss is used as an alternative to cross-entropy. Training a classifier on top of the frozen encoder. The classifier will take advantage of the clusters of points belonging to the same class which are pulled together in embedding space.
In this notebook, this technique is implemented. We have used this tutorial as a source: https://keras.io/examples/vision/supervised-contrastive-learning/
!pip install tensorflow-addons
Requirement already satisfied: tensorflow-addons in /home/beast/anaconda3/envs/image_classification_v2/lib/python3.7/site-packages (0.15.0) Requirement already satisfied: typeguard>=2.7 in /home/beast/anaconda3/envs/image_classification_v2/lib/python3.7/site-packages (from tensorflow-addons) (2.13.2)
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import cv2
import os
import csv
import gc
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob
from pathlib import Path
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from keras import layers
from tensorflow import keras
from keras import applications
from keras import layers
from keras.models import Model
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
Check the docstrings for more information.
# Function for plotting images.
def plot_samples(dataset):
sample_images, sample_labels = next(iter(dataset))
plt.figure(figsize=(8, 8))
for ix, sample_image in enumerate(sample_images[:9]):
plt.subplot(3, 3, ix+1)
plt.imshow(sample_image.numpy().astype(np.uint8))
plt.axis("off")
plt.show()
def read_data(path_list, im_size=(224,224)):
"""
Given the list of paths where the images are stored <path_list>,
and the size for image decimation <im_size>, it returns 2 Numpy Arrays
with the images and labels; and a dictionary with the mapping between
classes and folders. This will be used later for displaying the predicted
labels.
Parameters:
path_list (List[String]): The list of paths to the images.
im_size (Tuple): The height and width values.
Returns:
X (ndarray): Images
y (ndarray): Labels
tag2idx (dict): Map between labels and folders.
"""
X = []
y = []
# Exctract the file-names of the datasets we read and create a label dictionary.
tag2idx = {tag.split(os.path.sep)[-2]:i for i, tag in enumerate(path_list)}
print(tag2idx)
for path in path_list:
for im_file in tqdm(glob(path + '*/*')): # Read all files in path
try:
# os.path.separator is OS agnostic (either '/' or '\'),[-2] to grab folder name.
label = im_file.split(os.path.sep)[-3]
im = cv2.imread(im_file, cv2.IMREAD_COLOR)
# By default OpenCV read with BGR format, return back to RGB.
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# Resize to appropriate dimensions.You can try different interpolation methods.
#im = quantize_image(im)
im = cv2.resize(im, im_size,interpolation=cv2.INTER_AREA)
X.append(im)
y.append(tag2idx[label])# Append the label name to y
except Exception as e:
# In case annotations or metadata are found
print("Not a picture")
X = np.array(X) # Convert list to numpy array.
y = np.array(y).astype(np.uint8)
return X, y
# Callbacks are used for saving the best weights and
# early stopping.
def get_callbacks(patience):
"""
Callbacks are used for saving the best weights and early stopping.
Given some configuration parameters, it creates the callbacks that
will be used by Keras after each epoch.
Parameters:
weights_file (String): File name for saving the best model weights.
patience (Integer): Number of epochs without improvement to wait.
Returns:
callbacks (List[Callbacks]): Configured callbacks ready to use.
"""
return [
# If val_loss doesn't improve for a number of epochs set with 'patience' var
# training will stop to avoid overfitting.
EarlyStopping(monitor="val_loss",
mode="min",
patience = patience,
restore_best_weights=True,
verbose=1),
ModelCheckpoint(filepath='best-weights.h5',
verbose=1, monitor='val_loss',
save_weights_only=True,
save_best_only=True
)
]
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.02),
layers.RandomWidth(0.1),
layers.RandomHeight(0.1),
])
def create_encoder():
mobilenet = keras.applications.MobileNetV3Small(
include_top=False,
input_shape=INPUT_SHAPE,
weights="imagenet",
pooling="avg"
)
inputs = layers.Input(shape=INPUT_SHAPE)
augmented = data_augmentation(inputs)
features = layers.Lambda(keras.applications.mobilenet_v3.preprocess_input)(augmented)
outputs = mobilenet(features)
model = keras.Model(inputs, outputs, name="encoder")
return model
def create_classifier(encoder, trainable=True):
for layer in encoder.layers:
layer.trainable = trainable
inputs = layers.Input(shape=INPUT_SHAPE)
features = encoder(inputs)
features = layers.Dense(HIDDEN_UNITS, activation="relu")(features)
features = layers.Dropout(DROPOUT_RATE)(features)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(features)
model = keras.Model(inputs, outputs, name="classifier")
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy']
)
return model
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super(SupervisedContrastiveLoss, self).__init__(name=name)
self.temperature = temperature
def __call__(self, labels, feature_vectors, sample_weight):
normalized_feature_vectors = tf.math.l2_normalize(feature_vectors, axis=1)
logits = tf.divide(
tf.matmul(normalized_feature_vectors, tf.transpose(normalized_feature_vectors)),
self.temperature
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
def add_projection_head(encoder):
inputs = keras.Input(shape=INPUT_SHAPE)
features = encoder(inputs)
outputs = layers.Dense(PROJECTION_UNITS, activation="relu")(features)
model = keras.Model(
inputs, outputs, name="encoder-with-projection"
)
return model
PRETRAINING_EPOCHS = 50
HIDDEN_UNITS = 64
PROJECTION_UNITS = 128
DROPOUT_RATE = 0.25
TEMPERATURE = 0.05
INPUT_SHAPE = (224, 224, 3)
IM_SIZE = (224, 224)
NUM_EPOCHS = 20
BATCH_SIZE = 32
TEST_SPLIT = 0.2
VAL_SPLIT = 0.2
RANDOM_STATE = 2022
#WEIGHTS_FILE = "weights.h5"# File that stores updated weights
LEARNING_RATE = 1e-3
BASE_PATH = "eden_library_datasets" + os.path.sep
AUTO = tf.data.AUTOTUNE
# Datasets' paths we want to work on.
PATH_LIST = [
BASE_PATH + "Cucumber-Cucumis_sativus-Healthy-PRX-RGB-NA-20210607/images",
BASE_PATH + "Watermelon-Citrullus_lanatus-Healthy-PRX-RGB-NA-20210410/images",
BASE_PATH + "Tomato-Solanum_lycopersicum-Healthy-PRX-RGB-NA-20210607/images",
BASE_PATH + "Processing_tomato-Lycopersicum_esculentum-Healthy-PRX-RGB-NA-20200731/images",
]
tf.random.set_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
i=0
for path in PATH_LIST:
#Define paths in an OS agnostic way.
PATH_LIST[i] = str(Path(Path.cwd()).parents[0].joinpath(path))
i+=1
x, y = read_data(PATH_LIST, IM_SIZE)
{'Cucumber-Cucumis_sativus-Healthy-PRX-RGB-NA-20210607': 0, 'Watermelon-Citrullus_lanatus-Healthy-PRX-RGB-NA-20210410': 1, 'Tomato-Solanum_lycopersicum-Healthy-PRX-RGB-NA-20210607': 2, 'Processing_tomato-Lycopersicum_esculentum-Healthy-PRX-RGB-NA-20200731': 3}
100%|███████████████████████████████████████████████████████████████████████████████████| 145/145 [01:01<00:00, 2.38it/s] 3%|██▎ | 4/147 [00:01<01:04, 2.21it/s]Corrupt JPEG data: 2446 extraneous bytes before marker 0xd9 100%|███████████████████████████████████████████████████████████████████████████████████| 147/147 [01:05<00:00, 2.23it/s] 100%|███████████████████████████████████████████████████████████████████████████████████| 186/186 [01:18<00:00, 2.37it/s] 100%|███████████████████████████████████████████████████████████████████████████████████| 422/422 [01:16<00:00, 5.51it/s]
NUM_CLASSES = len(np.unique(y))
NUM_CLASSES
4
x_train, x_test, y_train, y_test = train_test_split(x, y,
test_size=TEST_SPLIT,
shuffle = True,
stratify = y,
random_state = RANDOM_STATE)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train,
test_size=VAL_SPLIT,
shuffle = True,
stratify = y_train,
random_state = RANDOM_STATE)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
train_ds = (
train_ds.shuffle(BATCH_SIZE, seed=RANDOM_STATE)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
val_ds = (
val_ds.batch(BATCH_SIZE)
.prefetch(AUTO)
)
test_ds = (
test_ds.batch(BATCH_SIZE)
.prefetch(AUTO)
)
plot_samples(train_ds)
encoder = create_encoder()
classifier = create_classifier(encoder)
classifier.fit(
train_ds,
validation_data=val_ds,
epochs=NUM_EPOCHS,
callbacks = get_callbacks(patience=NUM_EPOCHS//2)
)
Epoch 1/20 18/18 [==============================] - ETA: 0s - loss: 0.2771 - accuracy: 0.9062 Epoch 1: val_loss improved from inf to 0.72599, saving model to best-weights.h5 18/18 [==============================] - 6s 107ms/step - loss: 0.2771 - accuracy: 0.9062 - val_loss: 0.7260 - val_accuracy: 0.7708 Epoch 2/20 18/18 [==============================] - ETA: 0s - loss: 0.0285 - accuracy: 0.9878 Epoch 2: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 59ms/step - loss: 0.0285 - accuracy: 0.9878 - val_loss: 0.9224 - val_accuracy: 0.7431 Epoch 3/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0402 - accuracy: 0.9908 Epoch 3: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 55ms/step - loss: 0.0380 - accuracy: 0.9913 - val_loss: 1.2759 - val_accuracy: 0.7569 Epoch 4/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0024 - accuracy: 1.0000 Epoch 4: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 54ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 1.2769 - val_accuracy: 0.8194 Epoch 5/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0150 - accuracy: 0.9963 Epoch 5: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 60ms/step - loss: 0.0142 - accuracy: 0.9965 - val_loss: 1.4518 - val_accuracy: 0.7847 Epoch 6/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0480 - accuracy: 0.9835 Epoch 6: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 60ms/step - loss: 0.0464 - accuracy: 0.9844 - val_loss: 1.9092 - val_accuracy: 0.8264 Epoch 7/20 18/18 [==============================] - ETA: 0s - loss: 0.0151 - accuracy: 0.9965 Epoch 7: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 62ms/step - loss: 0.0151 - accuracy: 0.9965 - val_loss: 1.9572 - val_accuracy: 0.7986 Epoch 8/20 18/18 [==============================] - ETA: 0s - loss: 0.0031 - accuracy: 0.9983 Epoch 8: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 62ms/step - loss: 0.0031 - accuracy: 0.9983 - val_loss: 1.3774 - val_accuracy: 0.8125 Epoch 9/20 18/18 [==============================] - ETA: 0s - loss: 0.0010 - accuracy: 1.0000 Epoch 9: val_loss did not improve from 0.72599 18/18 [==============================] - 1s 54ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.9177 - val_accuracy: 0.8472 Epoch 10/20 18/18 [==============================] - ETA: 0s - loss: 0.0019 - accuracy: 1.0000 Epoch 10: val_loss improved from 0.72599 to 0.60299, saving model to best-weights.h5 18/18 [==============================] - 1s 60ms/step - loss: 0.0019 - accuracy: 1.0000 - val_loss: 0.6030 - val_accuracy: 0.8681 Epoch 11/20 18/18 [==============================] - ETA: 0s - loss: 7.3683e-04 - accuracy: 1.0000 Epoch 11: val_loss improved from 0.60299 to 0.46597, saving model to best-weights.h5 18/18 [==============================] - 1s 63ms/step - loss: 7.3683e-04 - accuracy: 1.0000 - val_loss: 0.4660 - val_accuracy: 0.9028 Epoch 12/20 18/18 [==============================] - ETA: 0s - loss: 3.6376e-04 - accuracy: 1.0000 Epoch 12: val_loss improved from 0.46597 to 0.35918, saving model to best-weights.h5 18/18 [==============================] - 1s 62ms/step - loss: 3.6376e-04 - accuracy: 1.0000 - val_loss: 0.3592 - val_accuracy: 0.9097 Epoch 13/20 17/18 [===========================>..] - ETA: 0s - loss: 3.3247e-04 - accuracy: 1.0000 Epoch 13: val_loss improved from 0.35918 to 0.28947, saving model to best-weights.h5 18/18 [==============================] - 1s 60ms/step - loss: 3.1768e-04 - accuracy: 1.0000 - val_loss: 0.2895 - val_accuracy: 0.9167 Epoch 14/20 18/18 [==============================] - ETA: 0s - loss: 5.6739e-04 - accuracy: 1.0000 Epoch 14: val_loss did not improve from 0.28947 18/18 [==============================] - 1s 53ms/step - loss: 5.6739e-04 - accuracy: 1.0000 - val_loss: 0.3160 - val_accuracy: 0.9097 Epoch 15/20 18/18 [==============================] - ETA: 0s - loss: 3.9405e-04 - accuracy: 1.0000 Epoch 15: val_loss did not improve from 0.28947 18/18 [==============================] - 1s 60ms/step - loss: 3.9405e-04 - accuracy: 1.0000 - val_loss: 0.2972 - val_accuracy: 0.9167 Epoch 16/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0057 - accuracy: 0.9982 Epoch 16: val_loss improved from 0.28947 to 0.17355, saving model to best-weights.h5 18/18 [==============================] - 1s 62ms/step - loss: 0.0054 - accuracy: 0.9983 - val_loss: 0.1736 - val_accuracy: 0.9236 Epoch 17/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0161 - accuracy: 0.9963 Epoch 17: val_loss did not improve from 0.17355 18/18 [==============================] - 1s 54ms/step - loss: 0.0152 - accuracy: 0.9965 - val_loss: 0.6960 - val_accuracy: 0.8542 Epoch 18/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0028 - accuracy: 0.9982 Epoch 18: val_loss did not improve from 0.17355 18/18 [==============================] - 1s 53ms/step - loss: 0.0026 - accuracy: 0.9983 - val_loss: 0.8840 - val_accuracy: 0.8542 Epoch 19/20 17/18 [===========================>..] - ETA: 0s - loss: 2.6565e-04 - accuracy: 1.0000 Epoch 19: val_loss did not improve from 0.17355 18/18 [==============================] - 1s 53ms/step - loss: 2.6334e-04 - accuracy: 1.0000 - val_loss: 0.6347 - val_accuracy: 0.8750 Epoch 20/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0042 - accuracy: 0.9982 Epoch 20: val_loss improved from 0.17355 to 0.04853, saving model to best-weights.h5 18/18 [==============================] - 1s 69ms/step - loss: 0.0051 - accuracy: 0.9983 - val_loss: 0.0485 - val_accuracy: 0.9861
<keras.callbacks.History at 0x7f4a6c1fced0>
classifier.load_weights('best-weights.h5')
print()
print("*"*50)
accuracy = classifier.evaluate(test_ds, verbose=0)[1]
print(round(accuracy, 2))
print("*"*50)
print()
************************************************** 0.97 **************************************************
encoder_pretrained = create_encoder()
encoder_with_projection_head = add_projection_head(encoder_pretrained)
encoder_with_projection_head.compile(
loss=SupervisedContrastiveLoss(temperature=TEMPERATURE),
optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE)
)
encoder_with_projection_head.fit(train_ds,
epochs=PRETRAINING_EPOCHS)
Epoch 1/50 18/18 [==============================] - 5s 50ms/step - loss: 2.6139 Epoch 2/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2725 Epoch 3/50 18/18 [==============================] - 1s 50ms/step - loss: 2.2763 Epoch 4/50 18/18 [==============================] - 1s 52ms/step - loss: 2.2764 Epoch 5/50 18/18 [==============================] - 1s 53ms/step - loss: 2.2672 Epoch 6/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2646 Epoch 7/50 18/18 [==============================] - 1s 46ms/step - loss: 2.2528 Epoch 8/50 18/18 [==============================] - 1s 54ms/step - loss: 2.2516 Epoch 9/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2695 Epoch 10/50 18/18 [==============================] - 1s 44ms/step - loss: 2.2594 Epoch 11/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2662 Epoch 12/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2647 Epoch 13/50 18/18 [==============================] - 1s 54ms/step - loss: 2.2605 Epoch 14/50 18/18 [==============================] - 1s 48ms/step - loss: 2.2709 Epoch 15/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2587 Epoch 16/50 18/18 [==============================] - 1s 48ms/step - loss: 2.2650 Epoch 17/50 18/18 [==============================] - 1s 50ms/step - loss: 2.2599 Epoch 18/50 18/18 [==============================] - 1s 46ms/step - loss: 2.2593 Epoch 19/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2549 Epoch 20/50 18/18 [==============================] - 1s 44ms/step - loss: 2.2610 Epoch 21/50 18/18 [==============================] - 1s 48ms/step - loss: 2.2567 Epoch 22/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2467 Epoch 23/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2759 Epoch 24/50 18/18 [==============================] - 1s 52ms/step - loss: 2.2502 Epoch 25/50 18/18 [==============================] - 1s 48ms/step - loss: 2.2424 Epoch 26/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2494 Epoch 27/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2573 Epoch 28/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2654 Epoch 29/50 18/18 [==============================] - 1s 48ms/step - loss: 2.2570 Epoch 30/50 18/18 [==============================] - 1s 43ms/step - loss: 2.2617 Epoch 31/50 18/18 [==============================] - 1s 46ms/step - loss: 2.2468 Epoch 32/50 18/18 [==============================] - 1s 44ms/step - loss: 2.2446 Epoch 33/50 18/18 [==============================] - 1s 44ms/step - loss: 2.2678 Epoch 34/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2716 Epoch 35/50 18/18 [==============================] - 1s 43ms/step - loss: 2.2905 Epoch 36/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2531 Epoch 37/50 18/18 [==============================] - 1s 41ms/step - loss: 2.2681 Epoch 38/50 18/18 [==============================] - 1s 41ms/step - loss: 2.2512 Epoch 39/50 18/18 [==============================] - 1s 50ms/step - loss: 2.2674 Epoch 40/50 18/18 [==============================] - 1s 46ms/step - loss: 2.2518 Epoch 41/50 18/18 [==============================] - 1s 43ms/step - loss: 2.2732 Epoch 42/50 18/18 [==============================] - 1s 41ms/step - loss: 2.2624 Epoch 43/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2824 Epoch 44/50 18/18 [==============================] - 1s 49ms/step - loss: 2.2811 Epoch 45/50 18/18 [==============================] - 1s 41ms/step - loss: 2.3021 Epoch 46/50 18/18 [==============================] - 1s 47ms/step - loss: 2.2661 Epoch 47/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2783 Epoch 48/50 18/18 [==============================] - 1s 45ms/step - loss: 2.2665 Epoch 49/50 18/18 [==============================] - 1s 42ms/step - loss: 2.2631 Epoch 50/50 18/18 [==============================] - 1s 41ms/step - loss: 2.2491
<keras.callbacks.History at 0x7f49f4506f50>
classifier_pretrained = create_classifier(encoder_pretrained, trainable=False)
classifier_pretrained.fit(
train_ds,
validation_data=val_ds,
epochs=NUM_EPOCHS,
callbacks = get_callbacks(patience=NUM_EPOCHS//2)
)
Epoch 1/20 18/18 [==============================] - ETA: 0s - loss: 1.1587 - accuracy: 0.8368 Epoch 1: val_loss improved from inf to 0.20861, saving model to best-weights.h5 18/18 [==============================] - 4s 70ms/step - loss: 1.1587 - accuracy: 0.8368 - val_loss: 0.2086 - val_accuracy: 0.9444 Epoch 2/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0196 - accuracy: 0.9982 Epoch 2: val_loss improved from 0.20861 to 0.12942, saving model to best-weights.h5 18/18 [==============================] - 1s 32ms/step - loss: 0.0187 - accuracy: 0.9983 - val_loss: 0.1294 - val_accuracy: 0.9861 Epoch 3/20 16/18 [=========================>....] - ETA: 0s - loss: 0.0149 - accuracy: 0.9980 Epoch 3: val_loss improved from 0.12942 to 0.08812, saving model to best-weights.h5 18/18 [==============================] - 1s 29ms/step - loss: 0.0135 - accuracy: 0.9983 - val_loss: 0.0881 - val_accuracy: 0.9861 Epoch 4/20 15/18 [========================>.....] - ETA: 0s - loss: 0.0134 - accuracy: 0.9958 Epoch 4: val_loss improved from 0.08812 to 0.05538, saving model to best-weights.h5 18/18 [==============================] - 0s 27ms/step - loss: 0.0129 - accuracy: 0.9965 - val_loss: 0.0554 - val_accuracy: 0.9861 Epoch 5/20 18/18 [==============================] - ETA: 0s - loss: 0.0033 - accuracy: 1.0000 Epoch 5: val_loss improved from 0.05538 to 0.04286, saving model to best-weights.h5 18/18 [==============================] - 1s 28ms/step - loss: 0.0033 - accuracy: 1.0000 - val_loss: 0.0429 - val_accuracy: 0.9861 Epoch 6/20 18/18 [==============================] - ETA: 0s - loss: 0.0087 - accuracy: 0.9983 Epoch 6: val_loss did not improve from 0.04286 18/18 [==============================] - 0s 21ms/step - loss: 0.0087 - accuracy: 0.9983 - val_loss: 0.0454 - val_accuracy: 0.9861 Epoch 7/20 16/18 [=========================>....] - ETA: 0s - loss: 0.0042 - accuracy: 1.0000 Epoch 7: val_loss improved from 0.04286 to 0.03729, saving model to best-weights.h5 18/18 [==============================] - 1s 29ms/step - loss: 0.0037 - accuracy: 1.0000 - val_loss: 0.0373 - val_accuracy: 0.9861 Epoch 8/20 16/18 [=========================>....] - ETA: 0s - loss: 0.0107 - accuracy: 0.9980 Epoch 8: val_loss improved from 0.03729 to 0.02254, saving model to best-weights.h5 18/18 [==============================] - 0s 28ms/step - loss: 0.0097 - accuracy: 0.9983 - val_loss: 0.0225 - val_accuracy: 1.0000 Epoch 9/20 18/18 [==============================] - ETA: 0s - loss: 0.0039 - accuracy: 1.0000 Epoch 9: val_loss improved from 0.02254 to 0.01731, saving model to best-weights.h5 18/18 [==============================] - 1s 29ms/step - loss: 0.0039 - accuracy: 1.0000 - val_loss: 0.0173 - val_accuracy: 1.0000 Epoch 10/20 15/18 [========================>.....] - ETA: 0s - loss: 0.0055 - accuracy: 1.0000 Epoch 10: val_loss improved from 0.01731 to 0.01556, saving model to best-weights.h5 18/18 [==============================] - 0s 28ms/step - loss: 0.0047 - accuracy: 1.0000 - val_loss: 0.0156 - val_accuracy: 1.0000 Epoch 11/20 18/18 [==============================] - ETA: 0s - loss: 0.0011 - accuracy: 1.0000 Epoch 11: val_loss did not improve from 0.01556 18/18 [==============================] - 0s 21ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0156 - val_accuracy: 1.0000 Epoch 12/20 18/18 [==============================] - ETA: 0s - loss: 0.0010 - accuracy: 1.0000 Epoch 12: val_loss improved from 0.01556 to 0.01495, saving model to best-weights.h5 18/18 [==============================] - 1s 29ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 1.0000 Epoch 13/20 15/18 [========================>.....] - ETA: 0s - loss: 0.0019 - accuracy: 1.0000 Epoch 13: val_loss improved from 0.01495 to 0.01411, saving model to best-weights.h5 18/18 [==============================] - 1s 29ms/step - loss: 0.0030 - accuracy: 1.0000 - val_loss: 0.0141 - val_accuracy: 1.0000 Epoch 14/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0123 - accuracy: 0.9963 Epoch 14: val_loss improved from 0.01411 to 0.00908, saving model to best-weights.h5 18/18 [==============================] - 0s 28ms/step - loss: 0.0117 - accuracy: 0.9965 - val_loss: 0.0091 - val_accuracy: 1.0000 Epoch 15/20 15/18 [========================>.....] - ETA: 0s - loss: 6.6987e-04 - accuracy: 1.0000 Epoch 15: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 20ms/step - loss: 6.1664e-04 - accuracy: 1.0000 - val_loss: 0.0095 - val_accuracy: 1.0000 Epoch 16/20 18/18 [==============================] - ETA: 0s - loss: 0.0074 - accuracy: 0.9965 Epoch 16: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 21ms/step - loss: 0.0074 - accuracy: 0.9965 - val_loss: 0.0194 - val_accuracy: 0.9861 Epoch 17/20 18/18 [==============================] - ETA: 0s - loss: 0.0025 - accuracy: 1.0000 Epoch 17: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 21ms/step - loss: 0.0025 - accuracy: 1.0000 - val_loss: 0.0191 - val_accuracy: 0.9861 Epoch 18/20 17/18 [===========================>..] - ETA: 0s - loss: 8.4984e-04 - accuracy: 1.0000 Epoch 18: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 22ms/step - loss: 8.0332e-04 - accuracy: 1.0000 - val_loss: 0.0168 - val_accuracy: 0.9931 Epoch 19/20 16/18 [=========================>....] - ETA: 0s - loss: 0.0017 - accuracy: 1.0000 Epoch 19: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 20ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0150 - val_accuracy: 1.0000 Epoch 20/20 17/18 [===========================>..] - ETA: 0s - loss: 0.0015 - accuracy: 1.0000 Epoch 20: val_loss did not improve from 0.00908 18/18 [==============================] - 0s 21ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 0.0148 - val_accuracy: 1.0000
<keras.callbacks.History at 0x7f4a6c1fc4d0>
classifier_pretrained.load_weights('best-weights.h5')
print()
print("*"*50)
accuracy = classifier_pretrained.evaluate(test_ds, verbose=0)[1]
print(round(accuracy, 2))
print("*"*50)
print()
************************************************** 1.0 **************************************************
Contrastive supervised learning can improve the performance of the image classifier over the traditional cross-entropy approach. At least, at the beginning, after the pre-training, the model starts the training process from a better position that can lead to a more efficient training procedure.
[1] Khosla, P., Teterwak, P., Wang, C., Sarna, A., Tian, Y., Isola, P., Maschinot, A., Liu, C., & Krishnan, D. (2020). Supervised Contrastive Learning. ArXiv, abs/2004.11362.