Transfer Learning is the re-use of pre-trained models on new tasks. Most often, the two tasks are different but somehow related to each other. For example, a model which was trained on image classification might have learnt image features which can also be harnessed for other image related tasks. This technique became increasingly popular in the field of Deep Learning since it enables one to train a model on comparatively little data.
In this notebook, we consider the Alien vs Preditor task from Kaggle. We want to classify images as either 'aliens' or 'predators'.
Because the dataset is relatively small, and we don't want to wait for hours, we use a model pre-trained on the very large ImageNet task.
To reduce the difficulty of training, we freeze the intermediate layers and only train a few layers close to the output.
Figures taken from https://www.kaggle.com/pmigdal/alien-vs-predator-images
Execute this cell to download the alien-vs-predator dataset and to install some dependencies. Google Colab will offer you to restart the kernel after you did this. Please do so.
!wget -q https://raw.githubusercontent.com/theevann/amld-pytorch-workshop/master/binder/requirements.txt -O requirements.txt
!pip install -qr requirements.txt
!mkdir -p data
!curl -Lo alien-vs-predator.zip "https://docs.google.com/uc?export=download&id=1hct3PjRf14ZBp83ob3f6Uo_0mqrT9FGZ"
!unzip -oq alien-vs-predator.zip -d data/
!rm alien-vs-predator.zip
!ls -l data/alien-vs-predator/
# for PIL.Image
!pip install --no-cache-dir -I Pillow==7.1.2
If all went well, you should be able to execute the following cell successfully.
import os
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rc('font', size=16)
Here we will create a Dataset
and corresponding DataLoader
that find training examples in the following directory structure:
data/alien-vs-predator
│
└───train
│ │
│ │───alien
│ │ │ 20.jpg
│ │ │ 104.jpg
│ │ └ ...
│ │
│ └───predator
│ │ 1.jpg
│ │ 78.jpg
│ └ ...
│
└───validation
│
│───alien
│ │ 233.jpg
│ │ 12.jpg
│ └ ...
│
└───predator
│ 22.jpg
│ 77.jpg
└ ...
torchvision
datasets allow us to specify many different transformation on the inputs. Random perturbations can improve the quality of your model by synthetically enlarging your dataset.
from torchvision.datasets import ImageFolder
from torchvision.transforms import (
Compose,
RandomResizedCrop,
RandomHorizontalFlip,
ToTensor,
Resize,
CenterCrop,
)
from torch.utils.data import DataLoader
# Create datasets
train_data = ImageFolder(
os.path.join(os.getcwd(), "data", "alien-vs-predator", "train"),
transform=Compose(
[RandomResizedCrop(224), RandomHorizontalFlip(), ToTensor()] # data augmentation
),
)
test_data = ImageFolder(
os.path.join(os.getcwd(), "data", "alien-vs-predator", "validation"),
transform=Compose([Resize(256), CenterCrop(224), ToTensor()]), # give images the same size as the train images
)
# Specify corresponding batched data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
# Our datasets have two classes:
class_names = train_data.classes
class_names
Let's have a look at the effect of the transformations we specified for data augmentation
# We load one image from the dataset
preview_data = ImageFolder(os.path.join(os.getcwd(), "data", "alien-vs-predator", "train"))
img, label = preview_data[2]
# Let's inspect the effect of the various transformations
fig = plt.figure(figsize=(16, 9))
def show_image(img, label):
# permute turns (rgb, height, width) into (height, width, rgb)
plt.imshow(tensor_transformer(img).permute(1, 2, 0))
plt.xlabel(label)
tensor_transformer = ToTensor()
plt.subplot(1, 4, 1)
show_image(img, "Original")
resize_transformer = Resize((400,400))
plt.subplot(1, 4, 2)
show_image(resize_transformer(img), "Resized (400x400)")
horizontal_flip_transformer = RandomHorizontalFlip()
plt.subplot(1, 4, 3)
show_image(horizontal_flip_transformer(img), "Random horizontal flip")
random_resize_crop_transformer = RandomResizedCrop(250, scale=(0.5, 1))
plt.subplot(1, 4, 4)
show_image(random_resize_crop_transformer(img), "Random resizing + croping")
Note that the transformations above are random, so if you run the cell multiple times, you will see different results.
# Take one batch from the train loader
data, labels = next(iter(train_loader))
data, labels = data[0:5], labels[0:5]
# Plot the images
fig = plt.figure(figsize=(16, 9))
for i in range(0, 5):
fig.add_subplot(1, 5, i + 1)
plt.imshow(data[i].permute(1, 2, 0))
plt.xlabel(class_names[labels[i]])
torchvision
includes many pre-trained models. Let's get a list and have a look.
for model in dir(torchvision.models):
if model.startswith("_"): continue # Skip private properties
print(f"- {model}")
It's very simple to create a module that has this model with its weights pre-trained for ImageNet.
model_ft = torchvision.models.resnet18(pretrained=True)
model_ft
The last fully connected layer has a 1000 output neurons since it has been trained on the ImageNet task, which has 1000 image classes.
model_ft.fc
We would like to perform binary classification (alien/predator). Therefore, we have to replace the last fully-connected layer to suit our needs (two output units).
model_ft.fc = nn.Linear(in_features=512, out_features=2)
model_ft.fc
First way
for name, param in model_ft.named_parameters():
if name not in ["fc.weight", "fc.bias"]:
param.requires_grad = False
Second way
model_ft.requires_grad_(False)
model_ft.fc.requires_grad_(True)
Third way
A third way could be to pass to the optimiser only the parameters of the last linear layer.
However, this is not as good as the previous methods, do you see why ?
Hint: All the gradients are still computed...
So, now the architecture contains two output units, we can therefore use it to perform binary classification.
The train and accuracy function are almost identical to the functions we used when traininig the CNN. This again nicely demonstrates the modularity of PyTorch and its simple interface.
import colorama
def train(
model,
train_loader,
test_loader,
device,
num_epochs=3,
learning_rate=0.1,
decay_learning_rate=False,
):
# Some models behave differently in training and testing mode (Dropout, BatchNorm)
# so it is good practice to specify which behavior you want.
model.train()
# We will use the Adam with Cross Entropy loss
optimizer = torch.optim.Adam(model.fc.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
if decay_learning_rate:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.85)
# We make multiple passes over the dataset
for epoch in range(num_epochs):
print("=" * 40, "Starting epoch %d" % (epoch + 1), "=" * 40)
if decay_learning_rate:
scheduler.step()
total_epoch_loss = 0.0
# Make one pass in batches
for batch_number, (data, labels) in enumerate(train_loader):
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
total_epoch_loss += loss.item()
if batch_number % 5 == 0:
print("Batch %d/%d" % (batch_number, len(train_loader)))
train_acc = accuracy(model, train_loader, device)
test_acc = accuracy(model, test_loader, device)
print(
colorama.Fore.GREEN
+ "\nEpoch %d/%d, Loss=%.4f, Train-Acc=%d%%, Valid-Acc=%d%%"
% (
epoch + 1,
num_epochs,
total_epoch_loss / len(train_data),
100 * train_acc,
100 * test_acc,
),
colorama.Fore.RESET,
)
def accuracy(model, data_loader, device):
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad(): # deactivates autograd, reduces memory usage and speeds up computations
for data, labels in data_loader:
data, labels = data.to(device), labels.to(device)
predictions = torch.argmax(model(data), 1) # find the class number with the largest output
num_correct += (predictions == labels).sum().item()
num_samples += len(predictions)
return num_correct / num_samples
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ft.to(device)
train(model_ft, train_loader, test_loader, device, num_epochs=2)
data, labels = next(iter(DataLoader(test_data, batch_size=5, shuffle=True)))
data, labels = data.to(device), labels
predictions = torch.argmax(model_ft(data), 1).cpu()
predictions, data = predictions.cpu(), data.cpu() # put it back on CPU for visualization
plt.figure(figsize=(16, 9))
for i in range(5):
img = data.squeeze(1)[i]
plt.subplot(1, 5, i + 1)
plt.imshow(img.permute(1, 2, 0))
plt.xlabel(
"prediction = %s\n (gt: %s)"
% (test_data.classes[predictions[i].item()], test_data.classes[labels[i]]),
fontsize=14,
)
plt.xticks([])
plt.yticks([])