This example is from Stefan Wunsch (CERN IML TensoFlow and Keras workshop). See also the example on the Keras website.
The MNIST dataset is one of the most popular benchmark-datasets in modern machine learning. The dataset consists of 70000 images of handwritten digits and associated labels, which can be used to train neural network performing image classification.
The following program presents the basic workflow of Keras showing the most import details of the API.
from os import environ
environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
np.random.seed(1234)
import matplotlib.pyplot as plt
The code below downloads the dataset and performs a scaling of the pixel-values of the images. Because the images are encoded with 8-bit unsigned int values, we scale these values to floating-point values in the range [0, 1)
so that the inputs match the activation of the neurons better.
from keras.datasets import mnist
from keras.utils.np_utils import to_categorical
# Download dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# The data is loaded as flat array with 784 entries (28x28),
# we need to reshape it into an array with shape:
# (num_images, pixels_row, pixels_column, color channels)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# Convert the uint8 PNG greyscale pixel values in range [0, 255]
# to floats in range [0, 1]
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
# Convert digits to one-hot vectors, e.g.,
# 2 -> [0 0 1 0 0 0 0 0 0 0]
# 0 -> [1 0 0 0 0 0 0 0 0 0]
# 9 -> [0 0 0 0 0 0 0 0 0 1]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
Addtionally, we store some example images to disk to show later on the inference part of the Keras API.
import png
num_examples = 6
# offset = 100
offset = 0
plt.figure(figsize=(num_examples*2, 2))
for i in range(num_examples):
plt.subplot(1, num_examples, i+1)
plt.axis('off')
example = np.squeeze(np.array(x_test[offset+i]*255).astype("uint8"))
plt.imshow(example, cmap="gray")
w = png.Writer(28, 28, greyscale=True)
w.write(open("mnist_example_{}.png".format(i+1), 'wb'), example)
from keras.models import Sequential
from keras.layers import Dense, Flatten, MaxPooling2D, Conv2D
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
The model definition in Keras can be done using the Sequential
or the functional API. Shown here is the Sequential
API allowing to stack neural network layers on top of each other, which is feasible for most neural network models. In contrast, the functional API would allow to have multiple inputs and outputs for a maximum of flexibility to build your custom model.
from keras.models import Sequential
from keras.layers import Dense, Flatten, MaxPooling2D, Conv2D, Input, Dropout
# conv layer with 8 3x3 filters
model = Sequential(
[
Input(shape=input_shape),
Conv2D(8, kernel_size=(3, 3), activation="relu"),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(16, activation="relu"),
Dense(num_classes, activation="softmax"),
]
)
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 26, 26, 8) 80 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 13, 13, 8) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 1352) 0 _________________________________________________________________ dense_2 (Dense) (None, 16) 21648 _________________________________________________________________ dense_3 (Dense) (None, 10) 170 ================================================================= Total params: 21,898 Trainable params: 21,898 Non-trainable params: 0 _________________________________________________________________
Using Keras, you have to compile
a model, which means adding the loss function, the optimizer algorithm and validation metrics to your training setup.
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
The cell below shows the training procedure of Keras using the model.fit(...)
method. Besides typical options such as batch_size
and epochs
, which control the number of gradient steps of your training, Keras allows to use callbacks during training.
Callbacks are methods, which are called during training to perform tasks such as saving checkpoints of the model (ModelCheckpoint
) or stop the training early if a convergence criteria is met (EarlyStopping
).
from keras.callbacks import ModelCheckpoint, EarlyStopping
checkpoint = ModelCheckpoint(
filepath="mnist_keras_model.h5",
save_best_only=True,
verbose=1)
early_stopping = EarlyStopping(patience=2)
history = model.fit(x_train, y_train, # Training data
batch_size=200, # Batch size
epochs=50, # Maximum number of training epochs
validation_split=0.5, # Use 50% of the train dataset for validation
callbacks=[checkpoint, early_stopping]) # Register callbacks
Epoch 1/50 150/150 [==============================] - 9s 51ms/step - loss: 1.5244 - accuracy: 0.5341 - val_loss: 0.3984 - val_accuracy: 0.8842 Epoch 00001: val_loss improved from inf to 0.39840, saving model to mnist_keras_model.h5 Epoch 2/50 150/150 [==============================] - 6s 43ms/step - loss: 0.3608 - accuracy: 0.8961 - val_loss: 0.2784 - val_accuracy: 0.9201 Epoch 00002: val_loss improved from 0.39840 to 0.27837, saving model to mnist_keras_model.h5 Epoch 3/50 150/150 [==============================] - 8s 51ms/step - loss: 0.2603 - accuracy: 0.9262 - val_loss: 0.2298 - val_accuracy: 0.9337 Epoch 00003: val_loss improved from 0.27837 to 0.22984, saving model to mnist_keras_model.h5 Epoch 4/50 150/150 [==============================] - 6s 40ms/step - loss: 0.2205 - accuracy: 0.9398 - val_loss: 0.1960 - val_accuracy: 0.9428 Epoch 00004: val_loss improved from 0.22984 to 0.19601, saving model to mnist_keras_model.h5 Epoch 5/50 150/150 [==============================] - 6s 39ms/step - loss: 0.1845 - accuracy: 0.9470 - val_loss: 0.1819 - val_accuracy: 0.9470 Epoch 00005: val_loss improved from 0.19601 to 0.18192, saving model to mnist_keras_model.h5 Epoch 6/50 150/150 [==============================] - 9s 63ms/step - loss: 0.1624 - accuracy: 0.9535 - val_loss: 0.1647 - val_accuracy: 0.9517 Epoch 00006: val_loss improved from 0.18192 to 0.16470, saving model to mnist_keras_model.h5 Epoch 7/50 150/150 [==============================] - 6s 41ms/step - loss: 0.1429 - accuracy: 0.9591 - val_loss: 0.1548 - val_accuracy: 0.9547 Epoch 00007: val_loss improved from 0.16470 to 0.15479, saving model to mnist_keras_model.h5 Epoch 8/50 150/150 [==============================] - 6s 41ms/step - loss: 0.1316 - accuracy: 0.9610 - val_loss: 0.1462 - val_accuracy: 0.9576 Epoch 00008: val_loss improved from 0.15479 to 0.14618, saving model to mnist_keras_model.h5 Epoch 9/50 150/150 [==============================] - 8s 50ms/step - loss: 0.1248 - accuracy: 0.9652 - val_loss: 0.1401 - val_accuracy: 0.9593 Epoch 00009: val_loss improved from 0.14618 to 0.14008, saving model to mnist_keras_model.h5 Epoch 10/50 150/150 [==============================] - 7s 47ms/step - loss: 0.1149 - accuracy: 0.9662 - val_loss: 0.1314 - val_accuracy: 0.9612 Epoch 00010: val_loss improved from 0.14008 to 0.13137, saving model to mnist_keras_model.h5 Epoch 11/50 150/150 [==============================] - 7s 46ms/step - loss: 0.1099 - accuracy: 0.9672 - val_loss: 0.1379 - val_accuracy: 0.9588 Epoch 00011: val_loss did not improve from 0.13137 Epoch 12/50 150/150 [==============================] - 7s 44ms/step - loss: 0.1005 - accuracy: 0.9711 - val_loss: 0.1314 - val_accuracy: 0.9604 Epoch 00012: val_loss did not improve from 0.13137
epochs = range(1, len(history.history["loss"])+1)
plt.figure(figsize=(12,5))
plt.subplot(1, 2, 1)
plt.plot(epochs, history.history["loss"], label="Training loss")
plt.plot(epochs, history.history["val_loss"], label="Validation loss")
plt.legend(fontsize=15), plt.xlabel("Epochs", fontsize=15), plt.ylabel("Loss", fontsize=15)
plt.subplot(1, 2, 2)
plt.plot(epochs, history.history["accuracy"], label="Training accuracy")
plt.plot(epochs, history.history["val_accuracy"], label="Validation accuracy")
plt.legend(fontsize=15), plt.xlabel("Epochs", fontsize=15), plt.ylabel("Accuracy", fontsize=15);
The prediction of unseen data is performed using the model.predict(inputs)
call. Below, a basic test of the model is done by calculating the accuracy on the test dataset.
# Get predictions on test dataset
y_pred = model.predict(x_test)
# Compare predictions with ground truth
test_accuracy = np.sum(
np.argmax(y_test, axis=1)==np.argmax(y_pred, axis=1))/float(x_test.shape[0])
print("Test accuracy: {}".format(test_accuracy))
Test accuracy: 0.9655