#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Model Zoo -Standardizing Images # This notebook provides an example for working with standardized images, that is, images where the image pixels in each image has mean zero and unit variance across the channel. # # The general equation for z-score standardization is computed as # # $$x' = \frac{x_i - \mu}{\sigma}$$ # # where $\mu$ is the mean and $\sigma$ is the standard deviation of the training set, respectively. Then $x_i'$ is the scaled feature feature value, and $x_i$ is the original feature value. # # I.e, for grayscale images, we would obtain 1 mean and 1 standard deviation. For RGB images (3 color channels), we would obtain 3 mean values and 3 standard deviations. # ## Imports # In[2]: import time import numpy as np import torch import torch.nn.functional as F from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Settings and Dataset # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 1 learning_rate = 0.05 num_epochs = 10 batch_size = 128 # Architecture num_classes = 10 # ### Compute the Mean and Standard Deviation for Normalization # First, we need to determine the mean and standard deviation for each color channel in the training set. Since we assume the entire dataset does not fit into the computer memory all at once, we do this in an incremental fashion, as shown below. # In[4]: ############################## ### PRELIMINARY DATALOADER ############################## train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) train_mean = [] train_std = [] for i, image in enumerate(train_loader, 0): numpy_image = image[0].numpy() batch_mean = np.mean(numpy_image, axis=(0, 2, 3)) batch_std = np.std(numpy_image, axis=(0, 2, 3)) train_mean.append(batch_mean) train_std.append(batch_std) train_mean = torch.tensor(np.mean(train_mean, axis=0)) train_std = torch.tensor(np.mean(train_std, axis=0)) print('Mean:', train_mean) print('Std Dev:', train_std) # **Note that** # # - For RGB images (3 color channels), we would get 3 means and 3 standard deviations. # - The transforms.ToTensor() method converts images to [0, 1] range, which is why the mean and standard deviation values are below 1. # ### Standardized Dataset Loader # Now we can use a custom transform function to standardize the dataset according the the mean and standard deviation we computed above. # In[5]: custom_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=train_mean, std=train_std)]) # In[6]: ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=custom_transform, download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=custom_transform) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Check that the dataset can be loaded: # In[7]: # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # For the given batch, check that the channel means and standard deviations are roughly 0 and 1, respectively: # In[8]: print('Channel mean:', torch.mean(images[:, 0, :, :])) print('Channel std:', torch.std(images[:, 0, :, :])) # ## Model # In[9]: ########################## ### MODEL ########################## class ConvNet(torch.nn.Module): def __init__(self, num_classes): super(ConvNet, self).__init__() # calculate same padding: # (w - k + 2*p)/s + 1 = o # => p = (s(o-1) - w + k)/2 # 28x28x1 => 28x28x4 self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3, 3), stride=(1, 1), padding=1) # (1(28-1) - 28 + 3) / 2 = 1 # 28x28x4 => 14x14x4 self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) # (2(14-1) - 28 + 2) = 0 # 14x14x4 => 14x14x8 self.conv_2 = torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=1) # (1(14-1) - 14 + 3) / 2 = 1 # 14x14x8 => 7x7x8 self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) # (2(7-1) - 14 + 2) = 0 self.linear_1 = torch.nn.Linear(7*7*8, num_classes) def forward(self, x): out = self.conv_1(x) out = F.relu(out) out = self.pool_1(out) out = self.conv_2(out) out = F.relu(out) out = self.pool_2(out) logits = self.linear_1(out.view(-1, 7*7*8)) probas = F.softmax(logits, dim=1) return logits, probas torch.manual_seed(random_seed) model = ConvNet(num_classes=num_classes) model = model.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # ## Training # In[10]: def compute_accuracy(model, data_loader): correct_pred, num_examples = 0, 0 for features, targets in data_loader: features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 start_time = time.time() for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) model = model.eval() print('Epoch: %03d/%03d training accuracy: %.2f%%' % ( epoch+1, num_epochs, compute_accuracy(model, train_loader))) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # ## Evaluation # In[11]: print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader))) # In[12]: get_ipython().run_line_magic('watermark', '-iv')