#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Model Zoo -- Weight Sharing Within a Layer # For some exotic research projects, you may want to share the weights in certain layers. For this example, suppose you want to share the weights across all output units but want to have unique bias units for each output unit. # # The illustration below shows the last hidden layer and the output layer of a regular multilayer neural network: # # ![](../images/weight-sharing/weight-sharing-1.png) # What we are trying to achieve is to have the same weight for each output unit, i.e., # # ![](../images/weight-sharing/weight-sharing-2.png) # One approach to achive this is to share the weight columns in the weight matrix of the hidden layer that connects to the output layer. A more efficient approach is to replace the matrix-matrix multiplication with shared weights by a matrix-vector multiplication that produces a single output unit, which we can then duplicate before adding the bias vector. # # In other words, the first step is to modify the hidden layer such that it only contains a single vector: # # ```python # # # Replace this by the uncommented code below: # #self.linear_1 = torch.nn.Linear(7*7*8, num_classes) # # # Use only a weight vector instead of weight matrix: # self.linear_1 = torch.nn.Linear(7*7*8, 1, bias=False) # # # Define bias manually: # self.linear_1_bias = torch.nn.Parameter(torch.tensor(torch.zeros(num_classes), # dtype=self.linear_1.weight.dtype)) # ``` # # Next, in the `forward` method, we compute the single output and duplicate it over the number of classes, then we add the weights: # # ```python # # # Duplicate outputs over all output units # logits = self.linear_1(out.view(-1, 7*7*8)) # ones = torch.ones(num_classes, dtype=logits.dtype) # ones = logits # # # then manually add bias # logits = logits + self.linear_1_bias # ``` # # The following code in this notebook illustrates this using a convnet and the 10-class MNIST dataset. # # **The classification performance will obviously poor, because in this case weight sharing is not ideal, but this is more meant as a technical reference/demo, not a real-world use case for this dataset** # ## Imports # In[2]: import time import numpy as np import torch import torch.nn.functional as F from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Settings and Dataset # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 1 learning_rate = 0.1 num_epochs = 10 batch_size = 128 # Architecture num_classes = 10 ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## Model # In[4]: ########################## ### MODEL ########################## class ConvNet(torch.nn.Module): def __init__(self, num_classes): super(ConvNet, self).__init__() # calculate same padding: # (w - k + 2*p)/s + 1 = o # => p = (s(o-1) - w + k)/2 # 28x28x1 => 28x28x4 self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3, 3), stride=(1, 1), padding=1) # (1(28-1) - 28 + 3) / 2 = 1 # 28x28x4 => 14x14x4 self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) # (2(14-1) - 28 + 2) = 0 # 14x14x4 => 14x14x8 self.conv_2 = torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=1) # (1(14-1) - 14 + 3) / 2 = 1 # 14x14x8 => 7x7x8 self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) # (2(7-1) - 14 + 2) = 0 ############################################################################## ### WEIGHT SHARING IN LAST LAYER #self.linear_1 = torch.nn.Linear(7*7*8, num_classes) # Use only a weight vector instead of weight matrix: self.linear_1 = torch.nn.Linear(7*7*8, 1, bias=False) # Define bias manually: self.linear_1_bias = torch.nn.Parameter(torch.tensor(torch.zeros(num_classes), dtype=self.linear_1.weight.dtype)) ############################################################################## def forward(self, x): out = self.conv_1(x) out = F.relu(out) out = self.pool_1(out) out = self.conv_2(out) out = F.relu(out) out = self.pool_2(out) ############################################################################## ### WEIGHT SHARING IN LAST LAYER # Duplicate outputs over all output units logits = self.linear_1(out.view(-1, 7*7*8)) # then manually add bias logits = logits + self.linear_1_bias ############################################################################## probas = F.softmax(logits, dim=1) return logits, probas torch.manual_seed(random_seed) model = ConvNet(num_classes=num_classes) model = model.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # ## Training # In[5]: def compute_accuracy(model, data_loader): correct_pred, num_examples = 0, 0 for features, targets in data_loader: features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 start_time = time.time() for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) model = model.eval() with torch.set_grad_enabled(False): # save memory during inference print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # Check that bias units updated correctly (should be all different): # In[6]: model.linear_1_bias # ## Evaluation # In[7]: with torch.set_grad_enabled(False): # save memory during inference print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # **The classification performance is obviously poor, because in this case weight sharing is not ideal, but this is more meant as a technical reference/demo, not a real-world use case for this dataset**