Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
In this notebook, we give an introduction to training an image classification model using fast.ai. Using a small dataset of four different beverage packages, we demonstrate how to train and evaluate a CNN image classification model. We also cover one of the most common ways to store data on a file system for this type of problem.
# Ensure edits to libraries are loaded and plotting is shown in the notebook.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
Import all the functions we need.
import sys
sys.path.append("../../")
import numpy as np
from pathlib import Path
import scrapbook as sb
# fastai and torch
import fastai
from fastai.metrics import accuracy
from fastai.vision import (
models, ImageList, imagenet_stats, partial, cnn_learner, ClassificationInterpretation, to_np,
)
# local modules
from utils_cv.classification.model import TrainMetricsRecorder
from utils_cv.classification.plot import plot_pr_roc_curves
from utils_cv.classification.widget import ResultsWidget
from utils_cv.classification.data import Urls
from utils_cv.common.data import unzip_url
from utils_cv.common.gpu import db_num_workers, which_processor
print(f"Fast.ai version = {fastai.__version__}")
which_processor()
Fast.ai version = 1.0.48 Fast.ai (Torch) is using GPU: Tesla V100-PCIE-16GB
This shows your machine's GPUs (if has any) and the computing device fastai/torch
is using. We suggest using an Azure DSVM Standard NC6 as a GPU compute resource.
Next, set some model runtime parameters. We use the unzip_url
helper function to download and unzip the data used in this example notebook.
DATA_PATH = unzip_url(Urls.fridge_objects_path, exist_ok=True)
EPOCHS = 10
LEARNING_RATE = 1e-4
IM_SIZE = 300
BATCH_SIZE = 16
ARCHITECTURE = models.resnet18
In this notebook, we use a toy dataset called Fridge Objects, which consists of 134 images of 4 classes of beverage container {can, carton, milk bottle, water bottle}
photos taken on different backgrounds. The helper function downloads and unzips data set to the ComputerVision/data
directory.
Set that directory in the path
variable for ease of use throughout the notebook.
path = Path(DATA_PATH)
path.ls()
[PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/models'), PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/milk_bottle'), PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/can'), PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/fast_inference'), PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/water_bottle'), PosixPath('/data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects/carton')]
You'll notice that we have four different folders inside:
/water_bottle
/milk_bottle
/carton
/can
This is most common data format for multiclass image classification. Each folder title corresponds to the image label for the images contained inside:
/images
+-- can (class 1)
| +-- image1.jpg
| +-- image2.jpg
| +-- ...
+-- carton (class 2)
| +-- image31.jpg
| +-- image32.jpg
| +-- ...
+-- ...
We have already set the data to this format structure.
In fast.ai, an ImageDataBunch
can easily use multiple images (mini-batches) during training time. We create the ImageDataBunch
by using data_block apis.
For training and validation, we randomly split the data in an 8:2
ratio, holding 80% of the data for training and 20% for validation. One can also created dedicated train-test splits e.g. by placing the image structure shown above into parent-folders "train" and "valid" and then using .split_by_folder() instead of .split_by_rand_pct() below.
data = (
ImageList.from_folder(path)
.split_by_rand_pct(valid_pct=0.2, seed=10)
.label_from_folder()
.transform(size=IM_SIZE)
.databunch(bs=BATCH_SIZE, num_workers = db_num_workers())
.normalize(imagenet_stats)
)
We examine some sample data using the databunch
we created.
data.show_batch(rows=3, figsize=(15,11))
Show all available classes:
print(f'number of classes: {data.c}')
print(data.classes)
number of classes: 4 ['can', 'carton', 'milk_bottle', 'water_bottle']
Show the number of images in the training and validation set.
data.batch_stats
<bound method ImageDataBunch.batch_stats of ImageDataBunch; Train: LabelList (108 items) x: ImageList Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300) y: CategoryList milk_bottle,milk_bottle,milk_bottle,milk_bottle,milk_bottle Path: /data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects; Valid: LabelList (26 items) x: ImageList Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300),Image (3, 300, 300) y: CategoryList can,can,can,carton,carton Path: /data/home/pabuehle/Desktop/ComputerVision/data/fridgeObjects; Test: None>
In a standard analysis, we would split the data into a train/validate/test data sets. For this example, we do not use a test set but this could be added using the add_test method. Note that in the fast.ai framework, test sets do not include labels as this should be the unknown data to be predicted. The validation data set is a test set that includes labels that can be used to measure the model performance on new observations not used to train the model.
For this image classifier, we use a ResNet50 convolutional neural network (CNN) architecture. You can find more details about ResNet from here.
When training CNN, there are almost an infinite number of ways to construct the model architecture. We need to determine how many and what type of layers to include and how many nodes make up each layer. Other hyperparameters that control the training of those layers are also important and add to the overall complexity of neural net methods. With fast.ai, we use the create_cnn
function to specify the model architecture and performance metric. We will use a transfer learning approach to reuse the CNN architecture and initialize the model parameters used to train on ImageNet.
In this work, we use a custom callback TrainMetricsRecorder
to track the model accuracy on the training set as we tune the model. This is for instruction only, as the standard approach in fast.ai
recorder class only supports tracking model accuracy on the validation set.
learn = cnn_learner(
data,
ARCHITECTURE,
metrics=[accuracy],
callback_fns=[partial(TrainMetricsRecorder, show_graph=True)]
)
Use the unfreeze
method to allow us to retrain all the CNN layers with the Fridge Objects data set.
learn.unfreeze()
The fit
function trains the CNN using the parameters specified above.
learn.fit(EPOCHS, LEARNING_RATE)
epoch | train_loss | valid_loss | train_accuracy | valid_accuracy | time |
---|---|---|---|---|---|
0 | 1.235351 | 0.997030 | 0.385417 | 0.500000 | 00:01 |
1 | 0.863196 | 0.842252 | 0.843750 | 0.769231 | 00:01 |
2 | 0.685307 | 0.691107 | 0.854167 | 0.846154 | 00:01 |
3 | 0.540183 | 0.631638 | 0.958333 | 0.846154 | 00:01 |
4 | 0.442337 | 0.536175 | 0.989583 | 0.846154 | 00:01 |
5 | 0.363073 | 0.525139 | 1.000000 | 0.846154 | 00:01 |
6 | 0.310642 | 0.495950 | 0.989583 | 0.807692 | 00:01 |
7 | 0.269349 | 0.475910 | 1.000000 | 0.846154 | 00:01 |
8 | 0.229974 | 0.430902 | 1.000000 | 0.846154 | 00:01 |
9 | 0.199667 | 0.402672 | 1.000000 | 0.846154 | 00:01 |
# You can plot loss by using the default callback Recorder.
learn.recorder.plot_losses()
To validate the model, calculate the model accuracy using the validation set.
_, validation_accuracy = learn.validate(learn.data.valid_dl, metrics=[accuracy])
print(f'Accuracy on validation set: {100*float(validation_accuracy):3.2f}')
Accuracy on validation set: 84.62
The ClassificationInterpretation
module is used to analyze the model classification results.
interp = ClassificationInterpretation.from_learner(learn)
# Get prediction scores. We convert tensors to numpy array to plot them later.
pred_scores = to_np(interp.preds)
To see these details use the widget helper class ResultsWidget
. The widget shows test images along with the ground truth label and model prediction score. With this tool, it's possible to see how the model predicts each image and debug the model if needed.
w_results = ResultsWidget(
dataset=learn.data.valid_ds,
y_score=pred_scores,
y_label=[data.classes[x] for x in np.argmax(pred_scores, axis=1)],
)
display(w_results.show())
Tab(children=(VBox(children=(HBox(children=(Button(description='Previous', layout=Layout(width='80px'), style=…
Aside from accuracy, precision and recall are other metrics that are also important in classification settings. These are linked metrics that quantify how well the model classifies an image against a known label, and where it fails. Since they are linked, there is a trade-off between optimizing for precision and optimizing for recall. They can be plotted against each other to graphically show how they are linked.
In multiclass settings, we plot precision-recall and ROC curves for each class. In this example, the dataset is not complex and the accuracy is close to 100%. In more difficult settings, these figures will be more interesting.
# True labels of the validation set. We convert to numpy array for plotting.
true_labels = to_np(interp.y_true)
plot_pr_roc_curves(true_labels, pred_scores, data.classes)
A confusion matrix details the number of images on which the model succeeded or failed. For each class, the matrix lists correct classifications along the diagonal, and incorrect ones off-diagonal. This allows a detailed look on how the model confused the prediction of some classes.
interp.plot_confusion_matrix()
When evaluating our results, we want to see where the model makes mistakes and if we can help it improve.
interp.plot_top_losses(9, figsize=(15,11))
# Preserve some of the notebook outputs
training_losses = [x.numpy().ravel()[0] for x in learn.recorder.losses]
training_accuracies = [x[0].numpy().ravel()[0] for x in learn.recorder.metrics]
sb.glue("training_losses", training_losses)
sb.glue("training_accuracies", training_accuracies)
sb.glue("validation_accuracy", 100 * float(validation_accuracy))
Using the concepts introduced in this notebook, you can bring your own dataset and train an image classifier to detect objects of interest for your specific setting.