#!/usr/bin/env python
# coding: utf-8

# <div>
# <img src="https://discuss.pytorch.org/uploads/default/original/2X/3/35226d9fbc661ced1c5d17e374638389178c3176.png" width="400" style="margin: 50px auto; display: block; position: relative; left: -30px;" />
# </div>
# 
# <!--NAVIGATION-->
# # < [CNN](6-CNN.ipynb) | Transfer Learning | [Pretrained models for NLP](8-Pretrained-models-for-NLP.ipynb) >

# ### An Example of Transfer Learning for Image Classification
# 
# Transfer Learning is the re-use of pre-trained models on new tasks. Most often, the two tasks are different but somehow related to each other. For example, a model which was trained on image classification might have learnt image features which can also be harnessed for other image related tasks. This technique became increasingly popular in the field of Deep Learning since it enables one to train a model on comparatively little data.

# ![one does not simply build models](figures/simply-build-model.jpeg)

# ### Table of Contents
# 
# #### 1. [Setup](#Setup)
# #### 2. [Data loading and pre-processing](#Data-loading-and-pre-processing)
# #### 3. [Loading a pre-trained model](#Loading-a-pre-trained-model)
# #### 4. [Training the last layer](#Training-the-last-layer)

# ---
# 
# # Setup
# 
# In this notebook, we consider the [Alien vs Preditor](https://www.kaggle.com/pmigdal/alien-vs-predator-images) task from [Kaggle](http://www.kaggle.com). We want to classify images as either 'aliens' or 'predators'. 
# 
# Because the dataset is relatively small, and we don't want to wait for hours, we use a model pre-trained on the very large ImageNet task.

# ![Transfer Learning Figure 1](figures/transfer-learning-1.png)

# To reduce the difficulty of training, we freeze the intermediate layers and only train a few layers close to the output.

# ![Transfer Learning Figure 2](figures/transfer-learning-2.png)
# Figures taken from https://www.kaggle.com/pmigdal/alien-vs-predator-images

# ---
# ### Requirements

# Execute this cell to download the alien-vs-predator dataset and to install some dependencies. Google Colab will offer you to restart the kernel after you did this. Please do so.

# In[ ]:


get_ipython().system('wget -q https://raw.githubusercontent.com/theevann/amld-pytorch-workshop/master/binder/requirements.txt -O requirements.txt')
get_ipython().system('pip install -qr requirements.txt')

get_ipython().system('mkdir -p data')
get_ipython().system('curl -Lo alien-vs-predator.zip "https://docs.google.com/uc?export=download&id=1hct3PjRf14ZBp83ob3f6Uo_0mqrT9FGZ"')
get_ipython().system('unzip -oq alien-vs-predator.zip -d data/')
get_ipython().system('rm alien-vs-predator.zip')
get_ipython().system('ls -l data/alien-vs-predator/')

# for PIL.Image
get_ipython().system('pip install --no-cache-dir -I Pillow==7.1.2')


# If all went well, you should be able to execute the following cell successfully.

# In[ ]:


import os

import torch
import torch.nn as nn

import torchvision

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rc('font', size=16)


# ---
# 
# # Data loading and pre-processing
# 
# Here we will create a `Dataset` and corresponding `DataLoader` that find training examples in the following directory structure:
# 
# ```
# data/alien-vs-predator
# │
# └───train
# │     │
# │     │───alien
# │     │    │   20.jpg
# │     │    │   104.jpg
# │     │    └   ...
# │     │
# │     └───predator
# │          │   1.jpg
# │          │   78.jpg
# │          └   ...
# │   
# └───validation
#       │
#       │───alien
#       │    │   233.jpg
#       │    │   12.jpg
#       │    └   ...
#       │
#       └───predator
#            │   22.jpg
#            │   77.jpg
#            └   ...
# ```

# `torchvision` datasets allow us to specify many different transformation on the inputs. Random perturbations can improve the quality of your model by synthetically enlarging your dataset.

# In[ ]:


from torchvision.datasets import ImageFolder
from torchvision.transforms import (
    Compose,
    RandomResizedCrop,
    RandomHorizontalFlip,
    ToTensor,
    Resize,
    CenterCrop,
)
from torch.utils.data import DataLoader

# Create datasets
train_data = ImageFolder(
    os.path.join(os.getcwd(), "data", "alien-vs-predator", "train"),
    transform=Compose(
        [RandomResizedCrop(224), RandomHorizontalFlip(), ToTensor()]  # data augmentation
    ),
)

test_data = ImageFolder(
    os.path.join(os.getcwd(), "data", "alien-vs-predator", "validation"),
    transform=Compose([Resize(256), CenterCrop(224), ToTensor()]),  # give images the same size as the train images
)

# Specify corresponding batched data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Our datasets have two classes:
class_names = train_data.classes
class_names


# ### Data augmentation
# 
# Let's have a look at the effect of the transformations we specified for data augmentation

# In[ ]:


# We load one image from the dataset
preview_data = ImageFolder(os.path.join(os.getcwd(), "data", "alien-vs-predator", "train"))
img, label = preview_data[2]

# Let's inspect the effect of the various transformations
fig = plt.figure(figsize=(16, 9))

def show_image(img, label):
    # permute turns (rgb, height, width) into (height, width, rgb)
    plt.imshow(tensor_transformer(img).permute(1, 2, 0))
    plt.xlabel(label)

tensor_transformer = ToTensor()
plt.subplot(1, 4, 1)
show_image(img, "Original")

resize_transformer = Resize((400,400))
plt.subplot(1, 4, 2)
show_image(resize_transformer(img), "Resized (400x400)")

