#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('reload_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') import os os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; os.environ["CUDA_VISIBLE_DEVICES"]="0" import sys # In[2]: # import ktrain and ktrain.vision modules import ktrain from ktrain import vision # Download a PNG version of the **MNIST** dataset from [here](https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz) and set DATADIR to the extracted folder. # In[3]: # load the data with some modest data augmentation # We load as RGB even though we have grayscale images # since some models only support RGB images. DATADIR = 'data/mnist_png' data_aug = vision.get_data_aug(featurewise_center=True, featurewise_std_normalization=True, rotation_range=15, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1) (train_data, val_data, preproc) = vision.images_from_folder( datadir=DATADIR, data_aug = data_aug, train_test_names=['training', 'testing'], target_size=(32,32), color_mode='rgb') # In[4]: # get a pre-canned 22-layer Wide Residual Network model model = vision.image_classifier('wrn22', train_data, val_data) # In[5]: # get a Learner object learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data, workers=8, use_multiprocessing=True, batch_size=64) # In[7]: # find a good learning rate learner.lr_find() # In[8]: learner.lr_plot() # In[6]: # train WRN-22 model for a single epoch learner.autofit(2e-3, 1) # In[7]: # get a Predictor object that we can use to classify (potentially unlabeled) images predictor = ktrain.get_predictor(learner.model, preproc) # In[11]: # let's see the class labels and their indices predictor.get_classes() # In[27]: # let's try to predict an image depicting a 7 predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png') # In[28]: # let's try predicting an image showing a 0 and return probabilities for all classes predictor.predict_filename('/home/amaiya/data/mnist_png/testing/0/101.png', return_proba=True) # In[29]: # let's predict all images showing a 3 in our validation set predictor.predict_folder('/home/amaiya/data/mnist_png/testing/3/')[:10] # In[30]: # let's save the predictor for possible later deployment in an application predictor.save('/tmp/mypredictor') # In[31]: # reload the predictor from a file predictor = ktrain.load_predictor('/tmp/mypredictor') # In[32]: # let's use the reloaded predictor to verify it still works correctly predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png') # In[ ]: