By the end of this chapter, you will know how to solve binary, multi-class, and multi-label problems with neural networks. All of this by solving problems like detecting fake dollar bills, deciding who threw which dart at a board, and building an intelligent system to water your farm. You will also be able to plot model training metrics and to stop training and save your models when they no longer improve. This is the Summary of lecture "Introduction to Deep Learning with Keras", via datacamp.
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['figure.figsize'] = (8, 8)
You will practice building classification models in Keras with the Banknote Authentication dataset.
Your goal is to distinguish between real and fake dollar bills. In order to do this, the dataset comes with 4 features: variance
,skewness
,curtosis
and entropy
. These features are calculated by applying mathematical operations over the dollar bill images. The labels are found in the dataframe's class
column.
banknotes = pd.read_csv('./dataset/banknotes.csv')
banknotes.head()
variace | skewness | curtosis | entropy | class | |
---|---|---|---|---|---|
0 | 3.62160 | 8.6661 | -2.8073 | -0.44699 | 0 |
1 | 4.54590 | 8.1674 | -2.4586 | -1.46210 | 0 |
2 | 3.86600 | -2.6383 | 1.9242 | 0.10645 | 0 |
3 | 3.45660 | 9.5228 | -4.0112 | -3.59440 | 0 |
4 | 0.32924 | -4.4552 | 4.5718 | -0.98880 | 0 |
# Normlize the data
X = banknotes.iloc[:, :4]
X = ((X - X.mean()) / X.std()).to_numpy()
y = banknotes['class'].to_numpy()
# Use pairplot and set the hue to be our class column
sns.pairplot(banknotes, hue='class');
# Describe the data
print('Dataset stats: \n', banknotes.describe())
# Count the number of observations per class
print('Observations per class: \n', banknotes['class'].value_counts())
Dataset stats: variace skewness curtosis entropy class count 1372.000000 1372.000000 1372.000000 1372.000000 1372.000000 mean 0.433735 1.922353 1.397627 -1.191657 0.444606 std 2.842763 5.869047 4.310030 2.101013 0.497103 min -7.042100 -13.773100 -5.286100 -8.548200 0.000000 25% -1.773000 -1.708200 -1.574975 -2.413450 0.000000 50% 0.496180 2.319650 0.616630 -0.586650 0.000000 75% 2.821475 6.814625 3.179250 0.394810 1.000000 max 6.824800 12.951600 17.927400 2.449500 1.000000 Observations per class: 0 762 1 610 Name: class, dtype: int64
Your pairplot shows that there are features for which the classes spread out noticeably. This gives us an intuition about our classes being easily separable. Let's build a model to find out what it can do!
Now that you know what the Banknote Authentication dataset looks like, we'll build a simple model to distinguish between real and fake bills.
You will perform binary classification by using a single neuron as an output. The input layer will have 4 neurons since we have 4 features in our dataset. The model's output will be a value constrained between 0 and 1.
We will interpret this output number as the probability of our input variables coming from a fake dollar bill, with 1 meaning we are certain it's a fake bill.
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
# Create a sequential model
model = Sequential()
# Add a dense layer
model.add(Dense(1, input_shape=(4, ), activation='sigmoid'))
# Compile your model
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
# Display a summary of your model
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 1) 5 ================================================================= Total params: 5 Trainable params: 5 Non-trainable params: 0 _________________________________________________________________
You are now ready to train your model and check how well it performs when classifying new bills!
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y)
# Train your model for 20 epochs
model.fit(X_train, y_train, epochs=20)
# Evaluate your model accuracy on the test set
accuracy = model.evaluate(X_test, y_test)[1]
# Print accuracy
print('Accuracy: ', accuracy)
Epoch 1/20 33/33 [==============================] - 0s 1ms/step - loss: 0.8882 - accuracy: 0.4665 Epoch 2/20 33/33 [==============================] - 0s 1ms/step - loss: 0.8308 - accuracy: 0.4791 Epoch 3/20 33/33 [==============================] - 0s 1ms/step - loss: 0.7804 - accuracy: 0.5005 Epoch 4/20 33/33 [==============================] - 0s 1ms/step - loss: 0.7370 - accuracy: 0.5238 Epoch 5/20 33/33 [==============================] - 0s 1ms/step - loss: 0.6974 - accuracy: 0.5559 Epoch 6/20 33/33 [==============================] - 0s 1ms/step - loss: 0.6639 - accuracy: 0.5802 Epoch 7/20 33/33 [==============================] - 0s 1ms/step - loss: 0.6326 - accuracy: 0.6006 Epoch 8/20 33/33 [==============================] - 0s 1ms/step - loss: 0.6052 - accuracy: 0.6210 Epoch 9/20 33/33 [==============================] - 0s 1ms/step - loss: 0.5807 - accuracy: 0.6414 Epoch 10/20 33/33 [==============================] - 0s 1ms/step - loss: 0.5585 - accuracy: 0.6774 Epoch 11/20 33/33 [==============================] - 0s 1ms/step - loss: 0.5388 - accuracy: 0.6968 Epoch 12/20 33/33 [==============================] - 0s 1ms/step - loss: 0.5203 - accuracy: 0.7162 Epoch 13/20 33/33 [==============================] - 0s 1ms/step - loss: 0.5037 - accuracy: 0.7279 Epoch 14/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4884 - accuracy: 0.7464 Epoch 15/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4740 - accuracy: 0.7551 Epoch 16/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4605 - accuracy: 0.7716 Epoch 17/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4479 - accuracy: 0.7862 Epoch 18/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4361 - accuracy: 0.7940 Epoch 19/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4249 - accuracy: 0.8076 Epoch 20/20 33/33 [==============================] - 0s 1ms/step - loss: 0.4144 - accuracy: 0.8192 11/11 [==============================] - 0s 853us/step - loss: 0.4147 - accuracy: 0.8192 Accuracy: 0.819242000579834
You're going to build a model that predicts who threw which dart only based on where that dart landed! (That is the dart's x and y coordinates on the board.)
This problem is a multi-class classification problem since each dart can only be thrown by one of 4 competitors. So classes/labels are mutually exclusive, and therefore we can build a neuron with as many output as competitors and use the softmax
activation function to achieve a total sum of probabilities of 1 over all competitors.
darts = pd.read_csv('./dataset/darts.csv')
darts.head()
xCoord | yCoord | competitor | |
---|---|---|---|
0 | 0.196451 | -0.520341 | Steve |
1 | 0.476027 | -0.306763 | Susan |
2 | 0.003175 | -0.980736 | Michael |
3 | 0.294078 | 0.267566 | Kate |
4 | -0.051120 | 0.598946 | Steve |
sns.pairplot(darts, hue='competitor');
# Instantiate a sequential model
model = Sequential()
# Add 3 dense layers of 128, 64, 32, neurons each
model.add(Dense(128, input_shape=(2, ), activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))
# Add a dense layer with as many neurons as competitors
model.add(Dense(4, activation='softmax'))
# Compile your model using categorical_crossentropy loss
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
In the console you can check that your labels, darts.competitor
are not yet in a format to be understood by your network. They contain the names of the competitors as strings. You will first turn these competitors into unique numbers,then use the to_categorical()
function from tf.keras.utils
to turn these numbers into their one-hot encoded representation.
This is useful for multi-class classification problems, since there are as many output neurons as classes and for every observation in our dataset we just want one of the neurons to be activated.
from tensorflow.keras.utils import to_categorical
# Transform into a categorical variable
darts.competitor = pd.Categorical(darts.competitor)
# Assign a number to each category (label encoding)
darts.competitor = darts.competitor.cat.codes
# Print the label encoded competitors
print('Label encoded competitors: \n', darts.competitor.head())
coordinates = darts.drop(['competitor'], axis=1)
# Use to_categorical on your labels
competitors = to_categorical(darts.competitor)
# Now print the one-hot encoded labels
print('One-hot encoded competitors: \n', competitors)
Label encoded competitors: 0 2 1 3 2 1 3 0 4 2 Name: competitor, dtype: int8 One-hot encoded competitors: [[0. 0. 1. 0.] [0. 0. 0. 1.] [0. 1. 0. 0.] ... [0. 1. 0. 0.] [0. 1. 0. 0.] [0. 0. 0. 1.]]
Each competitor is now a vector of length 4, full of zeroes except for the position representing her or himself.
Your model is now ready, just as your dataset. It's time to train!
The coordinates
features and competitors
labels you just transformed have been partitioned into coord_train
,coord_test
and competitors_train
,competitors_test
.
Let's find out who threw which dart just by looking at the board!
coordinates = darts[['xCoord', 'yCoord']]
coordinates.head()
xCoord | yCoord | |
---|---|---|
0 | 0.196451 | -0.520341 |
1 | 0.476027 | -0.306763 |
2 | 0.003175 | -0.980736 |
3 | 0.294078 | 0.267566 |
4 | -0.051120 | 0.598946 |
coord_train, coord_test, competitors_train, competitors_test = \
train_test_split(coordinates, competitors, test_size=0.25, stratify=competitors)
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 128) 384 _________________________________________________________________ dense_2 (Dense) (None, 64) 8256 _________________________________________________________________ dense_3 (Dense) (None, 32) 2080 _________________________________________________________________ dense_4 (Dense) (None, 4) 132 ================================================================= Total params: 10,852 Trainable params: 10,852 Non-trainable params: 0 _________________________________________________________________
# Fit your model to the training data for 200 epochs
model.fit(coord_train, competitors_train, epochs=200)
# Evaluate your model accuracy on the test data
accuracy = model.evaluate(coord_test, competitors_test)[1]
# Print accuracy
print('Accuracy:', accuracy)
Epoch 1/200 19/19 [==============================] - 0s 1ms/step - loss: 1.3815 - accuracy: 0.3017 Epoch 2/200 19/19 [==============================] - 0s 1ms/step - loss: 1.3421 - accuracy: 0.3100 Epoch 3/200 19/19 [==============================] - 0s 1ms/step - loss: 1.2890 - accuracy: 0.3333 Epoch 4/200 19/19 [==============================] - 0s 1ms/step - loss: 1.2148 - accuracy: 0.4683 Epoch 5/200 19/19 [==============================] - 0s 1ms/step - loss: 1.1204 - accuracy: 0.5567 Epoch 6/200 19/19 [==============================] - 0s 1ms/step - loss: 1.0274 - accuracy: 0.5983 Epoch 7/200 19/19 [==============================] - 0s 1ms/step - loss: 0.9489 - accuracy: 0.6000 Epoch 8/200 19/19 [==============================] - 0s 1ms/step - loss: 0.8811 - accuracy: 0.6533 Epoch 9/200 19/19 [==============================] - 0s 1ms/step - loss: 0.8523 - accuracy: 0.6417 Epoch 10/200 19/19 [==============================] - 0s 5ms/step - loss: 0.8351 - accuracy: 0.6467 Epoch 11/200 19/19 [==============================] - 0s 1ms/step - loss: 0.8121 - accuracy: 0.6800 Epoch 12/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7936 - accuracy: 0.6983 Epoch 13/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7787 - accuracy: 0.7083 Epoch 14/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7732 - accuracy: 0.7183 Epoch 15/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7707 - accuracy: 0.6967 Epoch 16/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7423 - accuracy: 0.7217 Epoch 17/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7359 - accuracy: 0.7350 Epoch 18/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7280 - accuracy: 0.7300 Epoch 19/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7174 - accuracy: 0.7500 Epoch 20/200 19/19 [==============================] - 0s 1ms/step - loss: 0.7009 - accuracy: 0.7633 Epoch 21/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6941 - accuracy: 0.7517 Epoch 22/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6957 - accuracy: 0.7567 Epoch 23/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6942 - accuracy: 0.7567 Epoch 24/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6832 - accuracy: 0.7633 Epoch 25/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6801 - accuracy: 0.7583 Epoch 26/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6693 - accuracy: 0.7767 Epoch 27/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6603 - accuracy: 0.7717 Epoch 28/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6593 - accuracy: 0.7850 Epoch 29/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6521 - accuracy: 0.7800 Epoch 30/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6472 - accuracy: 0.7867 Epoch 31/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6414 - accuracy: 0.7950 Epoch 32/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6526 - accuracy: 0.7650 Epoch 33/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6457 - accuracy: 0.7983 Epoch 34/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6413 - accuracy: 0.7733 Epoch 35/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6361 - accuracy: 0.7850 Epoch 36/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6259 - accuracy: 0.7967 Epoch 37/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6277 - accuracy: 0.7750 Epoch 38/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6190 - accuracy: 0.7883 Epoch 39/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6388 - accuracy: 0.7800 Epoch 40/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6224 - accuracy: 0.7900 Epoch 41/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6156 - accuracy: 0.7983 Epoch 42/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6119 - accuracy: 0.7933 Epoch 43/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6091 - accuracy: 0.7883 Epoch 44/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6149 - accuracy: 0.7850 Epoch 45/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6086 - accuracy: 0.7917 Epoch 46/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5947 - accuracy: 0.8050 Epoch 47/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5941 - accuracy: 0.7867 Epoch 48/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5938 - accuracy: 0.7917 Epoch 49/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5935 - accuracy: 0.8050 Epoch 50/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5943 - accuracy: 0.7867 Epoch 51/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5785 - accuracy: 0.8017 Epoch 52/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5772 - accuracy: 0.8083 Epoch 53/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5899 - accuracy: 0.7900 Epoch 54/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5744 - accuracy: 0.8050 Epoch 55/200 19/19 [==============================] - 0s 1ms/step - loss: 0.6058 - accuracy: 0.7783 Epoch 56/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5792 - accuracy: 0.8017 Epoch 57/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5783 - accuracy: 0.7917 Epoch 58/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5738 - accuracy: 0.8000 Epoch 59/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5679 - accuracy: 0.8100 Epoch 60/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5615 - accuracy: 0.8050 Epoch 61/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5755 - accuracy: 0.7950 Epoch 62/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5602 - accuracy: 0.8000 Epoch 63/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5601 - accuracy: 0.8100 Epoch 64/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5708 - accuracy: 0.8083 Epoch 65/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5658 - accuracy: 0.7967 Epoch 66/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5675 - accuracy: 0.7983 Epoch 67/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5610 - accuracy: 0.8150 Epoch 68/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5525 - accuracy: 0.8100 Epoch 69/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5623 - accuracy: 0.7883 Epoch 70/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5627 - accuracy: 0.7967 Epoch 71/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5657 - accuracy: 0.8050 Epoch 72/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5542 - accuracy: 0.8050 Epoch 73/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5553 - accuracy: 0.7950 Epoch 74/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5530 - accuracy: 0.8000 Epoch 75/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5506 - accuracy: 0.7983 Epoch 76/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5669 - accuracy: 0.7833 Epoch 77/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5656 - accuracy: 0.8000 Epoch 78/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5509 - accuracy: 0.8083 Epoch 79/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5367 - accuracy: 0.8133 Epoch 80/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5436 - accuracy: 0.8050 Epoch 81/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5416 - accuracy: 0.8083 Epoch 82/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5454 - accuracy: 0.8100 Epoch 83/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5524 - accuracy: 0.8050 Epoch 84/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5479 - accuracy: 0.8083 Epoch 85/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5417 - accuracy: 0.8083 Epoch 86/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5390 - accuracy: 0.8133 Epoch 87/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5421 - accuracy: 0.8017 Epoch 88/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5430 - accuracy: 0.8067 Epoch 89/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5394 - accuracy: 0.8000 Epoch 90/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5443 - accuracy: 0.8100 Epoch 91/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5412 - accuracy: 0.8100 Epoch 92/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5590 - accuracy: 0.7950 Epoch 93/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5690 - accuracy: 0.7800 Epoch 94/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5641 - accuracy: 0.7933 Epoch 95/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5466 - accuracy: 0.8017 Epoch 96/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5416 - accuracy: 0.8000 Epoch 97/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5346 - accuracy: 0.8150 Epoch 98/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5284 - accuracy: 0.8167 Epoch 99/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5417 - accuracy: 0.8050 Epoch 100/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5328 - accuracy: 0.8067 Epoch 101/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5348 - accuracy: 0.8033 Epoch 102/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5396 - accuracy: 0.8017 Epoch 103/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5459 - accuracy: 0.8083 Epoch 104/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5282 - accuracy: 0.8100 Epoch 105/200 19/19 [==============================] - 0s 2ms/step - loss: 0.5436 - accuracy: 0.8033 Epoch 106/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5227 - accuracy: 0.8150 Epoch 107/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5251 - accuracy: 0.8100 Epoch 108/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5293 - accuracy: 0.8083 Epoch 109/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5578 - accuracy: 0.7950 Epoch 110/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5285 - accuracy: 0.8083 Epoch 111/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5206 - accuracy: 0.8117 Epoch 112/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5250 - accuracy: 0.8100 Epoch 113/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5225 - accuracy: 0.8150 Epoch 114/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5248 - accuracy: 0.8100 Epoch 115/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5338 - accuracy: 0.8083 Epoch 116/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5279 - accuracy: 0.8050 Epoch 117/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5240 - accuracy: 0.8050 Epoch 118/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5204 - accuracy: 0.8083 Epoch 119/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5358 - accuracy: 0.8033 Epoch 120/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5186 - accuracy: 0.8117 Epoch 121/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5404 - accuracy: 0.8083 Epoch 122/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5250 - accuracy: 0.8150 Epoch 123/200 19/19 [==============================] - 0s 967us/step - loss: 0.5243 - accuracy: 0.8100 Epoch 124/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5342 - accuracy: 0.8067 Epoch 125/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5281 - accuracy: 0.8033 Epoch 126/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5249 - accuracy: 0.8133 Epoch 127/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5306 - accuracy: 0.8117 Epoch 128/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5325 - accuracy: 0.8067 Epoch 129/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5246 - accuracy: 0.8067 Epoch 130/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5208 - accuracy: 0.8050 Epoch 131/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5214 - accuracy: 0.8067 Epoch 132/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5170 - accuracy: 0.8133 Epoch 133/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5145 - accuracy: 0.8067 Epoch 134/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5181 - accuracy: 0.8017 Epoch 135/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5338 - accuracy: 0.7950 Epoch 136/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5401 - accuracy: 0.8050 Epoch 137/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5409 - accuracy: 0.8050 Epoch 138/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5238 - accuracy: 0.8067 Epoch 139/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5118 - accuracy: 0.8100 Epoch 140/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5109 - accuracy: 0.8133 Epoch 141/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5169 - accuracy: 0.8117 Epoch 142/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5257 - accuracy: 0.8067 Epoch 143/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5103 - accuracy: 0.8167 Epoch 144/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5129 - accuracy: 0.8067 Epoch 145/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5166 - accuracy: 0.8083 Epoch 146/200 19/19 [==============================] - ETA: 0s - loss: 0.3928 - accuracy: 0.84 - 0s 1ms/step - loss: 0.5254 - accuracy: 0.8000 Epoch 147/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5145 - accuracy: 0.8083 Epoch 148/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5118 - accuracy: 0.8083 Epoch 149/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5131 - accuracy: 0.8200 Epoch 150/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5057 - accuracy: 0.8200 Epoch 151/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5137 - accuracy: 0.8117 Epoch 152/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5104 - accuracy: 0.8117 Epoch 153/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5108 - accuracy: 0.8167 Epoch 154/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5037 - accuracy: 0.8183 Epoch 155/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5252 - accuracy: 0.7983 Epoch 156/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5267 - accuracy: 0.8033 Epoch 157/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5097 - accuracy: 0.8183 Epoch 158/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5148 - accuracy: 0.8017 Epoch 159/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5130 - accuracy: 0.8133 Epoch 160/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5160 - accuracy: 0.8033 Epoch 161/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5138 - accuracy: 0.8000 Epoch 162/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5088 - accuracy: 0.8133 Epoch 163/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5130 - accuracy: 0.8067 Epoch 164/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5054 - accuracy: 0.8083 Epoch 165/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5037 - accuracy: 0.8200 Epoch 166/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5039 - accuracy: 0.8167 Epoch 167/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5176 - accuracy: 0.8067 Epoch 168/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5134 - accuracy: 0.8150 Epoch 169/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5121 - accuracy: 0.8100 Epoch 170/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5191 - accuracy: 0.8150 Epoch 171/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5090 - accuracy: 0.8100 Epoch 172/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5119 - accuracy: 0.7983 Epoch 173/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5198 - accuracy: 0.8050 Epoch 174/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5088 - accuracy: 0.8133 Epoch 175/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5331 - accuracy: 0.7983 Epoch 176/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5181 - accuracy: 0.8033 Epoch 177/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5051 - accuracy: 0.8167 Epoch 178/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5033 - accuracy: 0.8117 Epoch 179/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5075 - accuracy: 0.8133 Epoch 180/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4996 - accuracy: 0.8200 Epoch 181/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5007 - accuracy: 0.8167 Epoch 182/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5027 - accuracy: 0.8167 Epoch 183/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5004 - accuracy: 0.8100 Epoch 184/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5009 - accuracy: 0.8150 Epoch 185/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5136 - accuracy: 0.8083 Epoch 186/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5048 - accuracy: 0.8067 Epoch 187/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4962 - accuracy: 0.8100 Epoch 188/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4980 - accuracy: 0.8133 Epoch 189/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5150 - accuracy: 0.8033 Epoch 190/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5195 - accuracy: 0.8017 Epoch 191/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5164 - accuracy: 0.8033 Epoch 192/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5081 - accuracy: 0.8117 Epoch 193/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5022 - accuracy: 0.8100 Epoch 194/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5029 - accuracy: 0.8117 Epoch 195/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4898 - accuracy: 0.8200 Epoch 196/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5111 - accuracy: 0.8117 Epoch 197/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4966 - accuracy: 0.8183 Epoch 198/200 19/19 [==============================] - 0s 1ms/step - loss: 0.5088 - accuracy: 0.8067 Epoch 199/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4978 - accuracy: 0.8183 Epoch 200/200 19/19 [==============================] - 0s 1ms/step - loss: 0.4926 - accuracy: 0.8167 7/7 [==============================] - 0s 914us/step - loss: 0.6935 - accuracy: 0.7650 Accuracy: 0.7649999856948853
This model is generalizing well!, that's why you got a high accuracy on the test set.
Since you used the softmax
activation function, for every input of 2 coordinates provided to your model there's an output vector of 4 numbers. Each of these numbers encodes the probability of a given dart being thrown by one of the 4 possible competitors.
When computing accuracy with the model's .evaluate()
method, your model takes the class with the highest probability as the prediction. np.argmax()
can help you do this since it returns the index with the highest value in an array.
Use the collection of test throws stored in coords_small_test
and np.argmax()
to check this out!
coords_small_test = pd.DataFrame({
'xCoord':[0.209048, 0.082103, 0.198165, -0.348660, 0.214726],
'yCoord':[-0.077398, -0.721407, -0.674646, 0.035086, 0.183894]
})
competitors_small_test = np.array([[0., 0., 1., 0.], [0., 0., 0., 1.],
[0., 0., 0., 1.], [1., 0., 0., 0.],
[0., 0., 1., 0.]])
# Predict on coords_small_test
preds = model.predict(coords_small_test)
# Print preds vs true values
print("{:45} | {}".format("Raw Model Predictions", "True labels"))
for i, pred in enumerate(preds):
print("{} | {}".format(pred, competitors_small_test[i]))
Raw Model Predictions | True labels [0.40823606 0.01579 0.5673056 0.00866832] | [0. 0. 1. 0.] [0.14185785 0.00279553 0.04308222 0.81226444] | [0. 0. 0. 1.] [0.41373312 0.00399575 0.17039205 0.4118791 ] | [0. 0. 0. 1.] [0.93831897 0.0479466 0.00533243 0.00840211] | [1. 0. 0. 0.] [0.355474 0.01381436 0.62466073 0.00605091] | [0. 0. 1. 0.]
# Extract the position of highest probability from each pred vector
preds_chosen = [np.argmax(pred) for pred in preds]
# Print preds vs true values
print("{:10} | {}".format("Rounded Model Predictions", "True labels"))
for i, pred in enumerate(preds_chosen):
print("{:25} | {}".format(pred, competitors_small_test[i]))
Rounded Model Predictions | True labels 2 | [0. 0. 1. 0.] 3 | [0. 0. 0. 1.] 0 | [0. 0. 0. 1.] 0 | [1. 0. 0. 0.] 2 | [0. 0. 1. 0.]
As you've seen you can easily interpret the softmax output. This can also help you spot those observations where your network is less certain on which class to predict, since you can see the probability distribution among classes per prediction. Let's learn how to solve new problems with neural networks!
You're going to automate the watering of farm parcels by making an intelligent irrigation machine. Multi-label classification problems differ from multi-class problems in that each observation can be labeled with zero or more classes. So classes/labels are not mutually exclusive, you could water all, none or any combination of farm parcels based on the inputs.
To account for this behavior what we do is have an output layer with as many neurons as classes but this time, unlike in multi-class problems, each output neuron has a sigmoid
activation function. This makes each neuron in the output layer able to output a number between 0 and 1 independently.
irrigation = pd.read_csv('./dataset/irrigation_machine.csv', index_col=0)
irrigation.head()
sensor_0 | sensor_1 | sensor_2 | sensor_3 | sensor_4 | sensor_5 | sensor_6 | sensor_7 | sensor_8 | sensor_9 | ... | sensor_13 | sensor_14 | sensor_15 | sensor_16 | sensor_17 | sensor_18 | sensor_19 | parcel_0 | parcel_1 | parcel_2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 2.0 | 1.0 | 7.0 | 0.0 | 1.0 | 1.0 | 4.0 | 0.0 | 3.0 | ... | 8.0 | 1.0 | 0.0 | 2.0 | 1.0 | 9.0 | 2.0 | 0 | 1 | 0 |
1 | 5.0 | 1.0 | 3.0 | 5.0 | 2.0 | 2.0 | 1.0 | 2.0 | 3.0 | 1.0 | ... | 4.0 | 5.0 | 5.0 | 2.0 | 2.0 | 2.0 | 7.0 | 0 | 0 | 0 |
2 | 3.0 | 1.0 | 4.0 | 3.0 | 4.0 | 0.0 | 1.0 | 6.0 | 0.0 | 2.0 | ... | 3.0 | 3.0 | 1.0 | 0.0 | 3.0 | 1.0 | 0.0 | 1 | 1 | 0 |
3 | 2.0 | 2.0 | 4.0 | 3.0 | 5.0 | 0.0 | 3.0 | 2.0 | 2.0 | 5.0 | ... | 4.0 | 1.0 | 1.0 | 4.0 | 1.0 | 3.0 | 2.0 | 0 | 0 | 0 |
4 | 4.0 | 3.0 | 3.0 | 2.0 | 5.0 | 1.0 | 3.0 | 1.0 | 1.0 | 2.0 | ... | 1.0 | 3.0 | 2.0 | 2.0 | 1.0 | 1.0 | 0.0 | 1 | 1 | 0 |
5 rows × 23 columns
# Instantiate a Sequential model
model = Sequential()
# Add a hidden layer of 64 neurons and a 20 neuron's input
model.add(Dense(64, input_shape=(20, ), activation='relu'))
# Add an output layer of 3 neurons with sigmoid activation
model.add(Dense(3, activation='sigmoid'))
# Compile your model with binary crossentropy loss
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_5 (Dense) (None, 64) 1344 _________________________________________________________________ dense_6 (Dense) (None, 3) 195 ================================================================= Total params: 1,539 Trainable params: 1,539 Non-trainable params: 0 _________________________________________________________________
You've already built 3 models for 3 different problems! Hopefully you're starting to get a feel for how different problems can be modeled in the neural network realm.
An output of your multi-label model could look like this: [0.76 , 0.99 , 0.66 ]
. If we round up probabilities higher than 0.5, this observation will be classified as containing all 3 possible labels [1,1,1]
. For this particular problem, this would mean watering all 3 parcels in your farm is the right thing to do, according to the network, given the input sensor measurements.
You will now train and predict with the model you just built. sensors_train
, parcels_train
, sensors_test
and parcels_test
are already loaded for you to use.
Let's see how well your intelligent machine performs!
parcels = irrigation[['parcel_0', 'parcel_1', 'parcel_2']].to_numpy()
sensors = irrigation.drop(['parcel_0', 'parcel_1', 'parcel_2'], axis=1).to_numpy()
sensors_train, sensors_test, parcels_train, parcels_test = \
train_test_split(sensors, parcels, test_size=0.3, stratify=parcels)
# Train for 100 epochs using a validation split of 0.2
model.fit(sensors_train, parcels_train, epochs=100, validation_split=0.2)
# Predict on sensors_test and round up the predictions
preds = model.predict(sensors_test)
preds_rounded = np.round(preds)
# Print rounded preds
print('Rounded Predictions: \n', preds_rounded)
# Evaluate your model's accuracy on the test data
accuracy = model.evaluate(sensors_test, parcels_test)[1]
# Print accuracy
print('Accuracy:', accuracy)
Epoch 1/100 35/35 [==============================] - 0s 3ms/step - loss: 0.6201 - accuracy: 0.4518 - val_loss: 0.4988 - val_accuracy: 0.5179 Epoch 2/100 35/35 [==============================] - 0s 2ms/step - loss: 0.4524 - accuracy: 0.5866 - val_loss: 0.3890 - val_accuracy: 0.5857 Epoch 3/100 35/35 [==============================] - 0s 2ms/step - loss: 0.3798 - accuracy: 0.5991 - val_loss: 0.3269 - val_accuracy: 0.6286 Epoch 4/100 35/35 [==============================] - 0s 2ms/step - loss: 0.3357 - accuracy: 0.6062 - val_loss: 0.2967 - val_accuracy: 0.6429 Epoch 5/100 35/35 [==============================] - 0s 2ms/step - loss: 0.3105 - accuracy: 0.6268 - val_loss: 0.2734 - val_accuracy: 0.5857 Epoch 6/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2963 - accuracy: 0.6080 - val_loss: 0.2592 - val_accuracy: 0.5786 Epoch 7/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2833 - accuracy: 0.6241 - val_loss: 0.2482 - val_accuracy: 0.5821 Epoch 8/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2712 - accuracy: 0.6161 - val_loss: 0.2375 - val_accuracy: 0.6000 Epoch 9/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2634 - accuracy: 0.6179 - val_loss: 0.2325 - val_accuracy: 0.5893 Epoch 10/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2575 - accuracy: 0.6339 - val_loss: 0.2276 - val_accuracy: 0.5893 Epoch 11/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2509 - accuracy: 0.6045 - val_loss: 0.2259 - val_accuracy: 0.6464 Epoch 12/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2479 - accuracy: 0.6161 - val_loss: 0.2184 - val_accuracy: 0.6000 Epoch 13/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2416 - accuracy: 0.6205 - val_loss: 0.2133 - val_accuracy: 0.6000 Epoch 14/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2414 - accuracy: 0.6170 - val_loss: 0.2110 - val_accuracy: 0.5821 Epoch 15/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2313 - accuracy: 0.6036 - val_loss: 0.2073 - val_accuracy: 0.6250 Epoch 16/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2275 - accuracy: 0.6187 - val_loss: 0.2043 - val_accuracy: 0.6000 Epoch 17/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2242 - accuracy: 0.6098 - val_loss: 0.2035 - val_accuracy: 0.6000 Epoch 18/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2228 - accuracy: 0.6259 - val_loss: 0.2083 - val_accuracy: 0.6000 Epoch 19/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2182 - accuracy: 0.5964 - val_loss: 0.1976 - val_accuracy: 0.6143 Epoch 20/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2152 - accuracy: 0.6036 - val_loss: 0.1962 - val_accuracy: 0.5964 Epoch 21/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2127 - accuracy: 0.6036 - val_loss: 0.1959 - val_accuracy: 0.6036 Epoch 22/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2116 - accuracy: 0.6116 - val_loss: 0.1928 - val_accuracy: 0.6000 Epoch 23/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2083 - accuracy: 0.5991 - val_loss: 0.1986 - val_accuracy: 0.5857 Epoch 24/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2147 - accuracy: 0.6018 - val_loss: 0.1916 - val_accuracy: 0.5929 Epoch 25/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2057 - accuracy: 0.6000 - val_loss: 0.1936 - val_accuracy: 0.6286 Epoch 26/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2031 - accuracy: 0.6179 - val_loss: 0.1904 - val_accuracy: 0.5893 Epoch 27/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2018 - accuracy: 0.5964 - val_loss: 0.1890 - val_accuracy: 0.5857 Epoch 28/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2010 - accuracy: 0.5964 - val_loss: 0.1922 - val_accuracy: 0.6464 Epoch 29/100 35/35 [==============================] - 0s 2ms/step - loss: 0.2009 - accuracy: 0.6116 - val_loss: 0.1852 - val_accuracy: 0.6107 Epoch 30/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1971 - accuracy: 0.5982 - val_loss: 0.1871 - val_accuracy: 0.5893 Epoch 31/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1944 - accuracy: 0.6080 - val_loss: 0.1856 - val_accuracy: 0.5857 Epoch 32/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1930 - accuracy: 0.6080 - val_loss: 0.1866 - val_accuracy: 0.5821 Epoch 33/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1934 - accuracy: 0.5911 - val_loss: 0.1885 - val_accuracy: 0.6357 Epoch 34/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1911 - accuracy: 0.6143 - val_loss: 0.1873 - val_accuracy: 0.5607 Epoch 35/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1904 - accuracy: 0.5857 - val_loss: 0.1837 - val_accuracy: 0.6286 Epoch 36/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1891 - accuracy: 0.6080 - val_loss: 0.1937 - val_accuracy: 0.6179 Epoch 37/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1880 - accuracy: 0.6170 - val_loss: 0.1832 - val_accuracy: 0.5679 Epoch 38/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1882 - accuracy: 0.5786 - val_loss: 0.1846 - val_accuracy: 0.5964 Epoch 39/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1861 - accuracy: 0.6161 - val_loss: 0.1866 - val_accuracy: 0.5750 Epoch 40/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1842 - accuracy: 0.5946 - val_loss: 0.1841 - val_accuracy: 0.6000 Epoch 41/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1849 - accuracy: 0.5964 - val_loss: 0.1835 - val_accuracy: 0.6250 Epoch 42/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1818 - accuracy: 0.6134 - val_loss: 0.1828 - val_accuracy: 0.5821 Epoch 43/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1815 - accuracy: 0.5991 - val_loss: 0.1839 - val_accuracy: 0.6036 Epoch 44/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1797 - accuracy: 0.6062 - val_loss: 0.1837 - val_accuracy: 0.5750 Epoch 45/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1797 - accuracy: 0.5964 - val_loss: 0.1834 - val_accuracy: 0.6464 Epoch 46/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1781 - accuracy: 0.5938 - val_loss: 0.1842 - val_accuracy: 0.6107 Epoch 47/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1787 - accuracy: 0.6036 - val_loss: 0.1830 - val_accuracy: 0.5571 Epoch 48/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1762 - accuracy: 0.5920 - val_loss: 0.1831 - val_accuracy: 0.6571 Epoch 49/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1764 - accuracy: 0.6143 - val_loss: 0.1823 - val_accuracy: 0.6071 Epoch 50/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1767 - accuracy: 0.5929 - val_loss: 0.1833 - val_accuracy: 0.5786 Epoch 51/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1740 - accuracy: 0.6054 - val_loss: 0.1824 - val_accuracy: 0.6143 Epoch 52/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1731 - accuracy: 0.6062 - val_loss: 0.1871 - val_accuracy: 0.5893 Epoch 53/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1712 - accuracy: 0.6000 - val_loss: 0.1810 - val_accuracy: 0.6321 Epoch 54/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1708 - accuracy: 0.6009 - val_loss: 0.1835 - val_accuracy: 0.5679 Epoch 55/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1729 - accuracy: 0.6009 - val_loss: 0.1830 - val_accuracy: 0.6250 Epoch 56/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1687 - accuracy: 0.6125 - val_loss: 0.1855 - val_accuracy: 0.5714 Epoch 57/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1691 - accuracy: 0.5946 - val_loss: 0.1852 - val_accuracy: 0.5214 Epoch 58/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1696 - accuracy: 0.6170 - val_loss: 0.1831 - val_accuracy: 0.5500 Epoch 59/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1693 - accuracy: 0.6089 - val_loss: 0.1845 - val_accuracy: 0.5393 Epoch 60/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1659 - accuracy: 0.5857 - val_loss: 0.1831 - val_accuracy: 0.5857 Epoch 61/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1654 - accuracy: 0.5911 - val_loss: 0.1831 - val_accuracy: 0.6393 Epoch 62/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1648 - accuracy: 0.6009 - val_loss: 0.1848 - val_accuracy: 0.6107 Epoch 63/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1658 - accuracy: 0.6116 - val_loss: 0.1847 - val_accuracy: 0.5571 Epoch 64/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1652 - accuracy: 0.5857 - val_loss: 0.1835 - val_accuracy: 0.5893 Epoch 65/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1620 - accuracy: 0.6143 - val_loss: 0.1881 - val_accuracy: 0.5786 Epoch 66/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1618 - accuracy: 0.5991 - val_loss: 0.1903 - val_accuracy: 0.5821 Epoch 67/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1600 - accuracy: 0.6116 - val_loss: 0.1874 - val_accuracy: 0.5964 Epoch 68/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1614 - accuracy: 0.5893 - val_loss: 0.1874 - val_accuracy: 0.5857 Epoch 69/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1596 - accuracy: 0.6036 - val_loss: 0.1924 - val_accuracy: 0.5536 Epoch 70/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1565 - accuracy: 0.6009 - val_loss: 0.1867 - val_accuracy: 0.6357 Epoch 71/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1575 - accuracy: 0.6098 - val_loss: 0.1870 - val_accuracy: 0.5750 Epoch 72/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1577 - accuracy: 0.6009 - val_loss: 0.1875 - val_accuracy: 0.5964 Epoch 73/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1587 - accuracy: 0.5946 - val_loss: 0.1904 - val_accuracy: 0.5714 Epoch 74/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1557 - accuracy: 0.5920 - val_loss: 0.1885 - val_accuracy: 0.6036 Epoch 75/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1538 - accuracy: 0.5893 - val_loss: 0.1879 - val_accuracy: 0.5893 Epoch 76/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1528 - accuracy: 0.6098 - val_loss: 0.1910 - val_accuracy: 0.5500 Epoch 77/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1540 - accuracy: 0.6179 - val_loss: 0.1895 - val_accuracy: 0.5500 Epoch 78/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1511 - accuracy: 0.6045 - val_loss: 0.1885 - val_accuracy: 0.5786 Epoch 79/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1488 - accuracy: 0.5875 - val_loss: 0.1911 - val_accuracy: 0.6571 Epoch 80/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1499 - accuracy: 0.6196 - val_loss: 0.1924 - val_accuracy: 0.5500 Epoch 81/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1500 - accuracy: 0.6018 - val_loss: 0.1913 - val_accuracy: 0.5750 Epoch 82/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1473 - accuracy: 0.5893 - val_loss: 0.1929 - val_accuracy: 0.5857 Epoch 83/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1487 - accuracy: 0.6313 - val_loss: 0.1942 - val_accuracy: 0.6000 Epoch 84/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1463 - accuracy: 0.6045 - val_loss: 0.1984 - val_accuracy: 0.5286 Epoch 85/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1469 - accuracy: 0.5884 - val_loss: 0.1947 - val_accuracy: 0.6321 Epoch 86/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1438 - accuracy: 0.6196 - val_loss: 0.1933 - val_accuracy: 0.5750 Epoch 87/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1426 - accuracy: 0.6089 - val_loss: 0.1973 - val_accuracy: 0.5286 Epoch 88/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1419 - accuracy: 0.6161 - val_loss: 0.1964 - val_accuracy: 0.5786 Epoch 89/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1415 - accuracy: 0.5982 - val_loss: 0.2014 - val_accuracy: 0.6071 Epoch 90/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1417 - accuracy: 0.6000 - val_loss: 0.2075 - val_accuracy: 0.5857 Epoch 91/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1406 - accuracy: 0.6277 - val_loss: 0.1978 - val_accuracy: 0.5821 Epoch 92/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1397 - accuracy: 0.5723 - val_loss: 0.2014 - val_accuracy: 0.5821 Epoch 93/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1380 - accuracy: 0.6295 - val_loss: 0.2014 - val_accuracy: 0.5964 Epoch 94/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1371 - accuracy: 0.6098 - val_loss: 0.2011 - val_accuracy: 0.5643 Epoch 95/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1375 - accuracy: 0.6000 - val_loss: 0.2002 - val_accuracy: 0.6107 Epoch 96/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1376 - accuracy: 0.5929 - val_loss: 0.2012 - val_accuracy: 0.6214 Epoch 97/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1351 - accuracy: 0.6330 - val_loss: 0.2051 - val_accuracy: 0.5429 Epoch 98/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1359 - accuracy: 0.6054 - val_loss: 0.2021 - val_accuracy: 0.5679 Epoch 99/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1329 - accuracy: 0.6143 - val_loss: 0.2027 - val_accuracy: 0.5643 Epoch 100/100 35/35 [==============================] - 0s 2ms/step - loss: 0.1338 - accuracy: 0.5946 - val_loss: 0.2086 - val_accuracy: 0.5536 Rounded Predictions: [[1. 1. 1.] [1. 1. 0.] [1. 1. 0.] ... [1. 1. 0.] [1. 1. 1.] [1. 1. 0.]] 19/19 [==============================] - 0s 905us/step - loss: 0.2775 - accuracy: 0.5867 Accuracy: 0.5866666436195374
Great work on automating this farm! You can see how the validation_split argument is useful for evaluating how your model performs as it trains. Let's move on and improve your model training by using callbacks!
The history callback is returned by default every time you train a model with the .fit()
method. To access these metrics you can access the history dictionary parameter inside the returned h_callback
object with the corresponding keys.
The irrigation machine model you built in the previous lesson is loaded for you to train, along with its features and labels now loaded as X_train
, y_train
, X_test
, y_test
. This time you will store the model's history
callback and use the validation_data
parameter as it trains.
Let's see the behind the scenes of our training!
def plot_accuracy(acc,val_acc):
# Plot training & validation accuracy values
plt.figure();
plt.plot(acc);
plt.plot(val_acc);
plt.title('Model accuracy');
plt.ylabel('Accuracy');
plt.xlabel('Epoch');
plt.legend(['Train', 'Test'], loc='upper left');
def plot_loss(loss,val_loss):
plt.figure();
plt.plot(loss);
plt.plot(val_loss);
plt.title('Model loss');
plt.ylabel('Loss');
plt.xlabel('Epoch');
plt.legend(['Train', 'Test'], loc='upper right');
X_train, y_train = sensors_train, parcels_train
X_test, y_test = sensors_test, parcels_test
Note: In
tf.keras
,'accuracy'
and'val_accuracy'
is used for check accuracy
# Train your model and save its history
h_callback = model.fit(X_train, y_train, epochs=50, validation_data=(X_test, y_test))
Epoch 1/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1466 - accuracy: 0.6171 - val_loss: 0.2714 - val_accuracy: 0.6450 Epoch 2/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1450 - accuracy: 0.6014 - val_loss: 0.2698 - val_accuracy: 0.6250 Epoch 3/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1434 - accuracy: 0.6193 - val_loss: 0.2728 - val_accuracy: 0.6767 Epoch 4/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1409 - accuracy: 0.6014 - val_loss: 0.2719 - val_accuracy: 0.6117 Epoch 5/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1404 - accuracy: 0.6129 - val_loss: 0.2701 - val_accuracy: 0.6317 Epoch 6/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1404 - accuracy: 0.6186 - val_loss: 0.2705 - val_accuracy: 0.6333 Epoch 7/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1389 - accuracy: 0.6036 - val_loss: 0.2730 - val_accuracy: 0.5900 Epoch 8/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1377 - accuracy: 0.6286 - val_loss: 0.2707 - val_accuracy: 0.6583 Epoch 9/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1356 - accuracy: 0.6171 - val_loss: 0.2726 - val_accuracy: 0.6250 Epoch 10/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1351 - accuracy: 0.6264 - val_loss: 0.2739 - val_accuracy: 0.6850 Epoch 11/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1344 - accuracy: 0.6071 - val_loss: 0.2748 - val_accuracy: 0.6733 Epoch 12/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1346 - accuracy: 0.6207 - val_loss: 0.2808 - val_accuracy: 0.6867 Epoch 13/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1329 - accuracy: 0.6207 - val_loss: 0.2766 - val_accuracy: 0.6450 Epoch 14/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1332 - accuracy: 0.5971 - val_loss: 0.2774 - val_accuracy: 0.6583 Epoch 15/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1310 - accuracy: 0.6393 - val_loss: 0.2755 - val_accuracy: 0.6617 Epoch 16/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1294 - accuracy: 0.6307 - val_loss: 0.2792 - val_accuracy: 0.6683 Epoch 17/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1290 - accuracy: 0.6129 - val_loss: 0.2863 - val_accuracy: 0.5983 Epoch 18/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1329 - accuracy: 0.6429 - val_loss: 0.2809 - val_accuracy: 0.6533 Epoch 19/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1294 - accuracy: 0.6164 - val_loss: 0.2791 - val_accuracy: 0.6217 Epoch 20/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1297 - accuracy: 0.6057 - val_loss: 0.2772 - val_accuracy: 0.6600 Epoch 21/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1270 - accuracy: 0.6357 - val_loss: 0.2832 - val_accuracy: 0.6633 Epoch 22/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1249 - accuracy: 0.6264 - val_loss: 0.2783 - val_accuracy: 0.6567 Epoch 23/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1254 - accuracy: 0.6300 - val_loss: 0.2812 - val_accuracy: 0.6733 Epoch 24/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1265 - accuracy: 0.6150 - val_loss: 0.2838 - val_accuracy: 0.6317 Epoch 25/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1224 - accuracy: 0.6243 - val_loss: 0.2790 - val_accuracy: 0.6033 Epoch 26/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1223 - accuracy: 0.6300 - val_loss: 0.2861 - val_accuracy: 0.6933 Epoch 27/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1225 - accuracy: 0.6193 - val_loss: 0.2830 - val_accuracy: 0.6617 Epoch 28/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1215 - accuracy: 0.6264 - val_loss: 0.2843 - val_accuracy: 0.6683 Epoch 29/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1197 - accuracy: 0.6014 - val_loss: 0.2828 - val_accuracy: 0.6517 Epoch 30/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1203 - accuracy: 0.6286 - val_loss: 0.2878 - val_accuracy: 0.6933 Epoch 31/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1199 - accuracy: 0.6321 - val_loss: 0.2926 - val_accuracy: 0.6783 Epoch 32/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1189 - accuracy: 0.6350 - val_loss: 0.2863 - val_accuracy: 0.6433 Epoch 33/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1171 - accuracy: 0.6171 - val_loss: 0.2859 - val_accuracy: 0.6583 Epoch 34/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1165 - accuracy: 0.6136 - val_loss: 0.2900 - val_accuracy: 0.6617 Epoch 35/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1190 - accuracy: 0.6307 - val_loss: 0.2893 - val_accuracy: 0.6567 Epoch 36/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1148 - accuracy: 0.6193 - val_loss: 0.2883 - val_accuracy: 0.6550 Epoch 37/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1149 - accuracy: 0.6207 - val_loss: 0.2895 - val_accuracy: 0.6483 Epoch 38/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1155 - accuracy: 0.6393 - val_loss: 0.2925 - val_accuracy: 0.6517 Epoch 39/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1135 - accuracy: 0.6186 - val_loss: 0.3014 - val_accuracy: 0.6833 Epoch 40/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1151 - accuracy: 0.6300 - val_loss: 0.2913 - val_accuracy: 0.6700 Epoch 41/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1132 - accuracy: 0.6193 - val_loss: 0.2984 - val_accuracy: 0.6433 Epoch 42/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1122 - accuracy: 0.6229 - val_loss: 0.2910 - val_accuracy: 0.6050 Epoch 43/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1128 - accuracy: 0.6129 - val_loss: 0.2985 - val_accuracy: 0.6017 Epoch 44/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1118 - accuracy: 0.6121 - val_loss: 0.2938 - val_accuracy: 0.6650 Epoch 45/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1097 - accuracy: 0.6250 - val_loss: 0.2941 - val_accuracy: 0.6700 Epoch 46/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1086 - accuracy: 0.6329 - val_loss: 0.2950 - val_accuracy: 0.6350 Epoch 47/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1083 - accuracy: 0.6229 - val_loss: 0.3021 - val_accuracy: 0.6533 Epoch 48/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1086 - accuracy: 0.6286 - val_loss: 0.2937 - val_accuracy: 0.6500 Epoch 49/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1081 - accuracy: 0.6257 - val_loss: 0.2987 - val_accuracy: 0.6833 Epoch 50/50 44/44 [==============================] - 0s 2ms/step - loss: 0.1065 - accuracy: 0.6264 - val_loss: 0.2981 - val_accuracy: 0.6417
# Plot train vs test loss during training
plot_loss(h_callback.history['loss'], h_callback.history['val_loss'])
# Plot train vs test accuracy during training
#
plot_accuracy(h_callback.history['accuracy'], h_callback.history['val_accuracy'])
The early stopping callback is useful since it allows for you to stop the model training if it no longer improves after a given number of epochs. To make use of this functionality you need to pass the callback inside a list to the model's callback parameter in the .fit()
method.
The model
you built to detect fake dollar bills is loaded for you to train, this time with early stopping. X_train
, y_train
, X_test
and y_test
are also available for you to use.
# Normlize the data
X = banknotes.iloc[:, :4]
X = ((X - X.mean()) / X.std()).to_numpy()
y = banknotes['class'].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y)
# Create a sequential model
model = Sequential()
# Add a dense layer
model.add(Dense(1, input_shape=(4, ), activation='sigmoid'))
# Compile your model
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
from tensorflow.keras.callbacks import EarlyStopping
# Define a callback to monitor val_acc
monitor_val_acc = EarlyStopping(monitor='val_accuracy', patience=5)
# Train your model using early stopping callback
model.fit(X_train, y_train, epochs=100, validation_data=(X_test, y_test),
callbacks=[monitor_val_acc]);
Epoch 1/100 33/33 [==============================] - 0s 3ms/step - loss: 1.1413 - accuracy: 0.4052 - val_loss: 1.0880 - val_accuracy: 0.4198 Epoch 2/100 33/33 [==============================] - 0s 2ms/step - loss: 0.9852 - accuracy: 0.4490 - val_loss: 0.9373 - val_accuracy: 0.4606 Epoch 3/100 33/33 [==============================] - 0s 2ms/step - loss: 0.8559 - accuracy: 0.4801 - val_loss: 0.8066 - val_accuracy: 0.5044 Epoch 4/100 33/33 [==============================] - 0s 2ms/step - loss: 0.7452 - accuracy: 0.5432 - val_loss: 0.7030 - val_accuracy: 0.5627 Epoch 5/100 33/33 [==============================] - 0s 2ms/step - loss: 0.6581 - accuracy: 0.6210 - val_loss: 0.6263 - val_accuracy: 0.6385 Epoch 6/100 33/33 [==============================] - 0s 2ms/step - loss: 0.5934 - accuracy: 0.6890 - val_loss: 0.5683 - val_accuracy: 0.7289 Epoch 7/100 33/33 [==============================] - 0s 2ms/step - loss: 0.5438 - accuracy: 0.7891 - val_loss: 0.5240 - val_accuracy: 0.8367 Epoch 8/100 33/33 [==============================] - 0s 2ms/step - loss: 0.5052 - accuracy: 0.8707 - val_loss: 0.4889 - val_accuracy: 0.8688 Epoch 9/100 33/33 [==============================] - 0s 2ms/step - loss: 0.4742 - accuracy: 0.9018 - val_loss: 0.4609 - val_accuracy: 0.8921 Epoch 10/100 33/33 [==============================] - 0s 2ms/step - loss: 0.4489 - accuracy: 0.9106 - val_loss: 0.4384 - val_accuracy: 0.8980 Epoch 11/100 33/33 [==============================] - 0s 2ms/step - loss: 0.4282 - accuracy: 0.9145 - val_loss: 0.4194 - val_accuracy: 0.9067 Epoch 12/100 33/33 [==============================] - 0s 2ms/step - loss: 0.4105 - accuracy: 0.9203 - val_loss: 0.4030 - val_accuracy: 0.9184 Epoch 13/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3950 - accuracy: 0.9164 - val_loss: 0.3887 - val_accuracy: 0.9184 Epoch 14/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3814 - accuracy: 0.9271 - val_loss: 0.3762 - val_accuracy: 0.9155 Epoch 15/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3694 - accuracy: 0.9261 - val_loss: 0.3649 - val_accuracy: 0.9184 Epoch 16/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3586 - accuracy: 0.9291 - val_loss: 0.3550 - val_accuracy: 0.9213 Epoch 17/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3489 - accuracy: 0.9310 - val_loss: 0.3457 - val_accuracy: 0.9242 Epoch 18/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3398 - accuracy: 0.9349 - val_loss: 0.3373 - val_accuracy: 0.9242 Epoch 19/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3316 - accuracy: 0.9349 - val_loss: 0.3295 - val_accuracy: 0.9242 Epoch 20/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3239 - accuracy: 0.9378 - val_loss: 0.3222 - val_accuracy: 0.9242 Epoch 21/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3167 - accuracy: 0.9388 - val_loss: 0.3155 - val_accuracy: 0.9242 Epoch 22/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3101 - accuracy: 0.9427 - val_loss: 0.3092 - val_accuracy: 0.9271 Epoch 23/100 33/33 [==============================] - 0s 2ms/step - loss: 0.3038 - accuracy: 0.9397 - val_loss: 0.3032 - val_accuracy: 0.9300 Epoch 24/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2979 - accuracy: 0.9397 - val_loss: 0.2975 - val_accuracy: 0.9271 Epoch 25/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2923 - accuracy: 0.9407 - val_loss: 0.2922 - val_accuracy: 0.9329 Epoch 26/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2870 - accuracy: 0.9407 - val_loss: 0.2871 - val_accuracy: 0.9359 Epoch 27/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2820 - accuracy: 0.9407 - val_loss: 0.2824 - val_accuracy: 0.9388 Epoch 28/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2773 - accuracy: 0.9436 - val_loss: 0.2777 - val_accuracy: 0.9388 Epoch 29/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2727 - accuracy: 0.9456 - val_loss: 0.2733 - val_accuracy: 0.9446 Epoch 30/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2684 - accuracy: 0.9475 - val_loss: 0.2691 - val_accuracy: 0.9446 Epoch 31/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2642 - accuracy: 0.9475 - val_loss: 0.2651 - val_accuracy: 0.9446 Epoch 32/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2602 - accuracy: 0.9495 - val_loss: 0.2612 - val_accuracy: 0.9446 Epoch 33/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2564 - accuracy: 0.9495 - val_loss: 0.2575 - val_accuracy: 0.9446 Epoch 34/100 33/33 [==============================] - 0s 2ms/step - loss: 0.2527 - accuracy: 0.9495 - val_loss: 0.2539 - val_accuracy: 0.9446
Deep learning models can take a long time to train, especially when you move to deeper architectures and bigger datasets. Saving your model every time it improves as well as stopping it when it no longer does allows you to worry less about choosing the number of epochs to train for. You can also restore a saved model anytime and resume training where you left it.
Use the EarlyStopping()
and the ModelCheckpoint()
callbacks so that you can go eat a jar of cookies while you leave your computer to work!
from tensorflow.keras.callbacks import ModelCheckpoint
# Early stop on validation accuracy
monitor_val_acc = EarlyStopping(monitor='val_accuracy', patience=3)
# Save the best model as best_banknote_model.hdf5
modelCheckpoint = ModelCheckpoint('./best_banknote_model.hdf5', save_best_only=True)
# Fit your model for a stupid amount of epochs
h_callback = model.fit(X_train, y_train,
epochs=100000000000,
callbacks=[monitor_val_acc, modelCheckpoint],
validation_data=(X_test, y_test))
Epoch 1/100000000000 33/33 [==============================] - 0s 2ms/step - loss: 0.2492 - accuracy: 0.9504 - val_loss: 0.2504 - val_accuracy: 0.9475 Epoch 2/100000000000 33/33 [==============================] - 0s 2ms/step - loss: 0.2457 - accuracy: 0.9514 - val_loss: 0.2471 - val_accuracy: 0.9475 Epoch 3/100000000000 33/33 [==============================] - 0s 2ms/step - loss: 0.2425 - accuracy: 0.9524 - val_loss: 0.2438 - val_accuracy: 0.9475 Epoch 4/100000000000 33/33 [==============================] - 0s 2ms/step - loss: 0.2393 - accuracy: 0.9543 - val_loss: 0.2407 - val_accuracy: 0.9475
!ls | grep best_banknote*
best_banknote_model.hdf5
Now you always save the model that performed best, even if you early stopped at one that was already performing worse.