%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
from fastai.vision.gan import *
For this lesson, we'll be using the bedrooms from the LSUN dataset. The full dataset is a bit too large so we'll use a sample from kaggle.
本节课,我们将使用LSUN dataset的起居室数据集。整个数据集略大,因此这里我们只从kaggle节选部分数据。
path = untar_data(URLs.LSUN_BEDROOMS)
We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later. It consists of random noise of size 100 by default (can be changed below) as inputs and the images of bedrooms as targets. That's why we do tfm_y=True in the transforms, then apply the normalization to the ys and not the xs.
我们使用data block API抓取所有图片到文件夹中。我们不会创建一个独立验证集,原因后面我们会做解释。输入数据里包含了均值为100的随机噪声(也可以修改),并将起居室的图片作为目标。这也是我们在转换时,使用tfm_y=True
def get_data(bs, size):
return (GANItemList.from_folder(path, noise_sz=100)
.transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True)
.normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))
We'll begin with a small side and use gradual resizing.
data = get_data(128, 64)
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually 0. for fake images and 1. for real ones).
GAN即生成对抗网络 Generative Adversarial Nets,该网络由Ian Goodfellow发明。其核心概念是我们将同时训练两个模型:一个生成器网络和一个评判网络。生成器网络将自创和数据集中的图片相似的新图片,而评判网络则试图从生成器生成的图片中分辨出真实的图片。生成器给出图片,评判网络则输出一个(作为评判结果的)值(通常0表示赝品,1表示真品)。
We train them against each other in the sense that at each step (more or less), we:
)Here, we'll use the Wassertein GAN.
这里,我们将使用Wassertein GAN这篇论文中的GAN网络。
We create a generator and a critic that we pass to gan_learner
. The noise_size is the size of the random vector from which our generator creates images.
generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
epoch | train_loss | gen_loss | disc_loss |
1 | -0.842719 | 0.542895 | -1.086206 |
2 | -0.799776 | 0.539448 | -1.067940 |
3 | -0.738768 | 0.538581 | -1.015152 |
4 | -0.718174 | 0.484403 | -0.943485 |
5 | -0.570070 | 0.428915 | -0.777247 |
6 | -0.545130 | 0.413026 | -0.749381 |
7 | -0.541453 | 0.389443 | -0.719322 |
8 | -0.469548 | 0.356602 | -0.642670 |
9 | -0.434924 | 0.329100 | -0.598782 |
10 | -0.416448 | 0.301526 | -0.558442 |
11 | -0.389224 | 0.292355 | -0.532662 |
12 | -0.361795 | 0.266539 | -0.494872 |
13 | -0.363674 | 0.245725 | -0.475951 |
14 | -0.318343 | 0.227446 | -0.432148 |
15 | -0.309221 | 0.203628 | -0.417945 |
16 | -0.300667 | 0.213194 | -0.401034 |
17 | -0.282622 | 0.187128 | -0.381643 |
18 | -0.283902 | 0.156653 | -0.374541 |
19 | -0.267852 | 0.159612 | -0.346919 |
20 | -0.257258 | 0.163018 | -0.344198 |
21 | -0.242090 | 0.159207 | -0.323443 |
22 | -0.255733 | 0.129341 | -0.322228 |
23 | -0.235854 | 0.143768 | -0.305106 |
24 | -0.220397 | 0.115682 | -0.289971 |
25 | -0.233782 | 0.135361 | -0.294088 |
26 | -0.202050 | 0.142435 | -0.279994 |
27 | -0.196104 | 0.119580 | -0.265333 |
28 | -0.204124 | 0.119595 | -0.266063 |
29 | -0.204096 | 0.131431 | -0.264097 |
30 | -0.183655 | 0.128817 | -0.254156 |
learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))