Adapted from: https://github.com/Atcold/pytorch-Deep-Learning/blob/master/10-autoencoder.ipynb
#
# Required Packages
# Run this cell to install required packages.
#
%pip install "torch>=1.9" "tqdm>=4.64"
def to_img(x):
x = 0.5 * (x + 1)
x = x.view(x.size(0), 28, 28)
return x
from matplotlib import pyplot as plt
def display_images(in_, out, n=1):
for N in range(n):
if in_ is not None:
in_pic = to_img(in_.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
plt.subplot(1, 4, i + 1)
plt.imshow(in_pic[i + 4 * N])
plt.axis("off")
out_pic = to_img(out.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
plt.subplot(1, 4, i + 1)
plt.imshow(out_pic[i + 4 * N])
plt.axis("off")
from torchvision import transforms
from torchvision.datasets import MNIST
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = MNIST("./data", transform=img_transform, download=True)
from torch.utils.data import DataLoader
batch_size = 256
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
from torch import nn
d = 30
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, d),
nn.Tanh(),
)
self.decoder = nn.Sequential(
nn.Linear(d, 28 * 28),
nn.Tanh(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
criterion = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
)
num_epochs = 20
do = nn.Dropout()
for epoch in range(num_epochs):
for data in dataloader:
img, _ = data
img = img.to(device)
img = img.view(img.size(0), -1)
noise = do(torch.ones(img.shape)).to(device)
img_bad = (img * noise).to(device)
output = model(img)
loss = criterion(output, img.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}")
display_images(img_bad, output)
epoch [1/20], loss:0.2061 epoch [2/20], loss:0.1627 epoch [3/20], loss:0.1209 epoch [4/20], loss:0.1133 epoch [5/20], loss:0.0988 epoch [6/20], loss:0.0902 epoch [7/20], loss:0.0856 epoch [8/20], loss:0.0848 epoch [9/20], loss:0.0730 epoch [10/20], loss:0.0742 epoch [11/20], loss:0.0640
C:\Users\KIMBYO~1\AppData\Local\Temp\mrxlink-component-mmthaqcp\mrxlink_component_db988663.py:20: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). plt.figure(figsize=(18, 6))
epoch [12/20], loss:0.0678 epoch [13/20], loss:0.0641 epoch [14/20], loss:0.0617 epoch [15/20], loss:0.0584 epoch [16/20], loss:0.0583 epoch [17/20], loss:0.0580 epoch [18/20], loss:0.0554 epoch [19/20], loss:0.0555 epoch [20/20], loss:0.0559