import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
(X_train, y_train), (X_test, y_test) = mnist.load_data()
plt.imshow(X_train[0]) # show first number in the dataset
plt.show()
print('Label: ', y_train[0])
Label: 5
plt.imshow(X_test[0]) # show first number in the dataset
plt.show()
print('Label: ', y_test[0])
Label: 7
# reshaping X data: (n, 28, 28) => (n, 784)
X_train = X_train.reshape((X_train.shape[0], -1))
X_test = X_test.reshape((X_test.shape[0], -1))
# converting y data into categorical (one-hot encoding)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
(60000, 784) (10000, 784) (60000, 10) (10000, 10)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras import optimizers
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
history = model.fit(X_train, y_train, batch_size = 256, validation_split = 0.3, epochs = 100, verbose = 1)
Epoch 1/100 165/165 [==============================] - 1s 7ms/step - loss: 2.4838 - accuracy: 0.0995 - val_loss: 2.4486 - val_accuracy: 0.0966 Epoch 2/100 165/165 [==============================] - 1s 5ms/step - loss: 2.4176 - accuracy: 0.0995 - val_loss: 2.3965 - val_accuracy: 0.0966 Epoch 3/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3761 - accuracy: 0.0995 - val_loss: 2.3634 - val_accuracy: 0.0966 Epoch 4/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3496 - accuracy: 0.0995 - val_loss: 2.3422 - val_accuracy: 0.0966 Epoch 5/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3327 - accuracy: 0.0995 - val_loss: 2.3283 - val_accuracy: 0.0966 Epoch 6/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3217 - accuracy: 0.0997 - val_loss: 2.3193 - val_accuracy: 0.0970 Epoch 7/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3145 - accuracy: 0.1019 - val_loss: 2.3135 - val_accuracy: 0.0989 Epoch 8/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3099 - accuracy: 0.0827 - val_loss: 2.3096 - val_accuracy: 0.1059 Epoch 9/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3070 - accuracy: 0.1140 - val_loss: 2.3072 - val_accuracy: 0.1079 Epoch 10/100 165/165 [==============================] - 1s 4ms/step - loss: 2.3050 - accuracy: 0.1143 - val_loss: 2.3055 - val_accuracy: 0.1079 Epoch 11/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3038 - accuracy: 0.1143 - val_loss: 2.3044 - val_accuracy: 0.1079 Epoch 12/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3029 - accuracy: 0.1143 - val_loss: 2.3036 - val_accuracy: 0.1079 Epoch 13/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3023 - accuracy: 0.1143 - val_loss: 2.3031 - val_accuracy: 0.1079 Epoch 14/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3019 - accuracy: 0.1143 - val_loss: 2.3028 - val_accuracy: 0.1079 Epoch 15/100 165/165 [==============================] - 1s 4ms/step - loss: 2.3016 - accuracy: 0.1143 - val_loss: 2.3025 - val_accuracy: 0.1079 Epoch 16/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3014 - accuracy: 0.1143 - val_loss: 2.3023 - val_accuracy: 0.1079 Epoch 17/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3013 - accuracy: 0.1143 - val_loss: 2.3021 - val_accuracy: 0.1079 Epoch 18/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3011 - accuracy: 0.1143 - val_loss: 2.3020 - val_accuracy: 0.1079 Epoch 19/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3010 - accuracy: 0.1143 - val_loss: 2.3018 - val_accuracy: 0.1079 Epoch 20/100 165/165 [==============================] - 1s 4ms/step - loss: 2.3009 - accuracy: 0.1143 - val_loss: 2.3017 - val_accuracy: 0.1079 Epoch 21/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3008 - accuracy: 0.1143 - val_loss: 2.3017 - val_accuracy: 0.1079 Epoch 22/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3007 - accuracy: 0.1143 - val_loss: 2.3016 - val_accuracy: 0.1079 Epoch 23/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3007 - accuracy: 0.1143 - val_loss: 2.3015 - val_accuracy: 0.1079 Epoch 24/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3006 - accuracy: 0.1143 - val_loss: 2.3014 - val_accuracy: 0.1079 Epoch 25/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3005 - accuracy: 0.1143 - val_loss: 2.3013 - val_accuracy: 0.1079 Epoch 26/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3004 - accuracy: 0.1143 - val_loss: 2.3012 - val_accuracy: 0.1079 Epoch 27/100 165/165 [==============================] - 1s 4ms/step - loss: 2.3003 - accuracy: 0.1143 - val_loss: 2.3011 - val_accuracy: 0.1079 Epoch 28/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3002 - accuracy: 0.1143 - val_loss: 2.3010 - val_accuracy: 0.1079 Epoch 29/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.3009 - val_accuracy: 0.1079 Epoch 30/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.3008 - val_accuracy: 0.1079 Epoch 31/100 165/165 [==============================] - 1s 5ms/step - loss: 2.3000 - accuracy: 0.1143 - val_loss: 2.3007 - val_accuracy: 0.1079 Epoch 32/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2999 - accuracy: 0.1143 - val_loss: 2.3007 - val_accuracy: 0.1079 Epoch 33/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2998 - accuracy: 0.1143 - val_loss: 2.3006 - val_accuracy: 0.1079 Epoch 34/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2997 - accuracy: 0.1143 - val_loss: 2.3005 - val_accuracy: 0.1079 Epoch 35/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2996 - accuracy: 0.1143 - val_loss: 2.3004 - val_accuracy: 0.1079 Epoch 36/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2995 - accuracy: 0.1143 - val_loss: 2.3003 - val_accuracy: 0.1079 Epoch 37/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2995 - accuracy: 0.1143 - val_loss: 2.3002 - val_accuracy: 0.1079 Epoch 38/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2994 - accuracy: 0.1143 - val_loss: 2.3002 - val_accuracy: 0.1079 Epoch 39/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2993 - accuracy: 0.1143 - val_loss: 2.3001 - val_accuracy: 0.1079 Epoch 40/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2992 - accuracy: 0.1143 - val_loss: 2.3000 - val_accuracy: 0.1079 Epoch 41/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2991 - accuracy: 0.1143 - val_loss: 2.2999 - val_accuracy: 0.1079 Epoch 42/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2991 - accuracy: 0.1143 - val_loss: 2.2998 - val_accuracy: 0.1079 Epoch 43/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2990 - accuracy: 0.1143 - val_loss: 2.2998 - val_accuracy: 0.1079 Epoch 44/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2989 - accuracy: 0.1143 - val_loss: 2.2997 - val_accuracy: 0.1079 Epoch 45/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2988 - accuracy: 0.1143 - val_loss: 2.2996 - val_accuracy: 0.1079 Epoch 46/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2987 - accuracy: 0.1143 - val_loss: 2.2995 - val_accuracy: 0.1079 Epoch 47/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2987 - accuracy: 0.1143 - val_loss: 2.2994 - val_accuracy: 0.1079 Epoch 48/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2986 - accuracy: 0.1143 - val_loss: 2.2994 - val_accuracy: 0.1079 Epoch 49/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2985 - accuracy: 0.1143 - val_loss: 2.2993 - val_accuracy: 0.1079 Epoch 50/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2984 - accuracy: 0.1143 - val_loss: 2.2992 - val_accuracy: 0.1079 Epoch 51/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2983 - accuracy: 0.1143 - val_loss: 2.2991 - val_accuracy: 0.1079 Epoch 52/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2983 - accuracy: 0.1143 - val_loss: 2.2991 - val_accuracy: 0.1079 Epoch 53/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2982 - accuracy: 0.1143 - val_loss: 2.2990 - val_accuracy: 0.1079 Epoch 54/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2981 - accuracy: 0.1143 - val_loss: 2.2989 - val_accuracy: 0.1079 Epoch 55/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2980 - accuracy: 0.1143 - val_loss: 2.2988 - val_accuracy: 0.1079 Epoch 56/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2979 - accuracy: 0.1143 - val_loss: 2.2988 - val_accuracy: 0.1079 Epoch 57/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2979 - accuracy: 0.1143 - val_loss: 2.2987 - val_accuracy: 0.1079 Epoch 58/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2978 - accuracy: 0.1143 - val_loss: 2.2986 - val_accuracy: 0.1079 Epoch 59/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2977 - accuracy: 0.1143 - val_loss: 2.2985 - val_accuracy: 0.1079 Epoch 60/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2976 - accuracy: 0.1143 - val_loss: 2.2984 - val_accuracy: 0.1079 Epoch 61/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2975 - accuracy: 0.1143 - val_loss: 2.2983 - val_accuracy: 0.1079 Epoch 62/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2975 - accuracy: 0.1143 - val_loss: 2.2982 - val_accuracy: 0.1079 Epoch 63/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2974 - accuracy: 0.1143 - val_loss: 2.2981 - val_accuracy: 0.1079 Epoch 64/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2973 - accuracy: 0.1143 - val_loss: 2.2981 - val_accuracy: 0.1079 Epoch 65/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2972 - accuracy: 0.1143 - val_loss: 2.2980 - val_accuracy: 0.1079 Epoch 66/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2971 - accuracy: 0.1143 - val_loss: 2.2979 - val_accuracy: 0.1079 Epoch 67/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2971 - accuracy: 0.1143 - val_loss: 2.2978 - val_accuracy: 0.1079 Epoch 68/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2970 - accuracy: 0.1143 - val_loss: 2.2977 - val_accuracy: 0.1079 Epoch 69/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2969 - accuracy: 0.1143 - val_loss: 2.2976 - val_accuracy: 0.1079 Epoch 70/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2968 - accuracy: 0.1143 - val_loss: 2.2976 - val_accuracy: 0.1079 Epoch 71/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2967 - accuracy: 0.1143 - val_loss: 2.2975 - val_accuracy: 0.1079 Epoch 72/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2967 - accuracy: 0.1143 - val_loss: 2.2974 - val_accuracy: 0.1079 Epoch 73/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2966 - accuracy: 0.1143 - val_loss: 2.2973 - val_accuracy: 0.1079 Epoch 74/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2965 - accuracy: 0.1143 - val_loss: 2.2972 - val_accuracy: 0.1079 Epoch 75/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2964 - accuracy: 0.1143 - val_loss: 2.2972 - val_accuracy: 0.1079 Epoch 76/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2963 - accuracy: 0.1143 - val_loss: 2.2971 - val_accuracy: 0.1079 Epoch 77/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2963 - accuracy: 0.1143 - val_loss: 2.2970 - val_accuracy: 0.1079 Epoch 78/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2962 - accuracy: 0.1143 - val_loss: 2.2969 - val_accuracy: 0.1079 Epoch 79/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2961 - accuracy: 0.1143 - val_loss: 2.2969 - val_accuracy: 0.1079 Epoch 80/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2960 - accuracy: 0.1143 - val_loss: 2.2968 - val_accuracy: 0.1079 Epoch 81/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2959 - accuracy: 0.1143 - val_loss: 2.2967 - val_accuracy: 0.1079 Epoch 82/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2966 - val_accuracy: 0.1079 Epoch 83/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2965 - val_accuracy: 0.1079 Epoch 84/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2957 - accuracy: 0.1143 - val_loss: 2.2964 - val_accuracy: 0.1079 Epoch 85/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2956 - accuracy: 0.1143 - val_loss: 2.2963 - val_accuracy: 0.1079 Epoch 86/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2955 - accuracy: 0.1143 - val_loss: 2.2963 - val_accuracy: 0.1079 Epoch 87/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2954 - accuracy: 0.1143 - val_loss: 2.2962 - val_accuracy: 0.1079 Epoch 88/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2953 - accuracy: 0.1143 - val_loss: 2.2961 - val_accuracy: 0.1079 Epoch 89/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2952 - accuracy: 0.1143 - val_loss: 2.2960 - val_accuracy: 0.1079 Epoch 90/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2952 - accuracy: 0.1143 - val_loss: 2.2959 - val_accuracy: 0.1079 Epoch 91/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2951 - accuracy: 0.1143 - val_loss: 2.2958 - val_accuracy: 0.1079 Epoch 92/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2950 - accuracy: 0.1143 - val_loss: 2.2958 - val_accuracy: 0.1079 Epoch 93/100 165/165 [==============================] - 1s 4ms/step - loss: 2.2949 - accuracy: 0.1143 - val_loss: 2.2957 - val_accuracy: 0.1079 Epoch 94/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2948 - accuracy: 0.1143 - val_loss: 2.2956 - val_accuracy: 0.1079 Epoch 95/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2947 - accuracy: 0.1143 - val_loss: 2.2955 - val_accuracy: 0.1079 Epoch 96/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2946 - accuracy: 0.1143 - val_loss: 2.2954 - val_accuracy: 0.1079 Epoch 97/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2945 - accuracy: 0.1143 - val_loss: 2.2953 - val_accuracy: 0.1079 Epoch 98/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2945 - accuracy: 0.1143 - val_loss: 2.2952 - val_accuracy: 0.1079 Epoch 99/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2944 - accuracy: 0.1143 - val_loss: 2.2951 - val_accuracy: 0.1079 Epoch 100/100 165/165 [==============================] - 1s 5ms/step - loss: 2.2943 - accuracy: 0.1143 - val_loss: 2.2950 - val_accuracy: 0.1079
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Training and validation accuracy seems to improve after around 60 epochs
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 2.2944 - accuracy: 0.1135
print('Test accuracy: ', results[1])
Test accuracy: 0.11349999904632568
# from now on, create a function to generate (return) models
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, ), kernel_initializer='he_normal')) # use he_normal initializer
model.add(Activation('sigmoid'))
model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer
model.add(Activation('sigmoid'))
model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer
model.add(Activation('sigmoid'))
model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer
model.add(Activation('sigmoid'))
model.add(Dense(10, kernel_initializer='he_normal')) # use he_normal initializer
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model = mlp_model()
history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 1)
Epoch 1/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.3656 - accuracy: 0.1143 - val_loss: 2.3063 - val_accuracy: 0.1079 Epoch 2/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.2984 - val_accuracy: 0.1079 Epoch 3/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2955 - val_accuracy: 0.1079 Epoch 4/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2929 - accuracy: 0.1143 - val_loss: 2.2927 - val_accuracy: 0.1079 Epoch 5/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2902 - accuracy: 0.1143 - val_loss: 2.2900 - val_accuracy: 0.1079 Epoch 6/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2876 - accuracy: 0.1143 - val_loss: 2.2872 - val_accuracy: 0.1079 Epoch 7/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2849 - accuracy: 0.1143 - val_loss: 2.2844 - val_accuracy: 0.1079 Epoch 8/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2821 - accuracy: 0.1143 - val_loss: 2.2815 - val_accuracy: 0.1079 Epoch 9/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2791 - accuracy: 0.1143 - val_loss: 2.2785 - val_accuracy: 0.1080 Epoch 10/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2760 - accuracy: 0.1176 - val_loss: 2.2755 - val_accuracy: 0.1084 Epoch 11/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2731 - accuracy: 0.1170 - val_loss: 2.2724 - val_accuracy: 0.1083 Epoch 12/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2700 - accuracy: 0.1160 - val_loss: 2.2692 - val_accuracy: 0.1257 Epoch 13/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2668 - accuracy: 0.1228 - val_loss: 2.2659 - val_accuracy: 0.1266 Epoch 14/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2634 - accuracy: 0.1285 - val_loss: 2.2626 - val_accuracy: 0.1390 Epoch 15/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2599 - accuracy: 0.1586 - val_loss: 2.2588 - val_accuracy: 0.1309 Epoch 16/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2561 - accuracy: 0.1680 - val_loss: 2.2549 - val_accuracy: 0.1638 Epoch 17/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2520 - accuracy: 0.1872 - val_loss: 2.2507 - val_accuracy: 0.2122 Epoch 18/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2476 - accuracy: 0.2439 - val_loss: 2.2460 - val_accuracy: 0.1938 Epoch 19/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2429 - accuracy: 0.2310 - val_loss: 2.2410 - val_accuracy: 0.2599 Epoch 20/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2377 - accuracy: 0.2752 - val_loss: 2.2356 - val_accuracy: 0.2886 Epoch 21/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2321 - accuracy: 0.2968 - val_loss: 2.2296 - val_accuracy: 0.3260 Epoch 22/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2260 - accuracy: 0.3393 - val_loss: 2.2234 - val_accuracy: 0.3254 Epoch 23/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2194 - accuracy: 0.3507 - val_loss: 2.2164 - val_accuracy: 0.3616 Epoch 24/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2121 - accuracy: 0.3837 - val_loss: 2.2088 - val_accuracy: 0.3750 Epoch 25/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.2041 - accuracy: 0.3929 - val_loss: 2.2003 - val_accuracy: 0.4189 Epoch 26/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1953 - accuracy: 0.4090 - val_loss: 2.1911 - val_accuracy: 0.4568 Epoch 27/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1858 - accuracy: 0.4468 - val_loss: 2.1812 - val_accuracy: 0.4503 Epoch 28/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1753 - accuracy: 0.4504 - val_loss: 2.1702 - val_accuracy: 0.4673 Epoch 29/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1637 - accuracy: 0.4713 - val_loss: 2.1582 - val_accuracy: 0.4663 Epoch 30/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1507 - accuracy: 0.4746 - val_loss: 2.1445 - val_accuracy: 0.4826 Epoch 31/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1362 - accuracy: 0.4836 - val_loss: 2.1293 - val_accuracy: 0.5059 Epoch 32/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1204 - accuracy: 0.4987 - val_loss: 2.1129 - val_accuracy: 0.5090 Epoch 33/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.1032 - accuracy: 0.5108 - val_loss: 2.0953 - val_accuracy: 0.5013 Epoch 34/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.0844 - accuracy: 0.5070 - val_loss: 2.0756 - val_accuracy: 0.5133 Epoch 35/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.0637 - accuracy: 0.5155 - val_loss: 2.0542 - val_accuracy: 0.5234 Epoch 36/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.0411 - accuracy: 0.5282 - val_loss: 2.0305 - val_accuracy: 0.5218 Epoch 37/100 1313/1313 [==============================] - 2s 2ms/step - loss: 2.0164 - accuracy: 0.5228 - val_loss: 2.0052 - val_accuracy: 0.5376 Epoch 38/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.9896 - accuracy: 0.5318 - val_loss: 1.9774 - val_accuracy: 0.5453 Epoch 39/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.9607 - accuracy: 0.5416 - val_loss: 1.9474 - val_accuracy: 0.5428 Epoch 40/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.9295 - accuracy: 0.5428 - val_loss: 1.9155 - val_accuracy: 0.5483 Epoch 41/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.8963 - accuracy: 0.5511 - val_loss: 1.8813 - val_accuracy: 0.5532 Epoch 42/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.8610 - accuracy: 0.5515 - val_loss: 1.8453 - val_accuracy: 0.5653 Epoch 43/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.8240 - accuracy: 0.5595 - val_loss: 1.8076 - val_accuracy: 0.5727 Epoch 44/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.7854 - accuracy: 0.5665 - val_loss: 1.7686 - val_accuracy: 0.5749 Epoch 45/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.7452 - accuracy: 0.5699 - val_loss: 1.7278 - val_accuracy: 0.5812 Epoch 46/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.7038 - accuracy: 0.5786 - val_loss: 1.6858 - val_accuracy: 0.5847 Epoch 47/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.6616 - accuracy: 0.5838 - val_loss: 1.6430 - val_accuracy: 0.5901 Epoch 48/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.6185 - accuracy: 0.5881 - val_loss: 1.5999 - val_accuracy: 0.5997 Epoch 49/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.5749 - accuracy: 0.5949 - val_loss: 1.5558 - val_accuracy: 0.6107 Epoch 50/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.5313 - accuracy: 0.6037 - val_loss: 1.5124 - val_accuracy: 0.6128 Epoch 51/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.4881 - accuracy: 0.6118 - val_loss: 1.4686 - val_accuracy: 0.6141 Epoch 52/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.4447 - accuracy: 0.6153 - val_loss: 1.4259 - val_accuracy: 0.6274 Epoch 53/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.4026 - accuracy: 0.6263 - val_loss: 1.3840 - val_accuracy: 0.6333 Epoch 54/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.3612 - accuracy: 0.6324 - val_loss: 1.3434 - val_accuracy: 0.6470 Epoch 55/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.3206 - accuracy: 0.6457 - val_loss: 1.3039 - val_accuracy: 0.6506 Epoch 56/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.2821 - accuracy: 0.6517 - val_loss: 1.2672 - val_accuracy: 0.6568 Epoch 57/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.2454 - accuracy: 0.6589 - val_loss: 1.2306 - val_accuracy: 0.6644 Epoch 58/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.2104 - accuracy: 0.6664 - val_loss: 1.1962 - val_accuracy: 0.6756 Epoch 59/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.1771 - accuracy: 0.6752 - val_loss: 1.1638 - val_accuracy: 0.6823 Epoch 60/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.1453 - accuracy: 0.6855 - val_loss: 1.1338 - val_accuracy: 0.6892 Epoch 61/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.1154 - accuracy: 0.6917 - val_loss: 1.1042 - val_accuracy: 0.7019 Epoch 62/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.0874 - accuracy: 0.7005 - val_loss: 1.0775 - val_accuracy: 0.7031 Epoch 63/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.0615 - accuracy: 0.7063 - val_loss: 1.0526 - val_accuracy: 0.7127 Epoch 64/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.0366 - accuracy: 0.7143 - val_loss: 1.0286 - val_accuracy: 0.7214 Epoch 65/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.0128 - accuracy: 0.7223 - val_loss: 1.0067 - val_accuracy: 0.7251 Epoch 66/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.9912 - accuracy: 0.7271 - val_loss: 0.9850 - val_accuracy: 0.7341 Epoch 67/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.9702 - accuracy: 0.7352 - val_loss: 0.9644 - val_accuracy: 0.7395 Epoch 68/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.9500 - accuracy: 0.7442 - val_loss: 0.9455 - val_accuracy: 0.7449 Epoch 69/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.9313 - accuracy: 0.7496 - val_loss: 0.9282 - val_accuracy: 0.7540 Epoch 70/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.9141 - accuracy: 0.7563 - val_loss: 0.9115 - val_accuracy: 0.7555 Epoch 71/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8965 - accuracy: 0.7618 - val_loss: 0.8964 - val_accuracy: 0.7647 Epoch 72/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8794 - accuracy: 0.7692 - val_loss: 0.8789 - val_accuracy: 0.7761 Epoch 73/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8636 - accuracy: 0.7744 - val_loss: 0.8655 - val_accuracy: 0.7739 Epoch 74/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8482 - accuracy: 0.7787 - val_loss: 0.8504 - val_accuracy: 0.7805 Epoch 75/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8342 - accuracy: 0.7845 - val_loss: 0.8375 - val_accuracy: 0.7871 Epoch 76/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8214 - accuracy: 0.7887 - val_loss: 0.8230 - val_accuracy: 0.7866 Epoch 77/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.8069 - accuracy: 0.7934 - val_loss: 0.8115 - val_accuracy: 0.7932 Epoch 78/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7945 - accuracy: 0.7967 - val_loss: 0.8010 - val_accuracy: 0.7973 Epoch 79/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7816 - accuracy: 0.8010 - val_loss: 0.7875 - val_accuracy: 0.8012 Epoch 80/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7692 - accuracy: 0.8051 - val_loss: 0.7758 - val_accuracy: 0.8077 Epoch 81/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7577 - accuracy: 0.8087 - val_loss: 0.7662 - val_accuracy: 0.8058 Epoch 82/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7457 - accuracy: 0.8120 - val_loss: 0.7550 - val_accuracy: 0.8089 Epoch 83/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7345 - accuracy: 0.8165 - val_loss: 0.7434 - val_accuracy: 0.8163 Epoch 84/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7244 - accuracy: 0.8196 - val_loss: 0.7346 - val_accuracy: 0.8154 Epoch 85/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7130 - accuracy: 0.8231 - val_loss: 0.7243 - val_accuracy: 0.8217 Epoch 86/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.7017 - accuracy: 0.8275 - val_loss: 0.7152 - val_accuracy: 0.8234 Epoch 87/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6929 - accuracy: 0.8316 - val_loss: 0.7069 - val_accuracy: 0.8229 Epoch 88/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6835 - accuracy: 0.8341 - val_loss: 0.6979 - val_accuracy: 0.8276 Epoch 89/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6726 - accuracy: 0.8356 - val_loss: 0.6890 - val_accuracy: 0.8304 Epoch 90/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6645 - accuracy: 0.8394 - val_loss: 0.6823 - val_accuracy: 0.8307 Epoch 91/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6549 - accuracy: 0.8445 - val_loss: 0.6707 - val_accuracy: 0.8353 Epoch 92/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6459 - accuracy: 0.8457 - val_loss: 0.6651 - val_accuracy: 0.8388 Epoch 93/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6375 - accuracy: 0.8502 - val_loss: 0.6558 - val_accuracy: 0.8389 Epoch 94/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6291 - accuracy: 0.8533 - val_loss: 0.6479 - val_accuracy: 0.8426 Epoch 95/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6208 - accuracy: 0.8541 - val_loss: 0.6406 - val_accuracy: 0.8453 Epoch 96/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.6119 - accuracy: 0.8574 - val_loss: 0.6317 - val_accuracy: 0.8479 Epoch 97/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.6037 - accuracy: 0.8609 - val_loss: 0.6270 - val_accuracy: 0.8495 Epoch 98/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.5967 - accuracy: 0.8618 - val_loss: 0.6185 - val_accuracy: 0.8507 Epoch 99/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.5882 - accuracy: 0.8646 - val_loss: 0.6135 - val_accuracy: 0.8517 Epoch 100/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.5817 - accuracy: 0.8655 - val_loss: 0.6051 - val_accuracy: 0.8533
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Training and validation accuracy seems to improve after around 60 epochs
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 0.5939 - accuracy: 0.8625
print('Test accuracy: ', results[1])
Test accuracy: 0.862500011920929
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(Activation('relu')) # use relu
model.add(Dense(50))
model.add(Activation('relu')) # use relu
model.add(Dense(50))
model.add(Activation('relu')) # use relu
model.add(Dense(50))
model.add(Activation('relu')) # use relu
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model = mlp_model()
history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 1)
Epoch 1/100 1313/1313 [==============================] - 2s 2ms/step - loss: 1.1128 - accuracy: 0.7314 - val_loss: 0.5504 - val_accuracy: 0.8466 Epoch 2/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.4489 - accuracy: 0.8700 - val_loss: 0.4168 - val_accuracy: 0.8827 Epoch 3/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.3434 - accuracy: 0.8988 - val_loss: 0.3719 - val_accuracy: 0.8918 Epoch 4/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.2895 - accuracy: 0.9141 - val_loss: 0.3231 - val_accuracy: 0.9103 Epoch 5/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.2528 - accuracy: 0.9248 - val_loss: 0.2951 - val_accuracy: 0.9153 Epoch 6/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.2266 - accuracy: 0.9321 - val_loss: 0.2940 - val_accuracy: 0.9170 Epoch 7/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.2094 - accuracy: 0.9368 - val_loss: 0.2848 - val_accuracy: 0.9194 Epoch 8/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1935 - accuracy: 0.9423 - val_loss: 0.2735 - val_accuracy: 0.9215 Epoch 9/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1820 - accuracy: 0.9458 - val_loss: 0.2632 - val_accuracy: 0.9277 Epoch 10/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1702 - accuracy: 0.9486 - val_loss: 0.2506 - val_accuracy: 0.9304 Epoch 11/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1605 - accuracy: 0.9518 - val_loss: 0.2467 - val_accuracy: 0.9319 Epoch 12/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1527 - accuracy: 0.9535 - val_loss: 0.2390 - val_accuracy: 0.9327 Epoch 13/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1447 - accuracy: 0.9558 - val_loss: 0.2948 - val_accuracy: 0.9189 Epoch 14/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1380 - accuracy: 0.9576 - val_loss: 0.2315 - val_accuracy: 0.9362 Epoch 15/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1324 - accuracy: 0.9604 - val_loss: 0.2351 - val_accuracy: 0.9369 Epoch 16/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1266 - accuracy: 0.9624 - val_loss: 0.2313 - val_accuracy: 0.9373 Epoch 17/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1217 - accuracy: 0.9626 - val_loss: 0.2308 - val_accuracy: 0.9385 Epoch 18/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1172 - accuracy: 0.9646 - val_loss: 0.2286 - val_accuracy: 0.9398 Epoch 19/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1130 - accuracy: 0.9659 - val_loss: 0.2267 - val_accuracy: 0.9398 Epoch 20/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1095 - accuracy: 0.9664 - val_loss: 0.2253 - val_accuracy: 0.9408 Epoch 21/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1049 - accuracy: 0.9676 - val_loss: 0.2281 - val_accuracy: 0.9411 Epoch 22/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.1017 - accuracy: 0.9695 - val_loss: 0.2252 - val_accuracy: 0.9417 Epoch 23/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0984 - accuracy: 0.9697 - val_loss: 0.2267 - val_accuracy: 0.9412 Epoch 24/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0946 - accuracy: 0.9718 - val_loss: 0.2310 - val_accuracy: 0.9422 Epoch 25/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0921 - accuracy: 0.9720 - val_loss: 0.2272 - val_accuracy: 0.9419 Epoch 26/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0891 - accuracy: 0.9733 - val_loss: 0.2260 - val_accuracy: 0.9438 Epoch 27/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0860 - accuracy: 0.9749 - val_loss: 0.2251 - val_accuracy: 0.9429 Epoch 28/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0823 - accuracy: 0.9757 - val_loss: 0.2363 - val_accuracy: 0.9436 Epoch 29/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0800 - accuracy: 0.9759 - val_loss: 0.2431 - val_accuracy: 0.9423 Epoch 30/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0776 - accuracy: 0.9765 - val_loss: 0.2289 - val_accuracy: 0.9447 Epoch 31/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0759 - accuracy: 0.9773 - val_loss: 0.2300 - val_accuracy: 0.9434 Epoch 32/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0735 - accuracy: 0.9784 - val_loss: 0.2301 - val_accuracy: 0.9450 Epoch 33/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0715 - accuracy: 0.9792 - val_loss: 0.2318 - val_accuracy: 0.9441 Epoch 34/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0692 - accuracy: 0.9794 - val_loss: 0.2324 - val_accuracy: 0.9446 Epoch 35/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0676 - accuracy: 0.9804 - val_loss: 0.2306 - val_accuracy: 0.9457 Epoch 36/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0651 - accuracy: 0.9807 - val_loss: 0.2379 - val_accuracy: 0.9436 Epoch 37/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0639 - accuracy: 0.9815 - val_loss: 0.2304 - val_accuracy: 0.9456 Epoch 38/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0613 - accuracy: 0.9825 - val_loss: 0.2417 - val_accuracy: 0.9455 Epoch 39/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0600 - accuracy: 0.9825 - val_loss: 0.2418 - val_accuracy: 0.9459 Epoch 40/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0588 - accuracy: 0.9832 - val_loss: 0.2460 - val_accuracy: 0.9445 Epoch 41/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0572 - accuracy: 0.9828 - val_loss: 0.2414 - val_accuracy: 0.9452 Epoch 42/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0552 - accuracy: 0.9837 - val_loss: 0.2456 - val_accuracy: 0.9441 Epoch 43/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0544 - accuracy: 0.9844 - val_loss: 0.2437 - val_accuracy: 0.9461 Epoch 44/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0532 - accuracy: 0.9843 - val_loss: 0.2438 - val_accuracy: 0.9466 Epoch 45/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0511 - accuracy: 0.9855 - val_loss: 0.2429 - val_accuracy: 0.9460 Epoch 46/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0492 - accuracy: 0.9862 - val_loss: 0.2568 - val_accuracy: 0.9448 Epoch 47/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0485 - accuracy: 0.9865 - val_loss: 0.2469 - val_accuracy: 0.9459 Epoch 48/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0469 - accuracy: 0.9863 - val_loss: 0.2509 - val_accuracy: 0.9459 Epoch 49/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0467 - accuracy: 0.9863 - val_loss: 0.2481 - val_accuracy: 0.9472 Epoch 50/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0447 - accuracy: 0.9870 - val_loss: 0.2511 - val_accuracy: 0.9464 Epoch 51/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0442 - accuracy: 0.9868 - val_loss: 0.2573 - val_accuracy: 0.9456 Epoch 52/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0425 - accuracy: 0.9875 - val_loss: 0.2563 - val_accuracy: 0.9455 Epoch 53/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0419 - accuracy: 0.9881 - val_loss: 0.2541 - val_accuracy: 0.9458 Epoch 54/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0404 - accuracy: 0.9886 - val_loss: 0.2624 - val_accuracy: 0.9472 Epoch 55/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0396 - accuracy: 0.9885 - val_loss: 0.2640 - val_accuracy: 0.9445 Epoch 56/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0388 - accuracy: 0.9888 - val_loss: 0.2718 - val_accuracy: 0.9448 Epoch 57/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0375 - accuracy: 0.9895 - val_loss: 0.2643 - val_accuracy: 0.9468 Epoch 58/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0361 - accuracy: 0.9897 - val_loss: 0.2728 - val_accuracy: 0.9439 Epoch 59/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0360 - accuracy: 0.9897 - val_loss: 0.2633 - val_accuracy: 0.9461 Epoch 60/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0347 - accuracy: 0.9905 - val_loss: 0.2640 - val_accuracy: 0.9468 Epoch 61/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0336 - accuracy: 0.9906 - val_loss: 0.2716 - val_accuracy: 0.9468 Epoch 62/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0330 - accuracy: 0.9908 - val_loss: 0.2772 - val_accuracy: 0.9449 Epoch 63/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0325 - accuracy: 0.9907 - val_loss: 0.2772 - val_accuracy: 0.9453 Epoch 64/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0320 - accuracy: 0.9911 - val_loss: 0.2770 - val_accuracy: 0.9471 Epoch 65/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0309 - accuracy: 0.9917 - val_loss: 0.2720 - val_accuracy: 0.9482 Epoch 66/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0302 - accuracy: 0.9918 - val_loss: 0.2828 - val_accuracy: 0.9473 Epoch 67/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0288 - accuracy: 0.9926 - val_loss: 0.2858 - val_accuracy: 0.9460 Epoch 68/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0287 - accuracy: 0.9923 - val_loss: 0.2855 - val_accuracy: 0.9459 Epoch 69/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0277 - accuracy: 0.9925 - val_loss: 0.2984 - val_accuracy: 0.9451 Epoch 70/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0269 - accuracy: 0.9929 - val_loss: 0.2891 - val_accuracy: 0.9447 Epoch 71/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0264 - accuracy: 0.9929 - val_loss: 0.2865 - val_accuracy: 0.9466 Epoch 72/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0251 - accuracy: 0.9933 - val_loss: 0.2896 - val_accuracy: 0.9472 Epoch 73/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0245 - accuracy: 0.9934 - val_loss: 0.2961 - val_accuracy: 0.9467 Epoch 74/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0248 - accuracy: 0.9934 - val_loss: 0.2961 - val_accuracy: 0.9462 Epoch 75/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0243 - accuracy: 0.9937 - val_loss: 0.2974 - val_accuracy: 0.9469 Epoch 76/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0226 - accuracy: 0.9943 - val_loss: 0.2977 - val_accuracy: 0.9474 Epoch 77/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0223 - accuracy: 0.9946 - val_loss: 0.3038 - val_accuracy: 0.9480 Epoch 78/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0220 - accuracy: 0.9940 - val_loss: 0.3051 - val_accuracy: 0.9472 Epoch 79/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0216 - accuracy: 0.9944 - val_loss: 0.3116 - val_accuracy: 0.9462 Epoch 80/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0207 - accuracy: 0.9950 - val_loss: 0.3021 - val_accuracy: 0.9489 Epoch 81/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0203 - accuracy: 0.9945 - val_loss: 0.3098 - val_accuracy: 0.9484 Epoch 82/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0196 - accuracy: 0.9952 - val_loss: 0.3139 - val_accuracy: 0.9468 Epoch 83/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0193 - accuracy: 0.9953 - val_loss: 0.3157 - val_accuracy: 0.9466 Epoch 84/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0188 - accuracy: 0.9954 - val_loss: 0.3180 - val_accuracy: 0.9458 Epoch 85/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0191 - accuracy: 0.9953 - val_loss: 0.3196 - val_accuracy: 0.9485 Epoch 86/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0182 - accuracy: 0.9953 - val_loss: 0.3267 - val_accuracy: 0.9452 Epoch 87/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0176 - accuracy: 0.9960 - val_loss: 0.3224 - val_accuracy: 0.9472 Epoch 88/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0177 - accuracy: 0.9957 - val_loss: 0.3195 - val_accuracy: 0.9476 Epoch 89/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0167 - accuracy: 0.9961 - val_loss: 0.3250 - val_accuracy: 0.9476 Epoch 90/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0163 - accuracy: 0.9964 - val_loss: 0.3281 - val_accuracy: 0.9472 Epoch 91/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0155 - accuracy: 0.9966 - val_loss: 0.3259 - val_accuracy: 0.9470 Epoch 92/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0154 - accuracy: 0.9964 - val_loss: 0.3387 - val_accuracy: 0.9447 Epoch 93/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0152 - accuracy: 0.9968 - val_loss: 0.3323 - val_accuracy: 0.9466 Epoch 94/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0147 - accuracy: 0.9968 - val_loss: 0.3295 - val_accuracy: 0.9470 Epoch 95/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0147 - accuracy: 0.9970 - val_loss: 0.3338 - val_accuracy: 0.9472 Epoch 96/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0143 - accuracy: 0.9970 - val_loss: 0.3401 - val_accuracy: 0.9477 Epoch 97/100 1313/1313 [==============================] - 3s 2ms/step - loss: 0.0143 - accuracy: 0.9968 - val_loss: 0.3434 - val_accuracy: 0.9459 Epoch 98/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0136 - accuracy: 0.9971 - val_loss: 0.3385 - val_accuracy: 0.9469 Epoch 99/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0130 - accuracy: 0.9975 - val_loss: 0.3472 - val_accuracy: 0.9472 Epoch 100/100 1313/1313 [==============================] - 2s 2ms/step - loss: 0.0127 - accuracy: 0.9977 - val_loss: 0.3448 - val_accuracy: 0.9450
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Training and validation accuracy improve instantaneously, but reach a plateau after around 30 epochs
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 0.3152 - accuracy: 0.9488
print('Test accuracy: ', results[1])
Test accuracy: 0.9488000273704529
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))
adam = optimizers.Adam(lr = 0.001) # use Adam optimizer
model.compile(optimizer = adam, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model = mlp_model()
history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Training and validation accuracy improve instantaneously, but reach plateau after around 50 epochs
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 0.1801 - accuracy: 0.9465
print('Test accuracy: ', results[1])
Test accuracy: 0.9465000033378601
Batch normalization layer is usually inserted after dense/convolution and before nonlinearity
from keras.layers import BatchNormalization
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(BatchNormalization()) # Add Batchnorm layer before Activation
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(BatchNormalization()) # Add Batchnorm layer before Activation
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(BatchNormalization()) # Add Batchnorm layer before Activation
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(BatchNormalization()) # Add Batchnorm layer before Activation
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model = mlp_model()
history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Training and validation accuracy improve consistently, but reach plateau after around 60 epochs
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 0.1866 - accuracy: 0.9481
print('Test accuracy: ', results[1])
Test accuracy: 0.9480999708175659
from keras.layers import Dropout
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(Activation('sigmoid'))
model.add(Dropout(0.2)) # Dropout layer after Activation
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dropout(0.2)) # Dropout layer after Activation
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dropout(0.2)) # Dropout layer after Activation
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dropout(0.2)) # Dropout layer after Activation
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model = mlp_model()
history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')
plt.show()
Validation results does not improve since it did not show signs of overfitting, yet.
Hence, the key takeaway message is that apply dropout when you see a signal of overfitting.
results = model.evaluate(X_test, y_test)
313/313 [==============================] - 0s 1ms/step - loss: 1.6782 - accuracy: 0.4227
print('Test accuracy: ', results[1])
Test accuracy: 0.4226999878883362
import numpy as np
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.ensemble import VotingClassifier
from sklearn.metrics import accuracy_score
y_train = np.argmax(y_train, axis = 1)
y_test = np.argmax(y_test, axis = 1)
def mlp_model():
model = Sequential()
model.add(Dense(50, input_shape = (784, )))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(50))
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))
sgd = optimizers.SGD(lr = 0.001)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])
return model
model1 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)
model2 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)
model3 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)
model1._estimator_type = "classifier"
model2._estimator_type = "classifier"
model3._estimator_type = "classifier"
ensemble_clf = VotingClassifier(estimators = [('model1', model1), ('model2', model2), ('model3', model3)]
, voting = 'soft')
ensemble_clf.fit(X_train, y_train)
VotingClassifier(estimators=[('model1', <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x7f3b1c1f9438>), ('model2', <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x7f3b1c1ce7b8>), ('model3', <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x7f3b1c10d978>)], flatten_transform=True, n_jobs=None, voting='soft', weights=None)
y_pred = ensemble_clf.predict(X_test)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/wrappers/scikit_learn.py:264: Sequential.predict_proba (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01. Instructions for updating: Please use `model.predict()` instead.
print('Test accuracy:', accuracy_score(y_pred, y_test))
Test accuracy: 0.9002
Below table is a summary of evaluation results so far. It turns out that all methods improve the test performance over the MNIST dataset. Why don't we try them out altogether?
Model | Baseline | Weight initialization | Activation function | Optimizer | Batchnormalization | Regularization | Ensemble |
---|---|---|---|---|---|---|---|
Test Accuracy | 0.1134 | 0.8625 | 0.9488 | 0.9465 | 0.9480 | 0.4226 | 0.9002 |