Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch,torchvision
Sebastian Raschka CPython 3.6.8 IPython 7.2.0 torch 1.0.1.post2 torchvision 0.2.2
import torch
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
#######################################
### PRE-TRAINED MODELS AVAILABLE HERE
## https://pytorch.org/docs/stable/torchvision/models.html
from torchvision import models
#######################################
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
In this example, we are going to work with CIFAR-10, because you are familiar with it and it is easier (smaller) than ImageNet. However, note that in a "real-world application", images with dimension > 224x224 are recommended. Here, we resize the images as a workaround
##########################
### SETTINGS
##########################
# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', DEVICE)
NUM_CLASSES = 10
# Hyperparameters
random_seed = 1
learning_rate = 0.0001
num_epochs = 10
batch_size = 128
##########################
### MNIST DATASET
##########################
custom_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
## Note that this particular normalization scheme is
## necessary since it was used for pre-training
## the network on ImageNet.
## These are the channel-means and standard deviations
## for z-score normalization.
train_dataset = datasets.CIFAR10(root='data',
train=True,
transform=custom_transform,
download=True)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=custom_transform)
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=8,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
num_workers=8,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Device: cuda:0 Files already downloaded and verified Image batch dimensions: torch.Size([128, 3, 224, 224]) Image label dimensions: torch.Size([128])
model = models.vgg16(pretrained=True)
model
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace) (2): Dropout(p=0.5) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace) (5): Dropout(p=0.5) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )
for param in model.parameters():
param.requires_grad = False
Assume we want to train the penultimate layer:
model.classifier[3].requires_grad = True
Now, replace the output layer with your own output layer (here, we actually add two more output layers):
model.classifier[6] = nn.Sequential(
nn.Linear(4096, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, NUM_CLASSES))
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())
def compute_accuracy(model, data_loader):
model.eval()
correct_pred, num_examples = 0, 0
for i, (features, targets) in enumerate(data_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
def compute_epoch_loss(model, data_loader):
model.eval()
curr_loss, num_examples = 0., 0
with torch.no_grad():
for features, targets in data_loader:
features = features.to(DEVICE)
targets = targets.to(DEVICE)
logits = model(features)
loss = F.cross_entropy(logits, targets, reduction='sum')
num_examples += targets.size(0)
curr_loss += loss
curr_loss = curr_loss / num_examples
return curr_loss
start_time = time.time()
for epoch in range(num_epochs):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(train_loader), cost))
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (
epoch+1, num_epochs,
compute_accuracy(model, train_loader),
compute_epoch_loss(model, train_loader)))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
Epoch: 001/010 | Batch 0000/0391 | Cost: 2.3338 Epoch: 001/010 | Batch 0050/0391 | Cost: 0.6228 Epoch: 001/010 | Batch 0100/0391 | Cost: 0.6328 Epoch: 001/010 | Batch 0150/0391 | Cost: 0.6408 Epoch: 001/010 | Batch 0200/0391 | Cost: 0.7256 Epoch: 001/010 | Batch 0250/0391 | Cost: 0.6532 Epoch: 001/010 | Batch 0300/0391 | Cost: 0.7125 Epoch: 001/010 | Batch 0350/0391 | Cost: 0.6446 Epoch: 001/010 | Train: 82.850% | Loss: 0.485 Time elapsed: 10.31 min Epoch: 002/010 | Batch 0000/0391 | Cost: 0.6978 Epoch: 002/010 | Batch 0050/0391 | Cost: 0.6642 Epoch: 002/010 | Batch 0100/0391 | Cost: 0.5250 Epoch: 002/010 | Batch 0150/0391 | Cost: 0.5881 Epoch: 002/010 | Batch 0200/0391 | Cost: 0.5950 Epoch: 002/010 | Batch 0250/0391 | Cost: 0.6431 Epoch: 002/010 | Batch 0300/0391 | Cost: 0.6382 Epoch: 002/010 | Batch 0350/0391 | Cost: 0.6368 Epoch: 002/010 | Train: 84.794% | Loss: 0.439 Time elapsed: 20.62 min Epoch: 003/010 | Batch 0000/0391 | Cost: 0.5730 Epoch: 003/010 | Batch 0050/0391 | Cost: 0.4823 Epoch: 003/010 | Batch 0100/0391 | Cost: 0.5051 Epoch: 003/010 | Batch 0150/0391 | Cost: 0.4477 Epoch: 003/010 | Batch 0200/0391 | Cost: 0.5656 Epoch: 003/010 | Batch 0250/0391 | Cost: 0.6305 Epoch: 003/010 | Batch 0300/0391 | Cost: 0.6265 Epoch: 003/010 | Batch 0350/0391 | Cost: 0.6465 Epoch: 003/010 | Train: 86.038% | Loss: 0.415 Time elapsed: 30.93 min Epoch: 004/010 | Batch 0000/0391 | Cost: 0.5001 Epoch: 004/010 | Batch 0050/0391 | Cost: 0.5504 Epoch: 004/010 | Batch 0100/0391 | Cost: 0.5089 Epoch: 004/010 | Batch 0150/0391 | Cost: 0.4440 Epoch: 004/010 | Batch 0200/0391 | Cost: 0.5091 Epoch: 004/010 | Batch 0250/0391 | Cost: 0.5236 Epoch: 004/010 | Batch 0300/0391 | Cost: 0.4528 Epoch: 004/010 | Batch 0350/0391 | Cost: 0.5743 Epoch: 004/010 | Train: 86.296% | Loss: 0.390 Time elapsed: 41.26 min Epoch: 005/010 | Batch 0000/0391 | Cost: 0.5132 Epoch: 005/010 | Batch 0050/0391 | Cost: 0.4700 Epoch: 005/010 | Batch 0100/0391 | Cost: 0.4637 Epoch: 005/010 | Batch 0150/0391 | Cost: 0.5554 Epoch: 005/010 | Batch 0200/0391 | Cost: 0.4977 Epoch: 005/010 | Batch 0250/0391 | Cost: 0.4567 Epoch: 005/010 | Batch 0300/0391 | Cost: 0.5479 Epoch: 005/010 | Batch 0350/0391 | Cost: 0.5718 Epoch: 005/010 | Train: 87.178% | Loss: 0.367 Time elapsed: 51.57 min Epoch: 006/010 | Batch 0000/0391 | Cost: 0.4161 Epoch: 006/010 | Batch 0050/0391 | Cost: 0.5224 Epoch: 006/010 | Batch 0100/0391 | Cost: 0.4122 Epoch: 006/010 | Batch 0150/0391 | Cost: 0.3939 Epoch: 006/010 | Batch 0200/0391 | Cost: 0.6094 Epoch: 006/010 | Batch 0250/0391 | Cost: 0.4580 Epoch: 006/010 | Batch 0300/0391 | Cost: 0.6227 Epoch: 006/010 | Batch 0350/0391 | Cost: 0.3779 Epoch: 006/010 | Train: 87.560% | Loss: 0.373 Time elapsed: 61.88 min Epoch: 007/010 | Batch 0000/0391 | Cost: 0.4695 Epoch: 007/010 | Batch 0050/0391 | Cost: 0.4915 Epoch: 007/010 | Batch 0100/0391 | Cost: 0.5667 Epoch: 007/010 | Batch 0150/0391 | Cost: 0.6295 Epoch: 007/010 | Batch 0200/0391 | Cost: 0.5733 Epoch: 007/010 | Batch 0250/0391 | Cost: 0.4504 Epoch: 007/010 | Batch 0300/0391 | Cost: 0.5983 Epoch: 007/010 | Batch 0350/0391 | Cost: 0.5212 Epoch: 007/010 | Train: 87.934% | Loss: 0.355 Time elapsed: 72.19 min Epoch: 008/010 | Batch 0000/0391 | Cost: 0.3749 Epoch: 008/010 | Batch 0050/0391 | Cost: 0.4332 Epoch: 008/010 | Batch 0100/0391 | Cost: 0.5650 Epoch: 008/010 | Batch 0150/0391 | Cost: 0.4598 Epoch: 008/010 | Batch 0200/0391 | Cost: 0.5086 Epoch: 008/010 | Batch 0250/0391 | Cost: 0.6177 Epoch: 008/010 | Batch 0300/0391 | Cost: 0.5695 Epoch: 008/010 | Batch 0350/0391 | Cost: 0.3319 Epoch: 008/010 | Train: 88.056% | Loss: 0.349 Time elapsed: 82.51 min Epoch: 009/010 | Batch 0000/0391 | Cost: 0.4925 Epoch: 009/010 | Batch 0050/0391 | Cost: 0.4697 Epoch: 009/010 | Batch 0100/0391 | Cost: 0.6494 Epoch: 009/010 | Batch 0150/0391 | Cost: 0.4746 Epoch: 009/010 | Batch 0200/0391 | Cost: 0.6954 Epoch: 009/010 | Batch 0250/0391 | Cost: 0.3513 Epoch: 009/010 | Batch 0300/0391 | Cost: 0.5205 Epoch: 009/010 | Batch 0350/0391 | Cost: 0.4902 Epoch: 009/010 | Train: 88.864% | Loss: 0.331 Time elapsed: 92.78 min Epoch: 010/010 | Batch 0000/0391 | Cost: 0.6645 Epoch: 010/010 | Batch 0050/0391 | Cost: 0.7019 Epoch: 010/010 | Batch 0100/0391 | Cost: 0.3686 Epoch: 010/010 | Batch 0150/0391 | Cost: 0.5411 Epoch: 010/010 | Batch 0200/0391 | Cost: 0.4750 Epoch: 010/010 | Batch 0250/0391 | Cost: 0.5674 Epoch: 010/010 | Batch 0300/0391 | Cost: 0.5198 Epoch: 010/010 | Batch 0350/0391 | Cost: 0.3849 Epoch: 010/010 | Train: 88.656% | Loss: 0.325 Time elapsed: 103.06 min Total Training Time: 103.06 min
with torch.set_grad_enabled(False): # save memory during inference
print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))
Test accuracy: 84.23%
%matplotlib inline
import matplotlib.pyplot as plt
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
for batch_idx, (features, targets) in enumerate(test_loader):
features = features
targets = targets
break
logits = model(features.to(DEVICE))
_, predicted_labels = torch.max(logits, 1)
def unnormalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.mul_(s).add_(m)
return tensor
n_images = 10
fig, axes = plt.subplots(nrows=1, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = features[:n_images]
for i in range(n_images):
curr_img = orig_images[i].detach().to(torch.device('cpu'))
curr_img = unnormalize(curr_img,
torch.tensor([0.485, 0.456, 0.406]),
torch.tensor([0.229, 0.224, 0.225]))
curr_img = curr_img.permute((1, 2, 0))
axes[i].imshow(curr_img)
axes[i].set_title(classes[predicted_labels[i]])