#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Model Zoo -- Convolutional ResNet and Residual Blocks # Please note that this example does not implement a really deep ResNet as described in literature but rather illustrates how the residual blocks described in He et al. [1] can be implemented in PyTorch. # # - [1] He, Kaiming, et al. "Deep residual learning for image recognition." *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2016. # ## Imports # In[2]: import time import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transforms if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Settings and Dataset # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 123 learning_rate = 0.01 num_epochs = 10 batch_size = 128 # Architecture num_classes = 10 ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## ResNet with identity blocks # The following code implements the residual blocks with skip connections such that the input passed via the shortcut matches the dimensions of the main path's output, which allows the network to learn identity functions. Such a residual block is illustrated below: # # ![](../images/resnets/resnet-ex-1-1.png) # In[4]: ########################## ### MODEL ########################## class ConvNet(torch.nn.Module): def __init__(self, num_classes): super(ConvNet, self).__init__() ######################### ### 1st residual block ######################### # 28x28x1 => 28x28x4 self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(1, 1), stride=(1, 1), padding=0) self.conv_1_bn = torch.nn.BatchNorm2d(4) # 28x28x4 => 28x28x1 self.conv_2 = torch.nn.Conv2d(in_channels=4, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding=1) self.conv_2_bn = torch.nn.BatchNorm2d(1) ######################### ### 2nd residual block ######################### # 28x28x1 => 28x28x4 self.conv_3 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(1, 1), stride=(1, 1), padding=0) self.conv_3_bn = torch.nn.BatchNorm2d(4) # 28x28x4 => 28x28x1 self.conv_4 = torch.nn.Conv2d(in_channels=4, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding=1) self.conv_4_bn = torch.nn.BatchNorm2d(1) ######################### ### Fully connected ######################### self.linear_1 = torch.nn.Linear(28*28*1, num_classes) def forward(self, x): ######################### ### 1st residual block ######################### shortcut = x out = self.conv_1(x) out = self.conv_1_bn(out) out = F.relu(out) out = self.conv_2(out) out = self.conv_2_bn(out) out += shortcut out = F.relu(out) ######################### ### 2nd residual block ######################### shortcut = out out = self.conv_3(out) out = self.conv_3_bn(out) out = F.relu(out) out = self.conv_4(out) out = self.conv_4_bn(out) out += shortcut out = F.relu(out) ######################### ### Fully connected ######################### logits = self.linear_1(out.view(-1, 28*28*1)) probas = F.softmax(logits, dim=1) return logits, probas torch.manual_seed(random_seed) model = ConvNet(num_classes=num_classes) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # ### Training # In[5]: def compute_accuracy(model, data_loader): correct_pred, num_examples = 0, 0 for i, (features, targets) in enumerate(data_loader): features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 start_time = time.time() for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) model = model.eval() # eval mode to prevent upd. batchnorm params during inference with torch.set_grad_enabled(False): # save memory during inference print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # ### Evaluation # In[6]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # ## ResNet with convolutional blocks for resizing # The following code implements the residual blocks with skip connections such that the input passed via the shortcut matches is resized to dimensions of the main path's output. Such a residual block is illustrated below: # # ![](../images/resnets/resnet-ex-1-2.png) # In[7]: ########################## ### MODEL ########################## class ConvNet(torch.nn.Module): def __init__(self, num_classes): super(ConvNet, self).__init__() ######################### ### 1st residual block ######################### # 28x28x1 => 14x14x4 self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3, 3), stride=(2, 2), padding=1) self.conv_1_bn = torch.nn.BatchNorm2d(4) # 14x14x4 => 14x14x8 self.conv_2 = torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(1, 1), stride=(1, 1), padding=0) self.conv_2_bn = torch.nn.BatchNorm2d(8) # 28x28x1 => 14x14x8 self.conv_shortcut_1 = torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(1, 1), stride=(2, 2), padding=0) self.conv_shortcut_1_bn = torch.nn.BatchNorm2d(8) ######################### ### 2nd residual block ######################### # 14x14x8 => 7x7x16 self.conv_3 = torch.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1) self.conv_3_bn = torch.nn.BatchNorm2d(16) # 7x7x16 => 7x7x32 self.conv_4 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, 1), stride=(1, 1), padding=0) self.conv_4_bn = torch.nn.BatchNorm2d(32) # 14x14x8 => 7x7x32 self.conv_shortcut_2 = torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=(1, 1), stride=(2, 2), padding=0) self.conv_shortcut_2_bn = torch.nn.BatchNorm2d(32) ######################### ### Fully connected ######################### self.linear_1 = torch.nn.Linear(7*7*32, num_classes) def forward(self, x): ######################### ### 1st residual block ######################### shortcut = x out = self.conv_1(x) # 28x28x1 => 14x14x4 out = self.conv_1_bn(out) out = F.relu(out) out = self.conv_2(out) # 14x14x4 => 714x14x8 out = self.conv_2_bn(out) # match up dimensions using a linear function (no relu) shortcut = self.conv_shortcut_1(shortcut) shortcut = self.conv_shortcut_1_bn(shortcut) out += shortcut out = F.relu(out) ######################### ### 2nd residual block ######################### shortcut = out out = self.conv_3(out) # 14x14x8 => 7x7x16 out = self.conv_3_bn(out) out = F.relu(out) out = self.conv_4(out) # 7x7x16 => 7x7x32 out = self.conv_4_bn(out) # match up dimensions using a linear function (no relu) shortcut = self.conv_shortcut_2(shortcut) shortcut = self.conv_shortcut_2_bn(shortcut) out += shortcut out = F.relu(out) ######################### ### Fully connected ######################### logits = self.linear_1(out.view(-1, 7*7*32)) probas = F.softmax(logits, dim=1) return logits, probas torch.manual_seed(random_seed) model = ConvNet(num_classes=num_classes) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # ### Training # In[8]: def compute_accuracy(model, data_loader): correct_pred, num_examples = 0, 0 for i, (features, targets) in enumerate(data_loader): features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) model = model.eval() # eval mode to prevent upd. batchnorm params during inference with torch.set_grad_enabled(False): # save memory during inference print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) # ### Evaluation # In[9]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # ## ResNet with convolutional blocks for resizing (using a helper class) # This is the same network as above but uses a `ResidualBlock` helper class. # In[10]: class ResidualBlock(torch.nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv_1 = torch.nn.Conv2d(in_channels=channels[0], out_channels=channels[1], kernel_size=(3, 3), stride=(2, 2), padding=1) self.conv_1_bn = torch.nn.BatchNorm2d(channels[1]) self.conv_2 = torch.nn.Conv2d(in_channels=channels[1], out_channels=channels[2], kernel_size=(1, 1), stride=(1, 1), padding=0) self.conv_2_bn = torch.nn.BatchNorm2d(channels[2]) self.conv_shortcut_1 = torch.nn.Conv2d(in_channels=channels[0], out_channels=channels[2], kernel_size=(1, 1), stride=(2, 2), padding=0) self.conv_shortcut_1_bn = torch.nn.BatchNorm2d(channels[2]) def forward(self, x): shortcut = x out = self.conv_1(x) out = self.conv_1_bn(out) out = F.relu(out) out = self.conv_2(out) out = self.conv_2_bn(out) # match up dimensions using a linear function (no relu) shortcut = self.conv_shortcut_1(shortcut) shortcut = self.conv_shortcut_1_bn(shortcut) out += shortcut out = F.relu(out) return out # In[11]: ########################## ### MODEL ########################## class ConvNet(torch.nn.Module): def __init__(self, num_classes): super(ConvNet, self).__init__() self.residual_block_1 = ResidualBlock(channels=[1, 4, 8]) self.residual_block_2 = ResidualBlock(channels=[8, 16, 32]) self.linear_1 = torch.nn.Linear(7*7*32, num_classes) def forward(self, x): out = self.residual_block_1.forward(x) out = self.residual_block_2.forward(out) logits = self.linear_1(out.view(-1, 7*7*32)) probas = F.softmax(logits, dim=1) return logits, probas torch.manual_seed(random_seed) model = ConvNet(num_classes=num_classes) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # ### Training # In[12]: def compute_accuracy(model, data_loader): correct_pred, num_examples = 0, 0 for i, (features, targets) in enumerate(data_loader): features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_dataset)//batch_size, cost)) model = model.eval() # eval mode to prevent upd. batchnorm params during inference with torch.set_grad_enabled(False): # save memory during inference print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) # ### Evaluation # In[13]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # In[14]: get_ipython().run_line_magic('watermark', '-iv')