In this notebook, we continue with the MNIST analysis after our initial exploration.
Let's again begin by reading in the MNIST dataset.
import conx as cx
Using TensorFlow backend. ConX, version 3.6.5
mnist = cx.Dataset.get('mnist')
mnist.info()
Dataset: MNIST
Original source: http://yann.lecun.com/exdb/mnist/
The MNIST dataset contains 70,000 images of handwritten digits (zero to nine) that have been size-normalized and centered in a square grid of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers representing grayscale intensities ranging from 0 (black) to 1 (white). The target data consists of one-hot binary vectors of size 10, corresponding to the digit classification categories zero through nine. Some example MNIST images are shown below:
Information:
Input Summary:
Target Summary:
Again, we build a CNN.
cnn = cx.Network("MNIST_CNN_Visualize")
cnn.add(cx.Layer("input", (28,28,1), colormap="gray"),
cx.Conv2DLayer("conv2D_1", 16, (5,5), activation="relu", dropout=0.20),
cx.MaxPool2DLayer("maxpool1", (2,2)),
cx.Conv2DLayer("conv2D_2", 32, (5,5), activation="relu", dropout=0.20),
cx.MaxPool2DLayer("maxpool2", (2,2)),
cx.FlattenLayer("flat"),
cx.Layer("hidden", 30, activation='relu'),
cx.Layer("output", 10, activation='softmax'))
cnn.connect()
cnn.get_dataset("MNIST")
cnn.dataset.split(10000)
cnn.dataset.summary()
_________________________________________________________________ MNIST: Patterns Shape Range ================================================================= inputs (28, 28, 1) (0.0, 1.0) targets (10,) (0.0, 1.0) ================================================================= Total patterns: 70000 Training patterns: 60000 Testing patterns: 10000 _________________________________________________________________
Again, we will try the RMSprop algorithm, which automatically adjusts the learning rate and momentum as training proceeds.
cnn.compile(error='categorical_crossentropy', optimizer='RMSprop')
cnn.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) (None, 28, 28, 1) 0 _________________________________________________________________ conv2D_1 (Conv2D) (None, 24, 24, 16) 416 _________________________________________________________________ dropout_1 (Dropout) (None, 24, 24, 16) 0 _________________________________________________________________ maxpool1 (MaxPooling2D) (None, 12, 12, 16) 0 _________________________________________________________________ conv2D_2 (Conv2D) (None, 8, 8, 32) 12832 _________________________________________________________________ dropout_2 (Dropout) (None, 8, 8, 32) 0 _________________________________________________________________ maxpool2 (MaxPooling2D) (None, 4, 4, 32) 0 _________________________________________________________________ flat (Flatten) (None, 512) 0 _________________________________________________________________ hidden (Dense) (None, 30) 15390 _________________________________________________________________ output (Dense) (None, 10) 310 ================================================================= Total params: 28,948 Trainable params: 28,948 Non-trainable params: 0 _________________________________________________________________
cnn.dashboard()
Failed to display Jupyter Widget of type Dashboard
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
Training our CNN on the full dataset will take a long time, so in the interest of brevity, let's reduce the size of our dataset by 80% using the chop
method, which deletes the specified fraction of data from the dataset.
print('Starting with', len(cnn.dataset), 'patterns')
cnn.dataset.chop(0.80)
print(len(cnn.dataset), 'patterns left after chop')
Starting with 70000 patterns 14000 patterns left after chop
WARNING: dataset split reset to 0
We then reserve 25% of the remaining data for testing.
cnn.dataset.split(0.25)
cnn.dataset.split()
(10500, 3500)
cnn.reset()
cnn.train(epochs=10, record=True)
======================================================== | Training | Training | Validate | Validate Epochs | Error | Accuracy | Error | Accuracy ------ | --------- | --------- | --------- | --------- # 10 | 0.03064 | 0.99057 | 0.06226 | 0.97857
Once the network is trained, we can find the Principal Components of the space of representations at the hidden layer.
First, we create a list of the representations of the inputs at the hidden layer. We use just the first 1,000 to speed things up.
states = [cnn.propagate_to("hidden", i) for i in cnn.dataset.train_inputs[:1000]]
Next, we find the Principal Components of these representations:
pca = cx.PCA(states)
We can then transform these representations into the first 2 Principal Component dimensions, scaling them between 0 and 1:
xy = pca.transform(states, scale=True)
Finally, we can plot the inputs in this transformed space to get an idea of how the network as learned to re-represent them:
cx.scatter(["Hidden states", xy])
Ok. However, that plot doesn't show how the items group together. In order to do that, we need to plot each category separately.
First, we make a dictionary mapping the category to the first 100 (to speed things up and not overwhelm the scatter plot) instances of that category:
groups = {category: cnn.dataset.inputs.select(lambda i,ds: ds.targets[i] == cx.onehot(category, 10),
slice=100,
index=True)
for category in range(10)}
Now we can do the same as above, but plotting each category separately:
def scatter_all(net, epoch):
lines = []
for i in range(10):
states = [net.propagate_to("hidden", i) for i in net.dataset.train_inputs[groups[i]]]
xy = pca.transform(states, scale=True)
lines.append([str(i), xy])
return cx.scatter(lines,
title="Hidden State after %s epochs" % epoch,
ymin=0, ymax=1,
xmin=0, xmax=1,
format="image")
scatter_all(cnn, 10)
We can now use this function with the playback and movie methods to see the development of the clustering of the 10 categories:
cnn.playback(scatter_all)
Failed to display Jupyter Widget of type SequenceViewer
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
cnn.movie(scatter_all)
That is useful, but it is impossible to see how the individual inputs vary in a cluster. To see this, instead of plotting a colored dot at the PCA location, we can actually show the input as an image.
First, we need a list of pictures, and their coordinates in PCA space. We use the first 1,000 to match the above:
pics = [cx.array_to_image(v) for v in cnn.dataset.inputs[:1000]]
states = [cnn.propagate_to("hidden", i) for i in cnn.dataset.train_inputs[:1000]]
xy = pca.transform(states, scale=True)
Now we can use conx's scatter_images function that takes a list of images of the inputs, and the (x,y) locations in PCA space:
cx.scatter_images(pics, xy, size=(1500,1500), scale=1.0)
You may have to increase the size of the image, or vary the scale to see the clusters with minimal overlap but not too much space between points/images.