In this notebook we will use the Clustergrammer-Widget to visualize the MNIST dataset. The MNIST dataset contains 70,000 handwitten digits. The handwritten digit images are 28x28 pixels in size and each digit can be thought of as a 784 dimensional vector.
# import Pandas and Clustergrammer-Widget
import pandas as pd
from clustergrammer_widget import *
net = Network(clustergrammer_widget)
# load data and store in DataFrame
net.load_file('../processed_MNIST/MNIST_row_labels.txt')
mnist_df = net.export_df()
print(mnist_df.shape)
(784, 70000)
Here we are manually setting the category colors of some of the digits.
net.set_cat_color('col', 1, 'digit: Zero', 'yellow')
net.set_cat_color('col', 1, 'digit: One', 'blue')
net.set_cat_color('col', 1, 'digit: Two', 'orange')
net.set_cat_color('col', 1, 'digit: Three', 'aqua')
net.set_cat_color('col', 1, 'digit: Four', 'lime')
net.set_cat_color('col', 1, 'digit: Six', 'purple')
net.set_cat_color('col', 1, 'digit: Eight', 'red')
net.set_cat_color('col', 1, 'digit: Nine', 'black')
net.set_cat_color('col', 1, 'Majority-digit: Zero', 'yellow')
net.set_cat_color('col', 1, 'Majority-digit: One', 'blue')
net.set_cat_color('col', 1, 'Majority-digit: Two', 'orange')
net.set_cat_color('col', 1, 'Majority-digit: Three', 'aqua')
net.set_cat_color('col', 1, 'Majority-digit: Four', 'lime')
net.set_cat_color('col', 1, 'Majority-digit: Six', 'purple')
net.set_cat_color('col', 1, 'Majority-digit: Eight', 'red')
net.set_cat_color('col', 1, 'Majority-digit: Nine', 'black')
We can not direclty visualize all 70,000 handwritten digits in the MNIST dataset. Insted we will take two approaches to visualizing the MNIST data: 1) random subsampling from the dataset, 2) downsampling using K-means.
Here we will randomly subsample 300 digits from the dataset. We will filter for the top 500 pixels based on their sum and this will remove pixels from the corners of the images which are always zero or almost always zero.
net.load_df(mnist_df)
net.random_sample(300, axis='col', random_state=99)
net.filter_N_top('row', rank_type='sum', N_top=500)
net.cluster()
net.widget()
Above we see a heatmap of digits as columns and pixels as rows. We see that digits tend to cluster together, e.g. blue ones.
Each pixel has a value-based category, 'Center', which is higest for pixels near the center of the image. Reordering based on the center category highlights broad patterns in pixel distributions, such as Zeros and Sevens generally have low values for pixels near the center of the image.
We can use the "Top rows sum" and "Top rows variance" sliders to filter out rows (pixels) based on sum and variance and observe how this effects clustering. Filtering based on sum reduces clustering quality more than filtering based on variance.
Here we will use K-means clustering as a means to downsample our dataset. We will generate 300 clusters from our 70,000 digits and visualize these clusters using hierarchical clustering. Note that each digit-cluster (column) is labeled by the majority digit present in the cluster and the 'number in clust' value-based category shows how many digits are in each cluster (cluster sizes range from 50 to 500). This method gives us a broad overview of the entire MNSIT dataset.
net.load_df(mnist_df)
net.downsample(axis='col', ds_type='kmeans', num_samples=300)
net.filter_N_top('row', rank_type='sum', N_top=500)
net.cluster()
net.widget()
Again, with the downsampled data we see that digits tend to cluster together. We see clear clusering of Ones (blue), Zeros (yellow), and Twos (orange). Using the dendrogram, we also see mixing of digits that have similar shape like
Reordering based on pixel 'Center' value again shows us overall trends in the pixel distributions of different digits. We can also use the sliders to observe the effects of dimensionality reduction on clustering. For instance, we can retain fairly good clustering of Zeros, Sixes, and Ones when keeping only the top 50 most variable pixels.