horizontal_flip_transformer = RandomHorizontalFlip()
plt.subplot(1, 4, 3)
show_image(horizontal_flip_transformer(img), "Random horizontal flip")

random_resize_crop_transformer = RandomResizedCrop(250, scale=(0.5, 1))
plt.subplot(1, 4, 4)
show_image(random_resize_crop_transformer(img), "Random resizing + croping")


# Note that the transformations above are random, so if you run the cell multiple times, you will see different results.

# ### Visualize some training samples

# In[ ]:


# Take one batch from the train loader
data, labels = next(iter(train_loader))
data, labels = data[0:5], labels[0:5]

# Plot the images
fig = plt.figure(figsize=(16, 9))
for i in range(0, 5):
    fig.add_subplot(1, 5, i + 1)
    plt.imshow(data[i].permute(1, 2, 0))
    plt.xlabel(class_names[labels[i]])


# ---
# 
# # Loading a pre-trained model

# ### List of available pre-trained models

# `torchvision` includes many pre-trained models. Let's get a list and have a look.

# In[ ]:


for model in dir(torchvision.models):
    if model.startswith("_"): continue  # Skip private properties
    print(f"- {model}")


# ### We will use the ResNet-18 architecture:
# ![ResNet-picture](./figures/resnet.png)

# It's very simple to create a module that has this model with its weights pre-trained for ImageNet.

# In[ ]:


model_ft = torchvision.models.resnet18(pretrained=True)
model_ft


# ### A closer look at the ResNet-18

# The last fully connected layer has a 1000 output neurons since it has been trained on the ImageNet task, which has 1000 image classes.

# In[ ]:


model_ft.fc


# We would like to perform binary classification (alien/predator). Therefore, we have to replace the last fully-connected layer to suit our needs (two output units).

# In[ ]:


model_ft.fc = nn.Linear(in_features=512, out_features=2)


# In[ ]:


model_ft.fc


# ---
# 
# # Training just the last layer

# ### Freeze all the layers except the last fully-connected one

# **First way**

# In[ ]:


for name, param in model_ft.named_parameters():
    if name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False


# **Second way**

# In[ ]:


model_ft.requires_grad_(False)
model_ft.fc.requires_grad_(True)


# **Third way**

# A third way could be to pass to the optimiser only the parameters of the last linear layer.  
# However, this is not as good as the previous methods, do you see why ?  
# Hint: All the gradients are still computed...

# ### Define the train and accuracy functions

# So, now the architecture contains two output units, we can therefore use it to perform binary classification.
# 
# The *train* and _accuracy_ function are almost identical to the functions we used when traininig the CNN. This again nicely demonstrates the modularity of PyTorch and its simple interface.

# In[ ]:


import colorama

def train(
    model,
    train_loader,
    test_loader,
    device,
    num_epochs=3,
    learning_rate=0.1,
    decay_learning_rate=False,
):
    # Some models behave differently in training and testing mode (Dropout, BatchNorm)
    # so it is good practice to specify which behavior you want.
    model.train()

    # We will use the Adam with Cross Entropy loss
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    if decay_learning_rate:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.85)

    # We make multiple passes over the dataset
    for epoch in range(num_epochs):
        print("=" * 40, "Starting epoch %d" % (epoch + 1), "=" * 40)

        if decay_learning_rate:
            scheduler.step()

        total_epoch_loss = 0.0
        # Make one pass in batches
        for batch_number, (data, labels) in enumerate(train_loader):
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()

            if batch_number % 5 == 0:
                print("Batch %d/%d" % (batch_number, len(train_loader)))

        train_acc = accuracy(model, train_loader, device)
        test_acc = accuracy(model, test_loader, device)

        print(
            colorama.Fore.GREEN
            + "\nEpoch %d/%d, Loss=%.4f, Train-Acc=%d%%, Valid-Acc=%d%%"
            % (
                epoch + 1,
                num_epochs,
                total_epoch_loss / len(train_data),
                100 * train_acc,
                100 * test_acc,
            ),
            colorama.Fore.RESET,
        )


# In[ ]:


def accuracy(model, data_loader, device):
    model.eval()

    num_correct = 0
    num_samples = 0
    with torch.no_grad():  # deactivates autograd, reduces memory usage and speeds up computations
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)

            predictions = torch.argmax(model(data), 1)  # find the class number with the largest output
            num_correct += (predictions == labels).sum().item()
            num_samples += len(predictions)

    return num_correct / num_samples


# ### Launch training

# In[ ]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ft.to(device)

train(model_ft, train_loader, test_loader, device, num_epochs=2)


# ### Inspecting the model's predictions

# In[ ]:


data, labels = next(iter(DataLoader(test_data, batch_size=5, shuffle=True)))
data, labels = data.to(device), labels
predictions = torch.argmax(model_ft(data), 1).cpu()

predictions, data = predictions.cpu(), data.cpu()  # put it back on CPU for visualization

plt.figure(figsize=(16, 9))
for i in range(5):
    img = data.squeeze(1)[i]
    plt.subplot(1, 5, i + 1)
    plt.imshow(img.permute(1, 2, 0))
    plt.xlabel(
        "prediction = %s\n (gt: %s)"
        % (test_data.classes[predictions[i].item()], test_data.classes[labels[i]]),
        fontsize=14,
    )
    plt.xticks([])
    plt.yticks([])


# <!--NAVIGATION-->
# # < [CNN](6-CNN.ipynb) | Transfer Learning | [Pretrained models for NLP](8-Pretrained-models-for-NLP.ipynb) >

# In[ ]: