이 노트북에서 LeNet-5과 비슷한 MNIST 손글씨 숫자를 분류하는 심층 합성곱 신경망을 만듭니다.
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Flatten, Conv2D, MaxPooling2D # new!
(X_train, y_train), (X_valid, y_valid) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step
X_train = X_train.reshape(60000, 28, 28, 1).astype('float32')
X_valid = X_valid.reshape(10000, 28, 28, 1).astype('float32')
X_train /= 255
X_valid /= 255
n_classes = 10
y_train = keras.utils.to_categorical(y_train, n_classes)
y_valid = keras.utils.to_categorical(y_valid, n_classes)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(n_classes, activation='softmax'))
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 26, 26, 32) 320 conv2d_1 (Conv2D) (None, 24, 24, 64) 18496 max_pooling2d (MaxPooling2D (None, 12, 12, 64) 0 ) dropout (Dropout) (None, 12, 12, 64) 0 flatten (Flatten) (None, 9216) 0 dense (Dense) (None, 128) 1179776 dropout_1 (Dropout) (None, 128) 0 dense_1 (Dense) (None, 10) 1290 ================================================================= Total params: 1,199,882 Trainable params: 1,199,882 Non-trainable params: 0 _________________________________________________________________
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=128, epochs=10, verbose=1, validation_data=(X_valid, y_valid))
Epoch 1/10 469/469 [==============================] - 15s 11ms/step - loss: 0.2413 - accuracy: 0.9275 - val_loss: 0.0489 - val_accuracy: 0.9849 Epoch 2/10 469/469 [==============================] - 5s 10ms/step - loss: 0.0879 - accuracy: 0.9739 - val_loss: 0.0349 - val_accuracy: 0.9869 Epoch 3/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0648 - accuracy: 0.9804 - val_loss: 0.0334 - val_accuracy: 0.9889 Epoch 4/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0548 - accuracy: 0.9836 - val_loss: 0.0331 - val_accuracy: 0.9901 Epoch 5/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0445 - accuracy: 0.9861 - val_loss: 0.0292 - val_accuracy: 0.9899 Epoch 6/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0383 - accuracy: 0.9886 - val_loss: 0.0246 - val_accuracy: 0.9920 Epoch 7/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0330 - accuracy: 0.9894 - val_loss: 0.0279 - val_accuracy: 0.9906 Epoch 8/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0313 - accuracy: 0.9900 - val_loss: 0.0270 - val_accuracy: 0.9918 Epoch 9/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0285 - accuracy: 0.9902 - val_loss: 0.0267 - val_accuracy: 0.9923 Epoch 10/10 469/469 [==============================] - 4s 9ms/step - loss: 0.0258 - accuracy: 0.9915 - val_loss: 0.0272 - val_accuracy: 0.9916
<keras.callbacks.History at 0x7f89d037f7c0>