#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # - Runs on CPU or GPU (if available) # # Model Zoo -- Convolutional Conditional Variational Autoencoder # # ## (without labels in reconstruction loss) # A simple convolutional conditional variational autoencoder that compresses 768-pixel MNIST images down to a 50-pixel latent vector representation. # # This implementation DOES NOT concatenate the inputs with the class labels when computing the reconstruction loss in contrast to how it is commonly done in non-convolutional conditional variational autoencoders. Not considering class-labels in the reconstruction loss leads to substantially better results compared to the implementation that does concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation [./autoencoder-cnn-cvae.ipynb](./autoencoder-cnn-cvae.ipynb) # ## Imports # In[2]: import time import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transforms if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # In[3]: ########################## ### SETTINGS ########################## # Device device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print('Device:', device) # Hyperparameters random_seed = 0 learning_rate = 0.001 num_epochs = 50 batch_size = 128 # Architecture num_classes = 10 num_features = 784 num_latent = 50 ########################## ### MNIST DATASET ########################## # Note transforms.ToTensor() scales input images # to 0-1 range train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## Model # In[4]: ########################## ### MODEL ########################## def to_onehot(labels, num_classes, device): labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device) labels_onehot.scatter_(1, labels.view(-1, 1), 1) return labels_onehot class ConditionalVariationalAutoencoder(torch.nn.Module): def __init__(self, num_features, num_latent, num_classes): super(ConditionalVariationalAutoencoder, self).__init__() self.num_classes = num_classes ############### # ENCODER ############## # calculate same padding: # (w - k + 2*p)/s + 1 = o # => p = (s(o-1) - w + k)/2 self.enc_conv_1 = torch.nn.Conv2d(in_channels=1+self.num_classes, out_channels=16, kernel_size=(6, 6), stride=(2, 2), padding=0) self.enc_conv_2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=0) self.enc_conv_3 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(2, 2), stride=(2, 2), padding=0) self.z_mean = torch.nn.Linear(64*2*2, num_latent) # in the original paper (Kingma & Welling 2015, we use # have a z_mean and z_var, but the problem is that # the z_var can be negative, which would cause issues # in the log later. Hence we assume that latent vector # has a z_mean and z_log_var component, and when we need # the regular variance or std_dev, we simply use # an exponential function self.z_log_var = torch.nn.Linear(64*2*2, num_latent) ############### # DECODER ############## self.dec_linear_1 = torch.nn.Linear(num_latent+self.num_classes, 64*2*2) self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(2, 2), stride=(2, 2), padding=0) self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(4, 4), stride=(3, 3), padding=1) self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=(6, 6), stride=(3, 3), padding=4) def reparameterize(self, z_mu, z_log_var): # Sample epsilon from standard normal distribution eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device) # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2) z = z_mu + eps * torch.exp(z_log_var/2.) return z def encoder(self, features, targets): ### Add condition onehot_targets = to_onehot(targets, self.num_classes, device) onehot_targets = onehot_targets.view(-1, self.num_classes, 1, 1) ones = torch.ones(features.size()[0], self.num_classes, features.size()[2], features.size()[3], dtype=features.dtype).to(device) ones = ones * onehot_targets x = torch.cat((features, ones), dim=1) x = self.enc_conv_1(x) x = F.leaky_relu(x) #print('conv1 out:', x.size()) x = self.enc_conv_2(x) x = F.leaky_relu(x) #print('conv2 out:', x.size()) x = self.enc_conv_3(x) x = F.leaky_relu(x) #print('conv3 out:', x.size()) z_mean = self.z_mean(x.view(-1, 64*2*2)) z_log_var = self.z_log_var(x.view(-1, 64*2*2)) encoded = self.reparameterize(z_mean, z_log_var) return z_mean, z_log_var, encoded def decoder(self, encoded, targets): ### Add condition onehot_targets = to_onehot(targets, self.num_classes, device) encoded = torch.cat((encoded, onehot_targets), dim=1) x = self.dec_linear_1(encoded) x = x.view(-1, 64, 2, 2) x = self.dec_deconv_1(x) x = F.leaky_relu(x) #print('deconv1 out:', x.size()) x = self.dec_deconv_2(x) x = F.leaky_relu(x) #print('deconv2 out:', x.size()) x = self.dec_deconv_3(x) x = F.leaky_relu(x) #print('deconv1 out:', x.size()) decoded = torch.sigmoid(x) return decoded def forward(self, features, targets): z_mean, z_log_var, encoded = self.encoder(features, targets) decoded = self.decoder(encoded, targets) return z_mean, z_log_var, encoded, decoded torch.manual_seed(random_seed) model = ConditionalVariationalAutoencoder(num_features, num_latent, num_classes) model = model.to(device) ########################## ### COST AND OPTIMIZER ########################## optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # ## Training # In[5]: start_time = time.time() for epoch in range(num_epochs): for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) ### FORWARD AND BACK PROP z_mean, z_log_var, encoded, decoded = model(features, targets) # cost = reconstruction loss + Kullback-Leibler divergence kl_divergence = (0.5 * (z_mean**2 + torch.exp(z_log_var) - z_log_var - 1)).sum() ### Add condition # Disabled for reconstruction loss as it gives poor results """ onehot_targets = to_onehot(targets, num_classes, device) onehot_targets = onehot_targets.view(-1, num_classes, 1, 1) ones = torch.ones(features.size()[0], num_classes, features.size()[2], features.size()[3], dtype=features.dtype).to(device) ones = ones * onehot_targets x_con = torch.cat((features, ones), dim=1) """ ### Compute loss #pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum') pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum') cost = kl_divergence + pixelwise_bce ### UPDATE MODEL PARAMETERS optimizer.zero_grad() cost.backward() optimizer.step() ### LOGGING if not batch_idx % 50: print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), cost)) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) # # Evaluation # ### Reconstruction # In[6]: get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib.pyplot as plt ########################## ### VISUALIZATION ########################## n_images = 15 image_width = 28 fig, axes = plt.subplots(nrows=2, ncols=n_images, sharex=True, sharey=True, figsize=(20, 2.5)) orig_images = features[:n_images] decoded_images = decoded[:n_images, 0] for i in range(n_images): for ax, img in zip(axes, [orig_images, decoded_images]): cpu_img = img[i].detach().to(torch.device('cpu')) ax[i].imshow(cpu_img.view((image_width, image_width)), cmap='binary') # ### New random-conditional images # In[7]: for i in range(10): ########################## ### RANDOM SAMPLE ########################## labels = torch.tensor([i]*10).to(device) n_images = labels.size()[0] rand_features = torch.randn(n_images, num_latent).to(device) new_images = model.decoder(rand_features, labels) ########################## ### VISUALIZATION ########################## image_width = 28 fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True) decoded_images = new_images[:n_images, 0] print('Class Label %d' % i) for ax, img in zip(axes, decoded_images): cpu_img = img.detach().to(torch.device('cpu')) ax.imshow(cpu_img.view((image_width, image_width)), cmap='binary') plt.show() # In[8]: watermark -iv