#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Gradient Clipping # Certain types of deep neural networks, especially, simple ones without any other type regularization and a relatively large number of layers, can suffer from exploding gradient problems. The exploding gradient problem is a scenario where large loss gradients accumulate during backpropagation, which will eventually result in very large weight updates during training. As a consequence, the updates will be very unstable and fluctuate a lot, which often causes severe problems during training. This is also a particular problem for unbounded activation functions such as ReLU. # # One common, classic technique for avoiding exploding gradient problems is the so-called gradient clipping approach. Here, we simply set gradient values above or below a certain threshold to a user-specified min or max value. In PyTorch, there are several ways for performing gradient clipping. # # **1 - Basic Clipping** # # The simplest approach to gradient clipping in PyTorch is by using the [`torch.nn.utils.clip_grad_value_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_value_) function. For example, if we have instantiated a PyTorch model from a model class based on `torch.nn.Module` (as usual), we can add the following line of code in order to clip the gradients to [-1, 1] range: # # ```python # torch.nn.utils.clip_grad_value_(parameters=model.parameters(), # clip_value=1.) # # ``` # # However, notice that via this approach, we can only specify a single clip value, which will be used for both the upper and lower bound such that gradients will be clipped to the range [-`clip_value`, `clip_value`]. # # # **2 - Custom Lower and Upper Bounds** # # If we want to clip the gradients to an unsymmetric interval around zero, say [-0.1, 1.0], we can take a different approach by defining a backwards hook: # # ```python # for param in model.parameters(): # param.register_hook(lambda gradient: torch.clamp(gradient, -0.1, 1.0)) # ``` # # This backward hook only needs to be defined once after instantiating the model. Then, each time after calling the `backward` method, it will clip the gradients before running the `model.step()` method. # # **3 - Norm-clipping** # # Lastly, there's a third clipping option, [`torch.nn.utils.clip_grad_norm_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_norm_), which clips the gradients using a vector norm as follows: # # # > `torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)` # # >Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. # # ## Imports # In[2]: import time import numpy as np from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader import torch.nn.functional as F import torch if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Settings and Dataset # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 1 learning_rate = 0.01 num_epochs = 10 batch_size = 64 # Architecture num_features = 784 num_hidden_1 = 256 num_hidden_2 = 128 num_hidden_3 = 64 num_hidden_4 = 32 num_classes = 10 ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # In[4]: def compute_accuracy(net, data_loader): net.eval() correct_pred, num_examples = 0, 0 with torch.no_grad(): for features, targets in data_loader: features = features.view(-1, 28*28).to(device) targets = targets.to(device) logits, probas = net(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 # In[5]: ########################## ### MODEL ########################## class MultilayerPerceptron(torch.nn.Module): def __init__(self, num_features, num_classes): super(MultilayerPerceptron, self).__init__() ### 1st hidden layer self.linear_1 = torch.nn.Linear(num_features, num_hidden_1) ### 2nd hidden layer self.linear_2 = torch.nn.Linear(num_hidden_1, num_hidden_2) ### 3rd hidden layer self.linear_3 = torch.nn.Linear(num_hidden_2, num_hidden_3) ### 4th hidden layer self.linear_4 = torch.nn.Linear(num_hidden_3, num_hidden_4) ### Output layer self.linear_out = torch.nn.Linear(num_hidden_4, num_classes) def forward(self, x): out = self.linear_1(x) out = F.relu(out) out = self.linear_2(out) out = F.relu(out) out = self.linear_3(out) out = F.relu(out) out = self.linear_4(out) out = F.relu(out) logits = self.linear_out(out) probas = F.log_softmax(logits, dim=1) return logits, probas # ## 1 - Basic Clipping # In[6]: torch.manual_seed(random_seed) model = MultilayerPerceptron(num_features=num_features, num_classes=num_classes) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ################################################################### start_time = time.time() for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.view(-1, 28*28).to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS ######################################################### ######################################################### ### GRADIENT CLIPPING torch.nn.utils.clip_grad_value_(model.parameters(), 1.) ######################################################### ######################################################### optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) with torch.set_grad_enabled(False): print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # In[7]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # ## 2 - Custom Lower and Upper Bounds # In[8]: torch.manual_seed(random_seed) model = MultilayerPerceptron(num_features=num_features, num_classes=num_classes) ######################################################### ######################################################### ### GRADIENT CLIPPING for p in model.parameters(): p.register_hook(lambda grad: torch.clamp(grad, -0.1, 1.0)) ######################################################### ######################################################### model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ################################################################### start_time = time.time() for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.view(-1, 28*28).to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) with torch.set_grad_enabled(False): print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # In[9]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # ## 3 - Norm-clipping # In[10]: torch.manual_seed(random_seed) model = MultilayerPerceptron(num_features=num_features, num_classes=num_classes) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ################################################################### start_time = time.time() for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.view(-1, 28*28).to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS ######################################################### ######################################################### ### GRADIENT CLIPPING torch.nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2) ######################################################### ######################################################### optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) with torch.set_grad_enabled(False): print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # In[11]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # In[12]: get_ipython().run_line_magic('watermark', '-iv')