In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.vision.gan import *

LSUN bedroom data

For this lesson, we'll be using the bedrooms from the LSUN dataset. The full dataset is a bit too large so we'll use a sample from kaggle.

In [3]:
path = untar_data(URLs.LSUN_BEDROOMS)

We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later. It consists of random noise of size 100 by default (can be changed below) as inputs and the images of bedrooms as targets. That's why we do tfm_y=True in the transforms, then apply the normalization to the ys and not the xs.

In [4]:
def get_data(bs, size):
    return (GANItemList.from_folder(path, noise_sz=100)
               .no_split()
               .label_from_func(noop)
               .transform(tfms=[[crop_pad(size=size, row_pct=(0, 1), col_pct=(0, 1))], []], size=size, tfm_y=True)
               .databunch(bs=bs)
               .normalize(stats=[torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5])], do_x=False, do_y=True))

We'll begin with a small size and use gradual resizing.

In [5]:
data = get_data(128, 64)
In [6]:
data.show_batch(rows=5)

Models

GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:

  1. Freeze the generator and train the critic for one step by:
    • getting one batch of true images (let's call that real)
    • generating one batch of fake images (let's call that fake)
    • have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
    • update the weights of the critic with the gradients of this loss
  1. Freeze the critic and train the generator for one step by:
    • generating one batch of fake images
    • evaluate the critic on it
    • return a loss that rewards posisitivly the critic thinking those are real images; the important part is that it rewards positively the detection of real images and penalizes the fake ones
    • update the weights of the generator with the gradients of this loss

Here, we'll use the Wassertein GAN (WGAN).

We create a generator and a critic that we pass to gan_learner. The noise_size is the size of the random vector from which our generator creates images.

  • basic_generator: a basic generator from noise_sz to images n_channels x in_size x in_size.
  • basic_critic: a basic critic for images n_channels x in_size x in_size.
In [7]:
generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (in_size=64, n_channels=3, n_extra_layers=1)
In [8]:
# Create a WGAN from `data`, `generator` and `critic`.
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                        opt_func=partial(optim.Adam, betas=(0., 0.99)), wd=0.)
In [9]:
%%time

learn.fit(30, 2e-4)
Total time: 7:24:39

epoch train_loss gen_loss disc_loss
1 -0.819376 0.552723 -1.092979
2 -0.718442 0.526980 -0.972343
3 -0.680821 0.462584 -0.917101
4 -0.595675 0.422231 -0.790659
5 -0.565563 0.423978 -0.768163
6 -0.540192 0.394471 -0.740738
7 -0.494420 0.357299 -0.659527
8 -0.442884 0.333854 -0.604446
9 -0.425071 0.304099 -0.579192
10 -0.406816 0.286110 -0.543247
11 -0.394037 0.259015 -0.524132
12 -0.363659 0.246012 -0.487526
13 -0.349113 0.219589 -0.457561
14 -0.317694 0.217000 -0.423149
15 -0.298355 0.210969 -0.407974
16 -0.286983 0.199303 -0.385608
17 -0.279147 0.159232 -0.367452
18 -0.271880 0.164967 -0.356202
19 -0.247997 0.165707 -0.330360
20 -0.236673 0.149570 -0.320379
21 -0.238884 0.139290 -0.314305
22 -0.233499 0.144022 -0.301759
23 -0.232004 0.126653 -0.297298
24 -0.211484 0.136557 -0.285208
25 -0.206046 0.141741 -0.274359
26 -0.204353 0.124318 -0.265893
27 -0.186395 0.131770 -0.264553
28 -0.201367 0.105065 -0.263350
29 -0.181199 0.118817 -0.249614
30 -0.183528 0.107711 -0.245990
CPU times: user 3h 47min 41s, sys: 51min, total: 4h 38min 42s
Wall time: 7h 24min 39s
In [10]:
learn.save('wgan_epoch_30')
In [15]:
learn.gan_trainer.switch(gen_mode=True)
In [16]:
learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-16-71ed10d3c2cd> in <module>
----> 1 learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))

~/anaconda3/envs/fastai-v1/lib/python3.6/site-packages/fastai/basic_train.py in show_results(self, ds_type, rows, **kwargs)
    301                 preds = self.data.denorm(preds, do_x=True)
    302         analyze_kwargs,kwargs = split_kwargs_by_func(kwargs, ds.y.analyze_pred)
--> 303         preds = [ds.y.analyze_pred(grab_idx(preds, i), **analyze_kwargs) for i in range(rows)]
    304         xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)]
    305         if has_arg(ds.y.reconstruct, 'x'):

~/anaconda3/envs/fastai-v1/lib/python3.6/site-packages/fastai/basic_train.py in <listcomp>(.0)
    301                 preds = self.data.denorm(preds, do_x=True)
    302         analyze_kwargs,kwargs = split_kwargs_by_func(kwargs, ds.y.analyze_pred)
--> 303         preds = [ds.y.analyze_pred(grab_idx(preds, i), **analyze_kwargs) for i in range(rows)]
    304         xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)]
    305         if has_arg(ds.y.reconstruct, 'x'):

~/anaconda3/envs/fastai-v1/lib/python3.6/site-packages/fastai/torch_core.py in grab_idx(x, i, batch_first)
    268 def grab_idx(x,i,batch_first:bool=True):
    269     "Grab the `i`-th batch in `x`, `batch_first` stating the batch dimension."
--> 270     if batch_first: return ([o[i].cpu() for o in x]   if is_listy(x) else x[i].cpu())
    271     else:           return ([o[:,i].cpu() for o in x] if is_listy(x) else x[:,i].cpu())
    272 

IndexError: index 128 is out of bounds for dimension 0 with size 128
In [ ]:
learn.gan_trainer.switch(gen_mode=True)
learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))