#!/usr/bin/env python
# coding: utf-8
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# Licensed under the MIT License.
# # Training an Image Classification Model
#
# In this notebook, we give an introduction to training an image classification model using [fast.ai](https://www.fast.ai/). Using a small dataset of four different beverage packages, we demonstrate how to train and evaluate a CNN image classification model. We also cover one of the most common ways to store data on a file system for this type of problem.
#
# ## Initialization
# In[1]:
# Ensure edits to libraries are loaded and plotting is shown in the notebook.
get_ipython().run_line_magic('reload_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')
get_ipython().run_line_magic('matplotlib', 'inline')
# Import all the functions we need.
# In[2]:
import sys
sys.path.append("../../")
import numpy as np
from pathlib import Path
import scrapbook as sb
# fastai and torch
import fastai
from fastai.metrics import accuracy
from fastai.vision import (
models, ImageList, imagenet_stats, partial, cnn_learner, ClassificationInterpretation, to_np,
)
# local modules
from utils_cv.classification.model import TrainMetricsRecorder
from utils_cv.classification.plot import plot_pr_roc_curves
from utils_cv.classification.widget import ResultsWidget
from utils_cv.classification.data import Urls
from utils_cv.common.data import unzip_url
from utils_cv.common.gpu import db_num_workers, which_processor
print(f"Fast.ai version = {fastai.__version__}")
which_processor()
# This shows your machine's GPUs (if has any) and the computing device `fastai/torch` is using. We suggest using an [Azure DSVM](https://azure.microsoft.com/en-us/services/virtual-machines/data-science-virtual-machines/) Standard NC6 as a GPU compute resource.
# Next, set some model runtime parameters. We use the `unzip_url` helper function to download and unzip the data used in this example notebook.
# In[3]:
DATA_PATH = unzip_url(Urls.fridge_objects_path, exist_ok=True)
EPOCHS = 10
LEARNING_RATE = 1e-4
IM_SIZE = 300
BATCH_SIZE = 16
ARCHITECTURE = models.resnet18
# ---
#
# # Prepare Image Classification Dataset
#
# In this notebook, we use a toy dataset called *Fridge Objects*, which consists of 134 images of 4 classes of beverage container `{can, carton, milk bottle, water bottle}` photos taken on different backgrounds. The helper function downloads and unzips data set to the `ComputerVision/data` directory.
#
# Set that directory in the `path` variable for ease of use throughout the notebook.
# In[4]:
path = Path(DATA_PATH)
path.ls()
# You'll notice that we have four different folders inside:
# - `/water_bottle`
# - `/milk_bottle`
# - `/carton`
# - `/can`
#
# This is most common data format for multiclass image classification. Each folder title corresponds to the image label for the images contained inside:
#
# ```
# /images
# +-- can (class 1)
# | +-- image1.jpg
# | +-- image2.jpg
# | +-- ...
# +-- carton (class 2)
# | +-- image31.jpg
# | +-- image32.jpg
# | +-- ...
# +-- ...
# ```
#
# We have already set the data to this format structure.
# # Load Images
#
# In fast.ai, an `ImageDataBunch` can easily use multiple images (mini-batches) during training time. We create the `ImageDataBunch` by using [data_block apis](https://docs.fast.ai/data_block.html).
#
# For training and validation, we randomly split the data in an `8:2` ratio, holding 80% of the data for training and 20% for validation. One can also created dedicated train-test splits e.g. by placing the image structure shown above into parent-folders "train" and "valid" and then using [.split_by_folder()](https://docs.fast.ai/data_block.html#ItemList.split_by_folder) instead of .split_by_rand_pct() below.
#
# In[5]:
data = (
ImageList.from_folder(path)
.split_by_rand_pct(valid_pct=0.2, seed=10)
.label_from_folder()
.transform(size=IM_SIZE)
.databunch(bs=BATCH_SIZE, num_workers = db_num_workers())
.normalize(imagenet_stats)
)
# We examine some sample data using the `databunch` we created.
# In[6]:
data.show_batch(rows=3, figsize=(15,11))
# Show all available classes:
# In[7]:
print(f'number of classes: {data.c}')
print(data.classes)
# Show the number of images in the training and validation set.
# In[8]:
data.batch_stats
# In a standard analysis, we would split the data into a train/validate/test data sets. For this example, we do not use a test set but this could be added using the [add_test](https://docs.fast.ai/data_block.html#LabelLists.add_test) method. Note that in the fast.ai framework, test sets do not include labels as this should be the unknown data to be predicted. The validation data set is a test set that includes labels that can be used to measure the model performance on new observations not used to train the model.
# # Train a Model
# For this image classifier, we use a **ResNet50** convolutional neural network (CNN) architecture. You can find more details about ResNet from [here](https://arxiv.org/abs/1512.03385).
#
# When training CNN, there are almost an infinite number of ways to construct the model architecture. We need to determine how many and what type of layers to include and how many nodes make up each layer. Other hyperparameters that control the training of those layers are also important and add to the overall complexity of neural net methods. With fast.ai, we use the `create_cnn` function to specify the model architecture and performance metric. We will use a transfer learning approach to reuse the CNN architecture and initialize the model parameters used to train on [ImageNet](http://www.image-net.org/).
#
# In this work, we use a custom callback `TrainMetricsRecorder` to track the model accuracy on the training set as we tune the model. This is for instruction only, as the standard approach in `fast.ai` [recorder class](https://docs.fast.ai/basic_train.html#Recorder) only supports tracking model accuracy on the validation set.
# In[9]:
learn = cnn_learner(
data,
ARCHITECTURE,
metrics=[accuracy],
callback_fns=[partial(TrainMetricsRecorder, show_graph=True)]
)
# Use the `unfreeze` method to allow us to retrain all the CNN layers with the Fridge Objects data set.
# In[10]:
learn.unfreeze()
# The `fit` function trains the CNN using the parameters specified above.
# In[11]:
learn.fit(EPOCHS, LEARNING_RATE)
# In[12]:
# You can plot loss by using the default callback Recorder.
learn.recorder.plot_losses()
# # Validate the model
#
# To validate the model, calculate the model accuracy using the validation set.
# In[13]:
_, validation_accuracy = learn.validate(learn.data.valid_dl, metrics=[accuracy])
print(f'Accuracy on validation set: {100*float(validation_accuracy):3.2f}')
# The `ClassificationInterpretation` module is used to analyze the model classification results.
# In[14]:
interp = ClassificationInterpretation.from_learner(learn)
# Get prediction scores. We convert tensors to numpy array to plot them later.
pred_scores = to_np(interp.preds)
# To see these details use the widget helper class `ResultsWidget`. The widget shows test images along with the ground truth label and model prediction score. With this tool, it's possible to see how the model predicts each image and debug the model if needed.
#
#
#