%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
from fastai.dataset import *
import gzip
torch.cuda.set_device(3)
Download the LSUN scene classification dataset bedroom category, unzip it, and convert it to jpg files (the scripts folder is here in the dl2
folder):
curl 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag=latest&category=bedroom&set=train' -o bedroom.zip
unzip bedroom.zip
pip install lmdb
python lsun-data.py {PATH}/bedroom_train_lmdb --out_dir {PATH}/bedroom
This isn't tested on Windows - if it doesn't work, you could use a Linux box to convert the files, then copy them over. Alternatively, you can download this 20% sample from Kaggle datasets.
PATH = Path('data/lsun/')
IMG_PATH = PATH/'bedroom'
CSV_PATH = PATH/'files.csv'
TMP_PATH = PATH/'tmp'
TMP_PATH.mkdir(exist_ok=True)
files = PATH.glob('bedroom/**/*.jpg')
with CSV_PATH.open('w') as fo:
for f in files: fo.write(f'{f.relative_to(IMG_PATH)},0\n')
# Optional - sampling a subset of files
CSV_PATH = PATH/'files_sample.csv'
files = PATH.glob('bedroom/**/*.jpg')
with CSV_PATH.open('w') as fo:
for f in files:
if random.random()<0.1: fo.write(f'{f.relative_to(IMG_PATH)},0\n')
class ConvBlock(nn.Module):
def __init__(self, ni, no, ks, stride, bn=True, pad=None):
super().__init__()
if pad is None: pad = ks//2//stride
self.conv = nn.Conv2d(ni, no, ks, stride, padding=pad, bias=False)
self.bn = nn.BatchNorm2d(no)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
return self.bn(self.relu(self.conv(x)))
class DCGAN_D(nn.Module):
def __init__(self, isize, nc, ndf, n_extra_layers=0):
super().__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
self.initial = ConvBlock(nc, ndf, 4, 2, bn=False)
csize,cndf = isize/2,ndf
self.extra = nn.Sequential(*[ConvBlock(cndf, cndf, 3, 1)
for t in range(n_extra_layers)])
pyr_layers = []
while csize > 4:
pyr_layers.append(ConvBlock(cndf, cndf*2, 4, 2))
cndf *= 2; csize /= 2
self.pyramid = nn.Sequential(*pyr_layers)
self.final = nn.Conv2d(cndf, 1, 4, padding=0, bias=False)
def forward(self, input):
x = self.initial(input)
x = self.extra(x)
x = self.pyramid(x)
return self.final(x).mean(0).view(1)
class DeconvBlock(nn.Module):
def __init__(self, ni, no, ks, stride, pad, bn=True):
super().__init__()
self.conv = nn.ConvTranspose2d(ni, no, ks, stride, padding=pad, bias=False)
self.bn = nn.BatchNorm2d(no)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.bn(self.relu(self.conv(x)))
class DCGAN_G(nn.Module):
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
super().__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf//2, 4
while tisize!=isize: cngf*=2; tisize*=2
self.initial = DeconvBlock(nz, cngf, 4, 1, 0)
csize, cndf = 4, cngf
pyr_layers = []
while csize < isize//2:
pyr_layers.append(DeconvBlock(cngf, cngf//2, 4, 2, 1))
cngf //= 2; csize *= 2
self.pyramid = nn.Sequential(*pyr_layers)
self.extra = nn.Sequential(*[DeconvBlock(cngf, cngf, 3, 1, 1)
for t in range(n_extra_layers)])
self.final = nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)
def forward(self, input):
x = self.initial(input)
x = self.pyramid(x)
x = self.extra(x)
x = self.final(x)
return F.tanh(x)
bs,sz,nz = 64,64,100
tfms = tfms_from_stats(inception_stats, sz)
md = ImageClassifierData.from_csv(PATH, 'bedroom', CSV_PATH, tfms=tfms,
skip_header=False, continuous=True)
md = md.resize(128)
A Jupyter Widget
x,_ = next(iter(md.val_dl))
plt.imshow(md.trn_ds.denorm(x)[0]);
netG = DCGAN_G(sz, nz, 3, 64, 1).cuda()
netD = DCGAN_D(sz, 3, 64, 1).cuda()
def create_noise(b): return V(torch.zeros(b, nz, 1, 1).normal_(0, 1))
preds = netG(create_noise(4))
pred_ims = md.trn_ds.denorm(preds)
fig, axes = plt.subplots(2, 2, figsize=(6, 6))
for i,ax in enumerate(axes.flat): ax.imshow(pred_ims[i])
def gallery(x, nc=3):
n,h,w,c = x.shape
nr = n//nc
assert n == nr*nc
return (x.reshape(nr, nc, h, w, c)
.swapaxes(1,2)
.reshape(h*nr, w*nc, c))
netD.load_state_dict(torch.load(TMP_PATH/'netD_2.h5'))
netG.load_state_dict(torch.load(TMP_PATH/'netG_2.h5'))
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)
def train(niter, first=True):
gen_iterations = 0
for epoch in trange(niter):
netD.train(); netG.train()
data_iter = iter(md.trn_dl)
i,n = 0,len(md.trn_dl)
while i < n:
set_trainable(netD, True)
set_trainable(netG, False)
d_iters = 100 if (first and (gen_iterations < 25) or (gen_iterations % 500 == 0)) else 5
j = 0
while (j < d_iters) and (i < n):
j += 1; i += 1
for p in netD.parameters(): p.data.clamp_(-0.01, 0.01)
real = V(next(data_iter)[0])
real_loss = netD(real)
fake = netG(create_noise(real.size(0)))
fake_loss = netD(V(fake.data))
netD.zero_grad()
lossD = real_loss-fake_loss
lossD.backward()
optimizerD.step()
set_trainable(netD, False)
set_trainable(netG, True)
netG.zero_grad()
lossG = netD(netG(create_noise(bs))).mean(0).view(1)
lossG.backward()
optimizerG.step()
gen_iterations += 1
print(f'Loss_D {to_np(lossD)}; Loss_G {to_np(lossG)}; '
f'D_real {to_np(real_loss)}; Loss_D_fake {to_np(fake_loss)}')
torch.backends.cudnn.benchmark=True
train(50, False)
# train(50, True)
0%| | 0/50 [00:00<?, ?it/s]Loss_D [-0.741]; Loss_G [-0.36428]; D_real [-0.14867]; Loss_D_fake [ 0.59233] 2%|▏ | 1/50 [50:43<41:25:46, 3043.82s/it]Loss_D [-0.61727]; Loss_G [-0.12435]; D_real [-0.08988]; Loss_D_fake [ 0.52739] 4%|▍ | 2/50 [1:18:10<31:16:21, 2345.45s/it]Loss_D [-0.80041]; Loss_G [-0.16529]; D_real [-0.39839]; Loss_D_fake [ 0.40202] 6%|▌ | 3/50 [1:44:47<27:21:45, 2095.87s/it]Loss_D [-0.96696]; Loss_G [ 0.06946]; D_real [-0.43684]; Loss_D_fake [ 0.53012] 8%|▊ | 4/50 [2:10:40<25:02:41, 1960.03s/it]Loss_D [-0.35137]; Loss_G [-0.18598]; D_real [ 0.13466]; Loss_D_fake [ 0.48603] 10%|█ | 5/50 [2:37:11<23:34:45, 1886.34s/it]Loss_D [-0.61718]; Loss_G [ 0.06927]; D_real [-0.4987]; Loss_D_fake [ 0.11848] 12%|█▏ | 6/50 [3:08:02<22:59:01, 1880.50s/it]Loss_D [-0.76736]; Loss_G [-0.03992]; D_real [-0.46265]; Loss_D_fake [ 0.30471] 14%|█▍ | 7/50 [3:36:19<22:08:49, 1854.18s/it]Loss_D [-0.8863]; Loss_G [-0.04472]; D_real [-0.52889]; Loss_D_fake [ 0.35741] 16%|█▌ | 8/50 [4:07:16<21:38:11, 1854.56s/it]Loss_D [-0.86096]; Loss_G [-0.13653]; D_real [-0.26505]; Loss_D_fake [ 0.59591] 18%|█▊ | 9/50 [4:39:19<21:12:27, 1862.12s/it]Loss_D [-0.9806]; Loss_G [ 0.12125]; D_real [-0.55699]; Loss_D_fake [ 0.42361] 20%|██ | 10/50 [5:11:12<20:44:51, 1867.28s/it]Loss_D [-0.77137]; Loss_G [-0.11307]; D_real [-0.20695]; Loss_D_fake [ 0.56441] 22%|██▏ | 11/50 [5:44:16<20:20:36, 1877.85s/it]Loss_D [-0.52718]; Loss_G [-0.19094]; D_real [ 0.02439]; Loss_D_fake [ 0.55157] 24%|██▍ | 12/50 [6:10:15<19:32:28, 1851.27s/it]Loss_D [-0.73584]; Loss_G [-0.03732]; D_real [-0.24378]; Loss_D_fake [ 0.49206] 26%|██▌ | 13/50 [6:38:03<18:52:55, 1837.16s/it]Loss_D [-0.94337]; Loss_G [-0.04098]; D_real [-0.46667]; Loss_D_fake [ 0.4767] 28%|██▊ | 14/50 [7:03:18<18:08:29, 1814.16s/it]Loss_D [-0.43215]; Loss_G [-0.09195]; D_real [ 0.17661]; Loss_D_fake [ 0.60876] 30%|███ | 15/50 [7:28:24<17:26:17, 1793.64s/it]Loss_D [-0.59091]; Loss_G [-0.08218]; D_real [-0.20211]; Loss_D_fake [ 0.38881] 32%|███▏ | 16/50 [7:59:36<16:59:10, 1798.56s/it]Loss_D [-0.31101]; Loss_G [-0.12482]; D_real [ 0.12555]; Loss_D_fake [ 0.43656] 34%|███▍ | 17/50 [8:31:41<16:33:16, 1805.97s/it]Loss_D [-0.65346]; Loss_G [-0.08354]; D_real [-0.1924]; Loss_D_fake [ 0.46106] 36%|███▌ | 18/50 [8:59:27<15:59:03, 1798.22s/it]Loss_D [-0.89509]; Loss_G [ 0.0186]; D_real [-0.57834]; Loss_D_fake [ 0.31675] 38%|███▊ | 19/50 [9:24:06<15:20:22, 1781.39s/it]Loss_D [-0.58589]; Loss_G [-0.02102]; D_real [-0.5239]; Loss_D_fake [ 0.06199] 40%|████ | 20/50 [9:49:41<14:44:32, 1769.08s/it]Loss_D [-0.42573]; Loss_G [-0.19511]; D_real [-0.06038]; Loss_D_fake [ 0.36535] 42%|████▏ | 21/50 [10:15:09<14:09:30, 1757.59s/it]Loss_D [-0.87549]; Loss_G [-0.22496]; D_real [-0.43741]; Loss_D_fake [ 0.43808] 44%|████▍ | 22/50 [10:41:11<13:36:03, 1748.71s/it]Loss_D [-0.92783]; Loss_G [-0.1044]; D_real [-0.49211]; Loss_D_fake [ 0.43573] 46%|████▌ | 23/50 [11:06:43<13:02:40, 1739.29s/it]Loss_D [-1.09537]; Loss_G [-0.04373]; D_real [-0.5675]; Loss_D_fake [ 0.52787] 48%|████▊ | 24/50 [11:32:48<12:30:32, 1732.03s/it]Loss_D [-0.53628]; Loss_G [-0.30188]; D_real [-0.00368]; Loss_D_fake [ 0.5326] 50%|█████ | 25/50 [11:58:33<11:58:33, 1724.55s/it]Loss_D [-0.68356]; Loss_G [-0.18487]; D_real [-0.04316]; Loss_D_fake [ 0.6404] 52%|█████▏ | 26/50 [12:23:25<11:26:14, 1715.59s/it]Loss_D [-0.91418]; Loss_G [-0.11698]; D_real [-0.39415]; Loss_D_fake [ 0.52003] 54%|█████▍ | 27/50 [12:48:40<10:54:48, 1708.18s/it]Loss_D [-0.98313]; Loss_G [-0.17385]; D_real [-0.48685]; Loss_D_fake [ 0.49627] 56%|█████▌ | 28/50 [13:13:48<10:23:42, 1701.01s/it]Loss_D [-0.70201]; Loss_G [ 0.16602]; D_real [-0.51006]; Loss_D_fake [ 0.19195] 58%|█████▊ | 29/50 [13:39:06<9:53:08, 1694.69s/it] Loss_D [-0.68617]; Loss_G [-0.1526]; D_real [-0.19117]; Loss_D_fake [ 0.495] 60%|██████ | 30/50 [14:04:16<9:22:50, 1688.53s/it]Loss_D [-0.65424]; Loss_G [-0.24224]; D_real [-0.04156]; Loss_D_fake [ 0.61268] 62%|██████▏ | 31/50 [14:28:49<8:52:30, 1681.60s/it]Loss_D [-0.49606]; Loss_G [-0.22956]; D_real [ 0.08758]; Loss_D_fake [ 0.58364] 64%|██████▍ | 32/50 [14:53:32<8:22:36, 1675.38s/it]Loss_D [-1.00359]; Loss_G [-0.06996]; D_real [-0.4014]; Loss_D_fake [ 0.60219] 66%|██████▌ | 33/50 [15:18:14<7:53:02, 1669.53s/it]Loss_D [-0.87005]; Loss_G [ 0.19799]; D_real [-0.46416]; Loss_D_fake [ 0.4059] 68%|██████▊ | 34/50 [15:42:49<7:23:40, 1663.81s/it]Loss_D [-0.83551]; Loss_G [ 0.05104]; D_real [-0.51311]; Loss_D_fake [ 0.3224] 70%|███████ | 35/50 [16:07:11<6:54:30, 1658.04s/it]Loss_D [-0.13401]; Loss_G [-0.1558]; D_real [-0.43032]; Loss_D_fake [-0.29631] 72%|███████▏ | 36/50 [16:32:14<6:25:52, 1653.73s/it]Loss_D [-0.28046]; Loss_G [-0.07256]; D_real [ 0.12695]; Loss_D_fake [ 0.40742] 74%|███████▍ | 37/50 [16:56:40<5:57:12, 1648.67s/it]Loss_D [-0.63494]; Loss_G [-0.18197]; D_real [-0.16491]; Loss_D_fake [ 0.47002] 76%|███████▌ | 38/50 [17:21:13<5:28:48, 1644.05s/it]Loss_D [-0.50106]; Loss_G [-0.24609]; D_real [-0.02722]; Loss_D_fake [ 0.47384] 78%|███████▊ | 39/50 [17:45:59<5:00:39, 1640.00s/it]Loss_D [-0.33549]; Loss_G [-0.27103]; D_real [ 0.11978]; Loss_D_fake [ 0.45527] 80%|████████ | 40/50 [18:10:57<4:32:44, 1636.44s/it]Loss_D [-0.53828]; Loss_G [-0.12616]; D_real [ 0.11586]; Loss_D_fake [ 0.65414] 82%|████████▏ | 41/50 [18:35:41<4:04:54, 1632.73s/it]Loss_D [-0.2754]; Loss_G [ 0.14415]; D_real [ 0.21671]; Loss_D_fake [ 0.49211] 84%|████████▍ | 42/50 [19:00:27<3:37:13, 1629.22s/it]Loss_D [-0.6237]; Loss_G [-0.10215]; D_real [-0.01876]; Loss_D_fake [ 0.60494] 86%|████████▌ | 43/50 [19:26:08<3:09:50, 1627.17s/it]Loss_D [-1.00502]; Loss_G [-0.22498]; D_real [-0.36107]; Loss_D_fake [ 0.64395] 88%|████████▊ | 44/50 [19:50:47<2:42:22, 1623.80s/it]Loss_D [-0.77478]; Loss_G [ 0.02211]; D_real [-0.494]; Loss_D_fake [ 0.28078] 90%|█████████ | 45/50 [20:15:41<2:15:04, 1620.92s/it]Loss_D [-0.46]; Loss_G [-0.36819]; D_real [ 0.07036]; Loss_D_fake [ 0.53036] 92%|█████████▏| 46/50 [20:40:31<1:47:52, 1618.08s/it]Loss_D [-0.83941]; Loss_G [-0.12256]; D_real [-0.48992]; Loss_D_fake [ 0.34949] 94%|█████████▍| 47/50 [21:05:50<1:20:47, 1615.98s/it]Loss_D [-0.65171]; Loss_G [ 0.09275]; D_real [-0.60725]; Loss_D_fake [ 0.04447] 96%|█████████▌| 48/50 [21:30:38<53:46, 1613.31s/it] Loss_D [-0.37492]; Loss_G [ 0.03437]; D_real [ 0.08686]; Loss_D_fake [ 0.46178] 98%|█████████▊| 49/50 [21:55:22<26:50, 1610.67s/it]Loss_D [-1.0074]; Loss_G [-0.02665]; D_real [-0.51458]; Loss_D_fake [ 0.49282] 100%|██████████| 50/50 [22:20:38<00:00, 1608.77s/it]
set_trainable(netD, True)
set_trainable(netG, True)
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-5)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-5)
train(10, False)
0%| | 0/10 [00:00<?, ?it/s]Loss_D [-0.53839]; Loss_G [-0.06873]; D_real [ 0.13422]; Loss_D_fake [ 0.6726] 10%|█ | 1/10 [23:56<3:35:25, 1436.17s/it]Loss_D [-0.55885]; Loss_G [-0.13435]; D_real [ 0.11764]; Loss_D_fake [ 0.67649] 20%|██ | 2/10 [47:51<3:11:27, 1435.99s/it]Loss_D [-0.53274]; Loss_G [-0.40558]; D_real [ 0.17589]; Loss_D_fake [ 0.70863] 30%|███ | 3/10 [1:12:04<2:48:11, 1441.63s/it]Loss_D [-1.08415]; Loss_G [ 0.11718]; D_real [-0.4141]; Loss_D_fake [ 0.67005] 40%|████ | 4/10 [1:36:07<2:24:11, 1441.99s/it]Loss_D [-0.10936]; Loss_G [ 0.30343]; D_real [ 0.53342]; Loss_D_fake [ 0.64279] 50%|█████ | 5/10 [1:59:57<1:59:57, 1439.58s/it]Loss_D [-1.21678]; Loss_G [ 0.46415]; D_real [-0.56848]; Loss_D_fake [ 0.6483] 60%|██████ | 6/10 [2:23:59<1:35:59, 1439.94s/it]Loss_D [-1.13921]; Loss_G [ 0.15855]; D_real [-0.4484]; Loss_D_fake [ 0.69081] 70%|███████ | 7/10 [2:48:08<1:12:03, 1441.25s/it]Loss_D [-1.34583]; Loss_G [ 0.63406]; D_real [-0.65586]; Loss_D_fake [ 0.68998] 80%|████████ | 8/10 [3:12:25<48:06, 1443.24s/it] Loss_D [-0.6132]; Loss_G [ 0.02461]; D_real [ 0.06617]; Loss_D_fake [ 0.67937] 90%|█████████ | 9/10 [3:36:32<24:03, 1443.66s/it]Loss_D [-0.73144]; Loss_G [-0.08948]; D_real [-0.04399]; Loss_D_fake [ 0.68746] 100%|██████████| 10/10 [4:01:28<00:00, 1448.89s/it]
netD.eval(); netG.eval();
fixed_noise = create_noise(bs)
netD.eval(); netG.eval();
fake = netG(fixed_noise).data.cpu()
faked = np.clip(md.trn_ds.denorm(fake),0,1)
plt.figure(figsize=(9,9))
plt.imshow(gallery(faked, 8));
torch.save(netG.state_dict(), TMP_PATH/'netG_2.h5')
torch.save(netD.state_dict(), TMP_PATH/'netD_2.h5')
class DeconvBlock(nn.Module):
def __init__(self, ni, no, ks, bn=True):
super().__init__()
self.conv = nn.Conv2d(ni, no, ks, 1, padding=ks//2, bias=False)
self.bn = nn.BatchNorm2d(no)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = F.upsample(x, scale_factor=2, mode='bilinear')
return self.bn(self.relu(self.conv(x)))
class DCGAN_G(nn.Module):
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
super().__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf//2, 4
while tisize!=isize: cngf*=2; tisize*=2
self.initial = ConvBlock(nz, cngf, 4, 1, pad=3)
csize, cndf = 4, cngf
pyr_layers = []
while csize < isize//2:
pyr_layers.append(DeconvBlock(cngf, cngf//2, 3))
cngf //= 2; csize *= 2
self.pyramid = nn.Sequential(*pyr_layers)
self.extra = nn.Sequential(*[ConvBlock(cngf, cngf, 3, 1)
for t in range(n_extra_layers)])
self.final = nn.Conv2d(cngf, nc, 3, 1, 1, bias=False)
def forward(self, input):
x = self.initial(input)
x = self.pyramid(x)
x = self.extra(x)
x = F.upsample(x, scale_factor=2, mode='bilinear')
x = self.final(x)
return F.tanh(x)
preds = netG(create_noise(4))
pred_ims = md.trn_ds.denorm(preds)
fig, axes = plt.subplots(2, 2, figsize=(6, 6))
for i,ax in enumerate(axes.flat): ax.imshow(pred_ims[i])
Results after <1 epoch
plt.figure(figsize=(9,9))
plt.imshow(gallery(faked, 8));