- 🤖 See full list of Machine Learning Experiments on GitHub
- ▶️ Interactive Demo: try this model and other machine learning experiments in action
In this experiment we will build a Convolutional Neural Network (CNN) model using Tensorflow to recognize handwritten sketches by using a quick-draw dataset.
# Selecting Tensorflow version v2 (the command is relevant for Colab only).
# %tensorflow_version 2.x
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import math
import datetime
import platform
import pathlib
import random
print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)
Python version: 3.7.6 Tensorflow version: 2.1.0 Keras version: 2.2.4-tf
cache_dir = 'tmp';
# Create cache folder.
!mkdir tmp
mkdir: tmp: File exists
# List all available datasets to see how the wikipedia dataset is called.
tfds.list_builders()
['abstract_reasoning', 'aeslc', 'aflw2k3d', 'amazon_us_reviews', 'arc', 'bair_robot_pushing_small', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'cmaterdb', 'cnn_dailymail', 'coco', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'cos_e', 'curated_breast_imaging_ddsm', 'cycle_gan', 'deep_weeds', 'definite_pronoun_resolution', 'diabetic_retinopathy_detection', 'dmlab', 'downsampled_imagenet', 'dsprites', 'dtd', 'duke_ultrasound', 'dummy_dataset_shared_generator', 'dummy_mnist', 'emnist', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'gap', 'gigaword', 'glue', 'groove', 'higgs', 'horses_or_humans', 'i_naturalist2017', 'image_label_folder', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet_resized', 'imagenette', 'imdb_reviews', 'iris', 'kitti', 'kmnist', 'lfw', 'lm1b', 'lost_and_found', 'lsun', 'malaria', 'math_dataset', 'mnist', 'mnist_corrupted', 'movie_rationales', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'newsroom', 'nsynth', 'omniglot', 'open_images_v4', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'pet_finder', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'quickdraw_bitmap', 'reddit_tifu', 'resisc45', 'rock_paper_scissors', 'rock_you', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'shapes3d', 'smallnorb', 'snli', 'so2sat', 'squad', 'stanford_dogs', 'stanford_online_products', 'starcraft_video', 'sun397', 'super_glue', 'svhn_cropped', 'ted_hrlr_translate', 'ted_multi_translate', 'tf_flowers', 'the300w_lp', 'titanic', 'trivia_qa', 'uc_merced', 'ucf101', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'wider_face', 'wikihow', 'wikipedia', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'xnli', 'xsum']
DATASET_NAME = 'quickdraw_bitmap'
dataset, dataset_info = tfds.load(
name=DATASET_NAME,
data_dir=cache_dir,
with_info=True,
split=tfds.Split.TRAIN,
)
print(dataset_info)
tfds.core.DatasetInfo( name='quickdraw_bitmap', version=3.0.0, description='The Quick Draw Dataset is a collection of 50 million drawings across 345 categories, contributed by players of the game Quick, Draw!. The bitmap dataset contains these drawings converted from vector format into 28x28 grayscale images', homepage='https://github.com/googlecreativelab/quickdraw-dataset', features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=345), }), total_num_examples=50426266, splits={ 'train': 50426266, }, supervised_keys=('image', 'label'), citation="""@article{DBLP:journals/corr/HaE17, author = {David Ha and Douglas Eck}, title = {A Neural Representation of Sketch Drawings}, journal = {CoRR}, volume = {abs/1704.03477}, year = {2017}, url = {http://arxiv.org/abs/1704.03477}, archivePrefix = {arXiv}, eprint = {1704.03477}, timestamp = {Mon, 13 Aug 2018 16:48:30 +0200}, biburl = {https://dblp.org/rec/bib/journals/corr/HaE17}, bibsource = {dblp computer science bibliography, https://dblp.org} }""", redistribution_info=, )
image_shape = dataset_info.features['image'].shape
num_classes = dataset_info.features['label'].num_classes
num_examples = dataset_info.splits['train'].num_examples
print('num_examples: ', num_examples)
print('image_shape: ', image_shape)
print('num_classes: ', num_classes)
num_examples: 50426266 image_shape: (28, 28, 1) num_classes: 345
label_index_to_string = dataset_info.features['label'].int2str
classes = []
for class_index in range(num_classes):
classes.append(label_index_to_string(class_index))
print('classes num:', len(classes))
print('classes:\n\n', classes)
classes num: 345 classes: ['aircraft carrier', 'airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball bat', 'baseball', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden hose', 'garden', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house plant', 'house', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'rifle', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 'syringe', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']
print(dataset)
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
fig = tfds.show_examples(dataset_info, dataset)
def dataset_preview(dataset, image_shape, preview_images_num=100):
num_cells = math.ceil(math.sqrt(preview_images_num))
plt.figure(figsize=(17, 17))
image_size = image_shape[0]
for image_index, example in enumerate(dataset.take(preview_images_num)):
image = example['image']
label = example['label']
class_index = label.numpy()
class_name = classes[class_index]
plt.subplot(num_cells, num_cells, image_index + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(
np.reshape(image, (image_size, image_size)),
cmap=plt.cm.binary
)
plt.xlabel('{} ({})'.format(class_name, class_index))
plt.show()
def dataset_normalized_preview(dataset, image_shape, preview_images_num=100):
num_cells = math.ceil(math.sqrt(preview_images_num))
plt.figure(figsize=(17, 17))
image_size = image_shape[0]
for image_index, example in enumerate(dataset.take(preview_images_num)):
image = example[0]
label = tf.math.argmax(example[1])
class_index = label.numpy()
class_name = label_index_to_string(class_index)
plt.subplot(num_cells, num_cells, image_index + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(
np.reshape(image, (image_size, image_size)),
cmap=plt.cm.binary
)
plt.xlabel('{} ({})'.format(class_name, class_index))
plt.show()
def dataset_head(ds):
for example in ds.take(1):
image = example['image']
label = example['label']
class_index = label.numpy()
class_name = label_index_to_string(class_index)
print('{} ({})'.format(class_name, class_index), '\n')
print('Image shape: ', image.shape, '\n')
print(np.reshape(image.numpy(), (28, 28)), '\n')
dataset_head(dataset)
backpack (12) Image shape: (28, 28, 1) [[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 3 31 63 95 115 86 50 2 0 0 0 0 0 0 0] [ 0 0 0 0 0 75 183 176 152 10 54 182 222 250 255 255 255 255 255 255 231 151 61 0 0 0 0 0] [ 0 0 0 0 129 255 216 202 255 98 220 222 157 125 93 60 28 8 38 82 162 239 254 108 0 0 0 0] [ 0 0 0 33 250 158 4 0 204 209 255 236 105 0 0 0 0 0 0 0 0 6 164 254 88 0 0 0] [ 0 0 0 158 243 30 155 235 211 255 254 185 255 97 0 0 0 0 0 0 0 0 3 193 229 10 0 0] [ 0 0 38 251 134 153 251 162 252 245 225 1 180 241 6 0 0 0 0 0 0 0 0 69 255 62 0 0] [ 0 0 150 240 26 245 145 0 116 255 149 0 75 255 58 0 0 0 0 0 0 0 0 29 255 94 0 0] [ 0 0 209 173 77 255 62 0 62 255 68 0 14 251 122 0 0 0 0 0 0 0 0 2 249 128 0 0] [ 0 0 240 138 151 236 4 0 76 255 46 0 0 221 157 0 0 0 0 0 0 0 0 0 218 160 0 0] [ 0 0 251 127 188 188 0 0 76 255 46 0 0 213 164 0 4 45 0 0 0 0 0 0 186 193 0 0] [ 0 5 255 117 199 176 0 0 76 255 46 0 0 205 173 0 75 251 9 0 44 154 2 0 158 226 0 0] [ 0 14 255 108 209 166 0 0 76 255 46 0 0 160 243 102 44 61 34 35 98 199 132 170 251 252 6 0] [ 0 23 255 99 220 156 0 0 76 255 46 0 0 120 249 255 255 255 255 255 255 255 247 214 202 255 24 0] [ 0 17 255 108 221 162 0 0 72 255 51 0 0 129 247 31 82 85 85 85 62 28 1 0 108 255 13 0] [ 0 0 237 143 159 232 3 0 38 255 87 0 0 129 247 0 0 0 0 0 0 0 0 0 121 253 2 0] [ 0 0 198 199 73 255 78 0 4 248 127 0 0 129 247 0 0 0 20 63 39 8 0 0 134 242 0 0] [ 0 0 96 255 101 220 233 96 1 213 166 0 0 129 247 0 95 204 253 255 255 255 230 199 218 232 0 0] [ 0 0 1 180 252 136 189 255 106 173 207 0 0 129 247 0 239 228 116 63 84 115 147 185 255 231 0 0] [ 0 0 0 8 161 255 163 76 22 133 246 2 0 129 247 0 198 177 0 0 0 0 0 49 255 220 0 0] [ 0 0 0 0 0 113 251 242 196 240 255 37 0 129 247 0 208 168 0 0 0 0 0 130 254 190 0 0] [ 0 0 0 0 0 0 56 141 184 165 248 159 0 126 250 0 174 224 13 0 0 0 0 219 255 115 0 0] [ 0 0 0 0 0 0 0 0 0 0 127 254 60 116 255 5 69 252 230 96 2 0 46 255 253 30 0 0] [ 0 0 0 0 0 0 0 0 0 0 9 215 241 161 255 15 0 46 190 255 206 169 244 255 117 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 23 203 255 255 229 217 204 191 227 255 255 251 156 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 1 108 255 183 159 170 185 170 112 65 4 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 3 97 23 0 0 0 0 0 0 0 0 0 0 0 0] [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
dataset_preview(dataset, image_shape)
def normalize_example(example):
image = example['image']
label = example['label']
label = tf.one_hot(label, len(classes))
image = tf.math.divide(image, 255)
return (image, label)
def augment_example(image, label):
image = tf.image.random_flip_left_right(image)
return (image, label)
dataset_normalized = dataset.map(normalize_example).map(augment_example)
for (image, label) in dataset_normalized.take(1):
class_index = tf.math.argmax(label).numpy()
class_name = label_index_to_string(class_index)
print('{} ({})'.format(class_name, class_index), '\n')
print('Image shape: ', image.shape, '\n')
print(np.reshape(image.numpy(), (28, 28)), '\n')
backpack (12) Image shape: (28, 28, 1) [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.01176471 0.12156863 0.24705882 0.37254903 0.4509804 0.3372549 0.19607843 0.00784314 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0.29411766 0.7176471 0.6901961 0.59607846 0.03921569 0.21176471 0.7137255 0.87058824 0.98039216 1. 1. 1. 1. 1. 1. 0.90588236 0.5921569 0.23921569 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0.5058824 1. 0.84705883 0.7921569 1. 0.38431373 0.8627451 0.87058824 0.6156863 0.49019608 0.3647059 0.23529412 0.10980392 0.03137255 0.14901961 0.32156864 0.63529414 0.9372549 0.99607843 0.42352942 0. 0. 0. 0. ] [0. 0. 0. 0.12941177 0.98039216 0.61960787 0.01568628 0. 0.8 0.81960785 1. 0.9254902 0.4117647 0. 0. 0. 0. 0. 0. 0. 0. 0.02352941 0.6431373 0.99607843 0.34509805 0. 0. 0. ] [0. 0. 0. 0.61960787 0.9529412 0.11764706 0.60784316 0.92156863 0.827451 1. 0.99607843 0.7254902 1. 0.38039216 0. 0. 0. 0. 0. 0. 0. 0. 0.01176471 0.75686276 0.8980392 0.03921569 0. 0. ] [0. 0. 0.14901961 0.9843137 0.5254902 0.6 0.9843137 0.63529414 0.9882353 0.9607843 0.88235295 0.00392157 0.7058824 0.94509804 0.02352941 0. 0. 0. 0. 0. 0. 0. 0. 0.27058825 1. 0.24313726 0. 0. ] [0. 0. 0.5882353 0.9411765 0.10196079 0.9607843 0.5686275 0. 0.45490196 1. 0.58431375 0. 0.29411766 1. 0.22745098 0. 0. 0. 0. 0. 0. 0. 0. 0.11372549 1. 0.36862746 0. 0. ] [0. 0. 0.81960785 0.6784314 0.3019608 1. 0.24313726 0. 0.24313726 1. 0.26666668 0. 0.05490196 0.9843137 0.47843137 0. 0. 0. 0. 0. 0. 0. 0. 0.00784314 0.9764706 0.5019608 0. 0. ] [0. 0. 0.9411765 0.5411765 0.5921569 0.9254902 0.01568628 0. 0.29803923 1. 0.18039216 0. 0. 0.8666667 0.6156863 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.85490197 0.627451 0. 0. ] [0. 0. 0.9843137 0.49803922 0.7372549 0.7372549 0. 0. 0.29803923 1. 0.18039216 0. 0. 0.8352941 0.6431373 0. 0.01568628 0.1764706 0. 0. 0. 0. 0. 0. 0.7294118 0.75686276 0. 0. ] [0. 0.01960784 1. 0.45882353 0.78039217 0.6901961 0. 0. 0.29803923 1. 0.18039216 0. 0. 0.8039216 0.6784314 0. 0.29411766 0.9843137 0.03529412 0. 0.17254902 0.6039216 0.00784314 0. 0.61960787 0.8862745 0. 0. ] [0. 0.05490196 1. 0.42352942 0.81960785 0.6509804 0. 0. 0.29803923 1. 0.18039216 0. 0. 0.627451 0.9529412 0.4 0.17254902 0.23921569 0.13333334 0.13725491 0.38431373 0.78039217 0.5176471 0.6666667 0.9843137 0.9882353 0.02352941 0. ] [0. 0.09019608 1. 0.3882353 0.8627451 0.6117647 0. 0. 0.29803923 1. 0.18039216 0. 0. 0.47058824 0.9764706 1. 1. 1. 1. 1. 1. 1. 0.96862745 0.8392157 0.7921569 1. 0.09411765 0. ] [0. 0.06666667 1. 0.42352942 0.8666667 0.63529414 0. 0. 0.28235295 1. 0.2 0. 0. 0.5058824 0.96862745 0.12156863 0.32156864 0.33333334 0.33333334 0.33333334 0.24313726 0.10980392 0.00392157 0. 0.42352942 1. 0.05098039 0. ] [0. 0. 0.92941177 0.56078434 0.62352943 0.9098039 0.01176471 0. 0.14901961 1. 0.34117648 0. 0. 0.5058824 0.96862745 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.4745098 0.99215686 0.00784314 0. ] [0. 0. 0.7764706 0.78039217 0.28627452 1. 0.30588236 0. 0.01568628 0.972549 0.49803922 0. 0. 0.5058824 0.96862745 0. 0. 0. 0.07843138 0.24705882 0.15294118 0.03137255 0. 0. 0.5254902 0.9490196 0. 0. ] [0. 0. 0.3764706 1. 0.39607844 0.8627451 0.9137255 0.3764706 0.00392157 0.8352941 0.6509804 0. 0. 0.5058824 0.96862745 0. 0.37254903 0.8 0.99215686 1. 1. 1. 0.9019608 0.78039217 0.85490197 0.9098039 0. 0. ] [0. 0. 0.00392157 0.7058824 0.9882353 0.53333336 0.7411765 1. 0.41568628 0.6784314 0.8117647 0. 0. 0.5058824 0.96862745 0. 0.9372549 0.89411765 0.45490196 0.24705882 0.32941177 0.4509804 0.5764706 0.7254902 1. 0.90588236 0. 0. ] [0. 0. 0. 0.03137255 0.6313726 1. 0.6392157 0.29803923 0.08627451 0.52156866 0.9647059 0.00784314 0. 0.5058824 0.96862745 0. 0.7764706 0.69411767 0. 0. 0. 0. 0. 0.19215687 1. 0.8627451 0. 0. ] [0. 0. 0. 0. 0. 0.44313726 0.9843137 0.9490196 0.76862746 0.9411765 1. 0.14509805 0. 0.5058824 0.96862745 0. 0.8156863 0.65882355 0. 0. 0. 0. 0. 0.50980395 0.99607843 0.74509805 0. 0. ] [0. 0. 0. 0. 0. 0. 0.21960784 0.5529412 0.72156864 0.64705884 0.972549 0.62352943 0. 0.49411765 0.98039216 0. 0.68235296 0.8784314 0.05098039 0. 0. 0. 0. 0.85882354 1. 0.4509804 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.49803922 0.99607843 0.23529412 0.45490196 1. 0.01960784 0.27058825 0.9882353 0.9019608 0.3764706 0.00784314 0. 0.18039216 1. 0.99215686 0.11764706 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.03529412 0.84313726 0.94509804 0.6313726 1. 0.05882353 0. 0.18039216 0.74509805 1. 0.80784315 0.6627451 0.95686275 1. 0.45882353 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.09019608 0.79607844 1. 1. 0.8980392 0.8509804 0.8 0.7490196 0.8901961 1. 1. 0.9843137 0.6117647 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.00392157 0.42352942 1. 0.7176471 0.62352943 0.6666667 0.7254902 0.6666667 0.4392157 0.25490198 0.01568628 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.01176471 0.38039216 0.09019608 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]]
dataset_normalized_preview(dataset_normalized, image_shape)
# A quick example of how we're going to split the dataset for train/test/validation subsets.
tmp_ds = tf.data.Dataset.range(10)
print('tmp_ds:', list(tmp_ds.as_numpy_iterator()))
tmp_ds_test = tmp_ds.take(2)
print('tmp_ds_test:', list(tmp_ds_test.as_numpy_iterator()))
tmp_ds_val = tmp_ds.skip(2).take(3)
print('tmp_ds_val:', list(tmp_ds_val.as_numpy_iterator()))
tmp_ds_train = tmp_ds.skip(2 + 3)
print('tmp_ds_train:', list(tmp_ds_train.as_numpy_iterator()))
tmp_ds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] tmp_ds_test: [0, 1] tmp_ds_val: [2, 3, 4] tmp_ds_train: [5, 6, 7, 8, 9]
# Dataset split
test_dataset_batches = 1
val_dataset_batches = 1
# Dataset batching
batch_size = 2000
prefetch_buffer_batches = 10
# Training
epochs = 40
steps_per_epoch = 500
dataset_batched = dataset_normalized.batch(batch_size=batch_size)
# TEST dataset.
dataset_test = dataset_batched \
.take(test_dataset_batches)
# VALIDATION dataset.
dataset_val = dataset_batched \
.skip(test_dataset_batches) \
.take(val_dataset_batches)
# TRAIN dataset.
dataset_train = dataset_batched \
.skip(test_dataset_batches + val_dataset_batches) \
.prefetch(buffer_size=prefetch_buffer_batches) \
.repeat()
for (image_test, label_test) in dataset_test.take(1):
print('label_test.shape: ', label_test.shape)
print('image_test.shape: ', image_test.shape)
print()
for (image_val, label_val) in dataset_val.take(1):
print('label_val.shape: ', label_val.shape)
print('image_val.shape: ', image_val.shape)
print()
for (image_train, label_train) in dataset_train.take(1):
print('label_train.shape: ', label_train.shape)
print('image_train.shape: ', image_train.shape)
label_test.shape: (2000, 345) image_test.shape: (2000, 28, 28, 1) label_val.shape: (2000, 345) image_val.shape: (2000, 28, 28, 1) label_train.shape: (2000, 345) image_train.shape: (2000, 28, 28, 1)
# Calculate how many times the network will "see" each class during one epoch of training
# given specific dataset (batches) and number of steps per epoch.
def get_dataset_classes_hist(dataset, classes, batches_num):
mentions = {class_name: 0 for class_name in classes}
for examples, labels in dataset.take(batches_num):
for label in labels:
class_name = classes[tf.math.argmax(label).numpy()]
mentions[class_name] += 1
return mentions;
mentions = get_dataset_classes_hist(
dataset_train,
classes,
batches_num=steps_per_epoch
)
for class_name in mentions:
print('{:15s}: {}'.format(class_name, mentions[class_name]))
aircraft carrier: 2295 airplane : 3030 alarm clock : 2458 ambulance : 2842 angel : 2926 animal migration: 2706 ant : 2548 anvil : 2553 apple : 2904 arm : 2373 asparagus : 3351 axe : 2473 backpack : 2429 banana : 6052 bandage : 2935 barn : 2956 baseball bat : 2381 baseball : 2668 basket : 2278 basketball : 2662 bat : 2371 bathtub : 3429 beach : 2515 bear : 2679 beard : 3318 bed : 2346 bee : 2417 belt : 3876 bench : 2531 bicycle : 2501 binoculars : 2432 bird : 2619 birthday cake : 3036 blackberry : 2570 blueberry : 2537 book : 2389 boomerang : 2918 bottlecap : 3091 bowtie : 2549 bracelet : 2394 brain : 2770 bread : 2356 bridge : 2638 broccoli : 2613 broom : 2347 bucket : 2423 bulldozer : 3645 bus : 3293 bush : 2425 butterfly : 2333 cactus : 2663 cake : 2415 calculator : 2543 calendar : 6369 camel : 2429 camera : 2521 camouflage : 3377 campfire : 2694 candle : 2764 cannon : 2768 canoe : 2500 car : 3560 carrot : 2643 castle : 2380 cat : 2385 ceiling fan : 2282 cell phone : 2415 cello : 3012 chair : 4409 chandelier : 3392 church : 3305 circle : 2414 clarinet : 2518 clock : 2371 cloud : 2465 coffee cup : 3613 compass : 2589 computer : 2450 cookie : 2664 cooler : 5334 couch : 2439 cow : 2458 crab : 2519 crayon : 2650 crocodile : 2568 crown : 2745 cruise ship : 2515 cup : 2624 diamond : 2617 dishwasher : 3346 diving board : 5817 dog : 3000 dolphin : 2446 donut : 2806 door : 2388 dragon : 2434 dresser : 2469 drill : 2673 drums : 2746 duck : 2658 dumbbell : 3085 ear : 2474 elbow : 2553 elephant : 2467 envelope : 2644 eraser : 2339 eye : 2579 eyeglasses : 4420 face : 3250 fan : 2657 feather : 2352 fence : 2584 finger : 3272 fire hydrant : 2710 fireplace : 3101 firetruck : 4426 fish : 2589 flamingo : 2463 flashlight : 4687 flip flops : 2406 floor lamp : 3254 flower : 2948 flying saucer : 2939 foot : 3890 fork : 2543 frog : 3120 frying pan : 2392 garden hose : 2384 garden : 3225 giraffe : 2511 goatee : 3803 golf club : 3867 grapes : 3037 grass : 2417 guitar : 2380 hamburger : 2604 hammer : 2313 hand : 5825 harp : 5799 hat : 4451 headphones : 2320 hedgehog : 2352 helicopter : 3123 helmet : 2452 hexagon : 2874 hockey puck : 4059 hockey stick : 2494 horse : 3555 hospital : 3317 hot air balloon: 2532 hot dog : 3639 hot tub : 2438 hourglass : 2705 house plant : 2422 house : 2700 hurricane : 2659 ice cream : 2488 jacket : 4193 jail : 2410 kangaroo : 3356 key : 3121 keyboard : 3697 knee : 5312 knife : 3459 ladder : 2419 lantern : 2968 laptop : 5222 leaf : 2486 leg : 2329 light bulb : 2441 lighter : 2353 lighthouse : 3259 lightning : 3075 line : 2831 lion : 2396 lipstick : 2607 lobster : 2780 lollipop : 2525 mailbox : 2579 map : 2366 marker : 6420 matches : 2820 megaphone : 2790 mermaid : 3598 microphone : 2374 microwave : 2629 monkey : 2488 moon : 2394 mosquito : 2402 motorbike : 3379 mountain : 2463 mouse : 3548 moustache : 3572 mouth : 2591 mug : 3131 mushroom : 2815 nail : 3209 necklace : 2360 nose : 3790 ocean : 2555 octagon : 3217 octopus : 2965 onion : 2659 oven : 4101 owl : 3321 paint can : 2341 paintbrush : 3673 palm tree : 2384 panda : 2177 pants : 2892 paper clip : 2511 parachute : 2499 parrot : 3648 passport : 2981 peanut : 2518 pear : 2332 peas : 3167 pencil : 2523 penguin : 5005 piano : 2270 pickup truck : 2634 picture frame : 2455 pig : 3771 pillow : 2338 pineapple : 2504 pizza : 2636 pliers : 3528 police car : 2724 pond : 2381 pool : 2637 popsicle : 2486 postcard : 2515 potato : 6675 power outlet : 3284 purse : 2421 rabbit : 3032 raccoon : 2289 radio : 2653 rain : 2624 rainbow : 2489 rake : 3160 remote control : 2320 rhinoceros : 3760 rifle : 3453 river : 2629 roller coaster : 2885 rollerskates : 2340 sailboat : 2706 sandwich : 2607 saw : 2456 saxophone : 2244 school bus : 2417 scissors : 2451 scorpion : 3334 screwdriver : 2344 sea turtle : 2347 see saw : 2599 shark : 2529 sheep : 2519 shoe : 2435 shorts : 2460 shovel : 2338 sink : 4133 skateboard : 2572 skull : 2513 skyscraper : 3606 sleeping bag : 2278 smiley face : 2512 snail : 2616 snake : 2410 snorkel : 3065 snowflake : 2312 snowman : 6759 soccer ball : 2457 sock : 4093 speedboat : 3768 spider : 4084 spoon : 2531 spreadsheet : 3312 square : 2414 squiggle : 2370 squirrel : 2943 stairs : 2599 star : 2706 steak : 2410 stereo : 2518 stethoscope : 3000 stitches : 2519 stop sign : 2339 stove : 2348 strawberry : 2351 streetlight : 2533 string bean : 2411 submarine : 2522 suitcase : 2513 sun : 2716 swan : 3096 sweater : 2426 swing set : 2379 sword : 2489 syringe : 2654 t-shirt : 2458 table : 2561 teapot : 2525 teddy-bear : 3568 telephone : 2559 television : 2456 tennis racquet : 4605 tent : 2589 The Eiffel Tower: 2668 The Great Wall of China: 3839 The Mona Lisa : 2320 tiger : 2382 toaster : 2400 toe : 2973 toilet : 2559 tooth : 2495 toothbrush : 2478 toothpaste : 2552 tornado : 2821 tractor : 2248 traffic light : 2466 train : 2559 tree : 2732 triangle : 2414 trombone : 3699 truck : 2603 trumpet : 3398 umbrella : 2510 underwear : 2502 van : 3241 vase : 2455 violin : 4292 washing machine: 2434 watermelon : 2609 waterslide : 3768 whale : 2301 wheel : 2708 windmill : 2360 wine bottle : 2473 wine glass : 2712 wristwatch : 3311 yoga : 5547 zebra : 2917 zigzag : 2319
mantions_x = [class_index for class_index, class_name in enumerate(classes)]
mantions_bars = [mentions[class_name] for class_name in mentions]
plt.bar(mantions_x, mantions_bars)
plt.xlabel('Class index')
plt.ylabel('Items per class')
plt.show()
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Convolution2D(
input_shape=image_shape,
kernel_size=5,
filters=32,
padding='same',
activation=tf.keras.activations.relu
))
model.add(tf.keras.layers.MaxPooling2D(
pool_size=2,
strides=2
))
model.add(tf.keras.layers.Convolution2D(
kernel_size=3,
filters=32,
padding='same',
activation=tf.keras.activations.relu,
))
model.add(tf.keras.layers.MaxPooling2D(
pool_size=2,
strides=2
))
model.add(tf.keras.layers.Convolution2D(
kernel_size=3,
filters=64,
padding='same',
activation=tf.keras.activations.relu
))
model.add(tf.keras.layers.MaxPooling2D(
pool_size=2,
strides=2
))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(
units=512,
activation=tf.keras.activations.relu
))
model.add(tf.keras.layers.Dense(
units=num_classes,
activation=tf.keras.activations.softmax
))
model.summary()
Model: "sequential_19" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_57 (Conv2D) (None, 28, 28, 32) 832 _________________________________________________________________ max_pooling2d_57 (MaxPooling (None, 14, 14, 32) 0 _________________________________________________________________ conv2d_58 (Conv2D) (None, 14, 14, 32) 9248 _________________________________________________________________ max_pooling2d_58 (MaxPooling (None, 7, 7, 32) 0 _________________________________________________________________ conv2d_59 (Conv2D) (None, 7, 7, 64) 18496 _________________________________________________________________ max_pooling2d_59 (MaxPooling (None, 3, 3, 64) 0 _________________________________________________________________ flatten_19 (Flatten) (None, 576) 0 _________________________________________________________________ dense_43 (Dense) (None, 512) 295424 _________________________________________________________________ dense_44 (Dense) (None, 345) 176985 ================================================================= Total params: 500,985 Trainable params: 500,985 Non-trainable params: 0 _________________________________________________________________
tf.keras.utils.plot_model(
model,
show_shapes=True,
show_layer_names=True,
)
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.003)
rms_prop_optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(
optimizer=adam_optimizer,
loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy']
)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
patience=5,
monitor='val_accuracy',
restore_best_weights=True,
verbose=1
)
training_history = model.fit(
x=dataset_train,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
validation_data=dataset_val,
callbacks=[
early_stopping_callback
]
)
# Renders the charts for training accuracy and loss.
def render_training_history(training_history):
loss = training_history.history['loss']
val_loss = training_history.history['val_loss']
accuracy = training_history.history['accuracy']
val_accuracy = training_history.history['val_accuracy']
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(loss, label='Training set')
plt.plot(val_loss, label='Test set', linestyle='--')
plt.legend()
plt.grid(linestyle='--', linewidth=1, alpha=0.5)
plt.subplot(1, 2, 2)
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot(accuracy, label='Training set')
plt.plot(val_accuracy, label='Test set', linestyle='--')
plt.legend()
plt.grid(linestyle='--', linewidth=1, alpha=0.5)
plt.show()
render_training_history(training_history)
%%capture
train_loss, train_accuracy = model.evaluate(dataset_train.take(1))
print('Train loss: ', '{:.2f}'.format(train_loss))
print('Train accuracy: ', '{:.2f}'.format(train_accuracy))
Train loss: 1.36 Train accuracy: 0.66
%%capture
val_loss, val_accuracy = model.evaluate(dataset_val)
print('Validation loss: ', '{:.2f}'.format(val_loss))
print('Validation accuracy: ', '{:.2f}'.format(val_accuracy))
Validation loss: 1.43 Validation accuracy: 0.65
%%capture
test_loss, test_accuracy = model.evaluate(dataset_test)
print('Test loss: ', '{:.2f}'.format(test_loss))
print('Test accuracy: ', '{:.2f}'.format(test_accuracy))
Test loss: 1.40 Test accuracy: 0.67
def visualize_predictions(model, dataset):
numbers_to_display = 64
num_cells = math.ceil(math.sqrt(numbers_to_display))
plt.figure(figsize=(15, 15))
batch = dataset.take(1)
predictions = tf.math.argmax(model.predict(batch), axis=1).numpy()
for x, y in batch:
for image_index in range(numbers_to_display):
pixels = np.reshape(x[image_index].numpy(), (28, 28))
y_correct = tf.math.argmax(y[image_index]).numpy()
y_predicted = predictions[image_index]
correct_label = classes[y_correct]
predicted_label = classes[y_predicted]
plt.xticks([])
plt.yticks([])
plt.grid(False)
color_map = 'Greens' if y_correct == y_predicted else 'Reds'
plt.subplot(num_cells, num_cells, image_index + 1)
plt.imshow(pixels, cmap=color_map)
plt.xlabel(correct_label + ' --> ' + predicted_label)
plt.subplots_adjust(hspace=1, wspace=0.5)
plt.show()
visualize_predictions(model, dataset_train)
visualize_predictions(model, dataset_test)
We will save the entire model to a HDF5
file. The .h5
extension of the file indicates that the model should be saved in Keras format as HDF5 file. To use this model on the front-end we will convert it (later in this notebook) to Javascript understandable format (tfjs_layers_model
with .json and .bin files) using tensorflowjs_converter as it is specified in the main README.
model_name = 'sketch_recognition_cnn.h5'
model.save(model_name, save_format='h5')
To use this model on the web we need to convert it into the format that will be understandable by tensorflowjs. To do so we may use tfjs-converter as following:
tensorflowjs_converter --input_format keras \
./experiments/sketch_recognition_cnn/sketch_recognition_cnn.h5 \
./demos/public/models/sketch_recognition_cnn
You find this experiment in the Demo app and play around with it right in you browser to see how the model performs in real life.