#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Basic Graph Neural Network with Edge Prediction on MNIST # Implementing a very basic graph neural network (GNN) using a subnetwork for edge prediction. # # Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. # # In the related notebook, [gnn-basic-1.ipyb], the adjacency matrix of the pixels was basically just determined by the neighborhood pixels. Using a Gaussian filter, pixels were connected based on their Euclidean distance in the grid. In **this notebook**, the edges are predicted via a seperate neural network model # # # ```python # self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 64), # nn.ReLU(), # nn.Linear(64, 1), # nn.Tanh()) # ``` # # # Using the resulting adjacency matrix $A$, we can compute the output of a layer as # # $$X^{(l+1)}=A X^{(l)} W^{(l)}.$$ # # Here, $A$ is the $N \times N$ adjacency matrix, and $X$ is the $N \times C$ feature matrix (a 2D coordinate array, where $N$ is the total number of pixels -- $28 \times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \times P$, where $P$ would represent the number of classes if we have only a single hidden layer. # # # - Inspired by and based on Boris Knyazev's tutorial at https://medium.com/@BorisAKnyazev/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-1-3d9fada3b80d. # ## Imports # In[2]: import time import numpy as np from scipy.spatial.distance import cdist import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader from torch.utils.data.dataset import Subset if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # In[3]: get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt # ## Settings and Dataset # In[4]: ########################## ### SETTINGS ########################## # Device DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Hyperparameters RANDOM_SEED = 1 LEARNING_RATE = 0.0005 NUM_EPOCHS = 50 BATCH_SIZE = 128 IMG_SIZE = 28 # Architecture NUM_CLASSES = 10 # ## MNIST Dataset # In[5]: train_indices = torch.arange(0, 59000) valid_indices = torch.arange(59000, 60000) custom_transform = transforms.Compose([transforms.ToTensor()]) train_and_valid = datasets.MNIST(root='data', train=True, transform=custom_transform, download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=custom_transform, download=True) train_dataset = Subset(train_and_valid, train_indices) valid_dataset = Subset(train_and_valid, valid_indices) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True) valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False) test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## Model # In[6]: ########################## ### MODEL ########################## def make_coordinate_array(img_size, out_size=4): ### Make 2D coordinate array (for MNIST: 784x2) n_rows = img_size * img_size col, row = np.meshgrid(np.arange(img_size), np.arange(img_size)) coord = np.stack((col, row), axis=2).reshape(-1, 2) coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5) coord = torch.from_numpy(coord).float() ### Reshape to [N, N, out_size] coord = torch.cat((coord.unsqueeze(0).repeat(n_rows, 1, int(out_size/2-1)), coord.unsqueeze(1).repeat(1, n_rows, 1)), dim=2) return coord class GraphNet(nn.Module): def __init__(self, img_size=28, coord_features=4, num_classes=10): super(GraphNet, self).__init__() n_rows = img_size**2 self.fc = nn.Linear(n_rows, num_classes, bias=False) coord = make_coordinate_array(img_size, coord_features) self.register_buffer('coord', coord) ########## # Edge Predictor self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 32), # coord -> hidden nn.ReLU(), nn.Linear(32, 1), # hidden -> edge nn.Tanh()) def forward(self, x): B = x.size(0) ### Predict edges self.A = self.pred_edge_fc(self.coord).squeeze() ### Reshape Adjacency Matrix # [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW A_tensor = self.A.unsqueeze(0) # [1, N, N] Adj tensor -> [B, N, N] tensor A_tensor = self.A.expand(B, -1, -1) ### Reshape inputs # [B, C, H, W] => [B, H*W, 1] x_reshape = x.view(B, -1, 1) # bmm = batch matrix product to sum the neighbor features # Input: [B, N, N] x [B, N, 1] # Output: [B, N] avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1)) logits = self.fc(avg_neighbor_features) probas = F.softmax(logits, dim=1) return logits, probas # In[7]: torch.manual_seed(RANDOM_SEED) model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES) model = model.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) # ## Training # In[8]: def compute_acc(model, data_loader, device): correct_pred, num_examples = 0, 0 for features, targets in data_loader: features = features.to(device) targets = targets.to(device) logits, probas = model(features) _, predicted_labels = torch.max(probas, 1) num_examples += targets.size(0) correct_pred += (predicted_labels == targets).sum() return correct_pred.float()/num_examples * 100 start_time = time.time() cost_list = [] train_acc_list, valid_acc_list = [], [] for epoch in range(NUM_EPOCHS): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(DEVICE) targets = targets.to(DEVICE) ### FORWARD AND BACK PROP logits, probas = model(features) cost = F.cross_entropy(logits, targets) optimizer.zero_grad() cost.backward() ### UPDATE MODEL PARAMETERS optimizer.step() ################################################# ### CODE ONLY FOR LOGGING BEYOND THIS POINT ################################################ cost_list.append(cost.item()) if not batch_idx % 150: print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | ' f'Batch {batch_idx:03d}/{len(train_loader):03d} |' f' Cost: {cost:.4f}') model.eval() with torch.set_grad_enabled(False): # save memory during inference train_acc = compute_acc(model, train_loader, device=DEVICE) valid_acc = compute_acc(model, valid_loader, device=DEVICE) print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\n' f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}') train_acc_list.append(train_acc) valid_acc_list.append(valid_acc) elapsed = (time.time() - start_time)/60 print(f'Time elapsed: {elapsed:.2f} min') elapsed = (time.time() - start_time)/60 print(f'Total Training Time: {elapsed:.2f} min') # ## Evaluation # In[9]: # last adjacency matrix plt.imshow(model.A.to('cpu')); # In[10]: plt.plot(cost_list, label='Minibatch cost') plt.plot(np.convolve(cost_list, np.ones(200,)/200, mode='valid'), label='Running average') plt.ylabel('Cross Entropy') plt.xlabel('Iteration') plt.legend() plt.show() # In[11]: plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training') plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.show() # In[12]: with torch.set_grad_enabled(False): test_acc = compute_acc(model=model, data_loader=test_loader, device=DEVICE) valid_acc = compute_acc(model=model, data_loader=valid_loader, device=DEVICE) print(f'Validation ACC: {valid_acc:.2f}%') print(f'Test ACC: {test_acc:.2f}%') # In[13]: get_ipython().run_line_magic('watermark', '-iv')