#!/usr/bin/env python # coding: utf-8 # # Denoising Autoencoder with PyTorch # # Adapted from: https://github.com/Atcold/pytorch-Deep-Learning/blob/master/10-autoencoder.ipynb # ### Required Python Packages # - `torch` # - `torchvision` # # Run the following cell to install the packages. # In[ ]: # # Required Packages # Run this cell to install required packages. # get_ipython().run_line_magic('pip', 'install "torch>=1.9" "tqdm>=4.64"') # In[ ]: def to_img(x): x = 0.5 * (x + 1) x = x.view(x.size(0), 28, 28) return x # In[ ]: from matplotlib import pyplot as plt def display_images(in_, out, n=1): for N in range(n): if in_ is not None: in_pic = to_img(in_.cpu().data) plt.figure(figsize=(18, 6)) for i in range(4): plt.subplot(1, 4, i + 1) plt.imshow(in_pic[i + 4 * N]) plt.axis("off") out_pic = to_img(out.cpu().data) plt.figure(figsize=(18, 6)) for i in range(4): plt.subplot(1, 4, i + 1) plt.imshow(out_pic[i + 4 * N]) plt.axis("off") # ## Load MNIST data # In[ ]: from torchvision import transforms from torchvision.datasets import MNIST img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataset = MNIST("./data", transform=img_transform, download=True) # ## Create dataloaders # In[ ]: from torch.utils.data import DataLoader batch_size = 256 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # ## Define model # In[ ]: from torch import nn d = 30 class Autoencoder(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, d), nn.Tanh(), ) self.decoder = nn.Sequential( nn.Linear(d, 28 * 28), nn.Tanh(), ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x # ## Create model # In[ ]: import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Autoencoder().to(device) criterion = nn.MSELoss() # ## Train model # In[ ]: learning_rate = 1e-3 optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, ) num_epochs = 20 do = nn.Dropout() for epoch in range(num_epochs): for data in dataloader: img, _ = data img = img.to(device) img = img.view(img.size(0), -1) noise = do(torch.ones(img.shape)).to(device) img_bad = (img * noise).to(device) output = model(img) loss = criterion(output, img.data) optimizer.zero_grad() loss.backward() optimizer.step() print(f"epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}") display_images(img_bad, output)