from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
path = untar_data(URLs.PETS)
path_hr = path / 'images'
path_lr = path / 'crappy'
Prepare the input data by crappifying images.
from PIL import Image, ImageDraw, ImageFont
def crappify(fn, i):
dest = path_lr / fn.relative_to(path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img = PIL.Image.open(fn)
targ_sz = resize_to(img, 96, use_min=True)
img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
w, h = img.size
q = random.randint(10, 70)
ImageDraw.Draw(img).text((random.randint(0, w//2), random.randint(0, h//2)), str(q), fill=(255, 255, 255))
img.save(dest, quality=q)
Uncomment the first time you run this notebook.
# il = ImageItemList.from_folder(path_hr)
# parallel(crappify, il.items) # il.items returns array of images
For gradual resizing we can change the commented line here.
bs, size = 32, 128
# bs, size = 24, 160
# bs, size = 8, 256
arch = models.resnet34
Now let's pretrain the generator.
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)
def get_data(bs, size):
data = (src.label_from_func(lambda x: path_hr / x.name)
.transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
.databunch(bs=bs).normalize(imagenet_stats, do_y=True))
data.c = 3
return data
data_gen = get_data(bs, size)
# Sanity check
type(data_gen)
fastai.vision.data.ImageDataBunch
data_gen.show_batch(4)
wd = 1e-3
y_range = (-3., 3.)
loss_gen = MSELossFlat()
# create generator learner
def create_gen_learner():
# blur argument is for PixelShuffle_ICNR
return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
self_attention=True, y_range=y_range, loss_func=loss_gen)
learn_gen = create_gen_learner()
learn_gen.fit_one_cycle(2, pct_start=0.8)
epoch train_loss valid_loss 1 0.063843 0.057545 2 0.052317 0.049272
learn_gen.unfreeze()
learn_gen.fit_one_cycle(3, slice(1e-6, 1e-3))
epoch train_loss valid_loss 1 0.050844 0.048165 2 0.048521 0.046062 3 0.046637 0.045033
learn_gen.show_results(rows=4)
learn_gen.save('gen_pre2')
learn_gen.load('gen_pre2')
Learner(data=ImageDataBunch; Train: LabelList y: ImageItemList (6651 items) [Image (3, 331, 500), Image (3, 500, 335), Image (3, 375, 500), Image (3, 500, 333), Image (3, 375, 500)]... Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/crappy x: ImageImageList (6651 items) [Image (3, 96, 145), Image (3, 143, 96), Image (3, 96, 128), Image (3, 144, 96), Image (3, 96, 128)]... Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/crappy; Valid: LabelList y: ImageItemList (739 items) [Image (3, 500, 366), Image (3, 492, 500), Image (3, 375, 500), Image (3, 300, 190), Image (3, 375, 500)]... Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/crappy x: ImageImageList (739 items) [Image (3, 131, 96), Image (3, 96, 97), Image (3, 96, 128), Image (3, 151, 96), Image (3, 96, 128)]... Path: /home/ubuntu/.fastai/data/oxford-iiit-pet/crappy; Test: None, model=DynamicUnet( (layers): ModuleList( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (4): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (5): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (7): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): Sequential( (0): Sequential( (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (1): Sequential( (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) ) (4): UnetBlock( (shuf): PixelShuffle_ICNR( (conv): Sequential( (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (shuf): PixelShuffle(upscale_factor=2) (pad): ReplicationPad2d((1, 0, 1, 0)) (blur): AvgPool2d(kernel_size=2, stride=1, padding=0) (relu): ReLU(inplace) ) (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (conv2): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (relu): ReLU() ) (5): UnetBlock( (shuf): PixelShuffle_ICNR( (conv): Sequential( (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) ) (shuf): PixelShuffle(upscale_factor=2) (pad): ReplicationPad2d((1, 0, 1, 0)) (blur): AvgPool2d(kernel_size=2, stride=1, padding=0) (relu): ReLU(inplace) ) (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Sequential( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (conv2): Sequential( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): SelfAttention( (query): Conv1d(384, 48, kernel_size=(1,), stride=(1,), bias=False) (key): Conv1d(384, 48, kernel_size=(1,), stride=(1,), bias=False) (value): Conv1d(384, 384, kernel_size=(1,), stride=(1,), bias=False) ) ) (relu): ReLU() ) (6): UnetBlock( (shuf): PixelShuffle_ICNR( (conv): Sequential( (0): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1)) ) (shuf): PixelShuffle(upscale_factor=2) (pad): ReplicationPad2d((1, 0, 1, 0)) (blur): AvgPool2d(kernel_size=2, stride=1, padding=0) (relu): ReLU(inplace) ) (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (conv2): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (relu): ReLU() ) (7): UnetBlock( (shuf): PixelShuffle_ICNR( (conv): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) ) (shuf): PixelShuffle(upscale_factor=2) (pad): ReplicationPad2d((1, 0, 1, 0)) (blur): AvgPool2d(kernel_size=2, stride=1, padding=0) (relu): ReLU(inplace) ) (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Sequential( (0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (conv2): Sequential( (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (relu): ReLU() ) (8): PixelShuffle_ICNR( (conv): Sequential( (0): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1)) ) (shuf): PixelShuffle(upscale_factor=2) (pad): ReplicationPad2d((1, 0, 1, 0)) (blur): AvgPool2d(kernel_size=2, stride=1, padding=0) (relu): ReLU(inplace) ) (9): MergeLayer() (10): SequentialEx( (layers): ModuleList( (0): Sequential( (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (1): Sequential( (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) ) (2): MergeLayer() ) ) (11): Sequential( (0): Conv2d(99, 3, kernel_size=(1, 1), stride=(1, 1)) ) (12): SigmoidRange() ) ), opt_func=functools.partial(<class 'torch.optim.adam.Adam'>, betas=(0.9, 0.99)), loss_func=<fastai.layers.FlattenedLoss object at 0x7f5772cea908>, metrics=[], true_wd=True, bn_wd=True, wd=0.001, train_bn=True, path=PosixPath('/home/ubuntu/.fastai/data/oxford-iiit-pet/crappy'), model_dir='models', callback_fns=[<class 'fastai.basic_train.Recorder'>], callbacks=[], layer_groups=[Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): ReLU(inplace) (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (11): ReLU(inplace) (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (15): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (16): ReLU(inplace) (17): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (18): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (19): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (20): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (21): ReLU(inplace) (22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (23): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (24): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (25): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (26): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (27): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (28): ReLU(inplace) (29): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (30): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (31): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (32): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (33): ReLU(inplace) (34): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (35): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (36): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (37): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (38): ReLU(inplace) (39): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (40): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (9): ReLU(inplace) (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (14): ReLU(inplace) (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (16): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (19): ReLU(inplace) (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (22): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (23): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (24): ReLU(inplace) (25): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (26): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (27): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (28): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (29): ReLU(inplace) (30): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (31): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (32): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (33): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (34): ReLU(inplace) (35): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (36): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (37): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (39): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (40): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (41): ReLU(inplace) (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (43): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (44): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (45): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (46): ReLU(inplace) (47): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (48): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), Sequential( (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace) (4): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): ReLU(inplace) (6): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (7): PixelShuffle(upscale_factor=2) (8): ReplicationPad2d((1, 0, 1, 0)) (9): AvgPool2d(kernel_size=2, stride=1, padding=0) (10): ReLU(inplace) (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (12): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace) (14): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace) (16): ReLU() (17): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (18): PixelShuffle(upscale_factor=2) (19): ReplicationPad2d((1, 0, 1, 0)) (20): AvgPool2d(kernel_size=2, stride=1, padding=0) (21): ReLU(inplace) (22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (23): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (24): ReLU(inplace) (25): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (26): ReLU(inplace) (27): Conv1d(384, 48, kernel_size=(1,), stride=(1,), bias=False) (28): Conv1d(384, 48, kernel_size=(1,), stride=(1,), bias=False) (29): Conv1d(384, 384, kernel_size=(1,), stride=(1,), bias=False) (30): ReLU() (31): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1)) (32): PixelShuffle(upscale_factor=2) (33): ReplicationPad2d((1, 0, 1, 0)) (34): AvgPool2d(kernel_size=2, stride=1, padding=0) (35): ReLU(inplace) (36): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (37): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (38): ReLU(inplace) (39): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (40): ReLU(inplace) (41): ReLU() (42): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (43): PixelShuffle(upscale_factor=2) (44): ReplicationPad2d((1, 0, 1, 0)) (45): AvgPool2d(kernel_size=2, stride=1, padding=0) (46): ReLU(inplace) (47): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (48): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (49): ReLU(inplace) (50): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (51): ReLU(inplace) (52): ReLU() (53): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1)) (54): PixelShuffle(upscale_factor=2) (55): ReplicationPad2d((1, 0, 1, 0)) (56): AvgPool2d(kernel_size=2, stride=1, padding=0) (57): ReLU(inplace) (58): MergeLayer() (59): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (60): ReLU(inplace) (61): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (62): ReLU(inplace) (63): MergeLayer() (64): Conv2d(99, 3, kernel_size=(1, 1), stride=(1, 1)) (65): SigmoidRange() )])
name_gen = 'image_gen'
path_gen = path / name_gen
# shutil.rmtree(path_gen)
path_gen.mkdir(exist_ok=True)
def save_preds(dl):
i = 0
names = dl.dataset.items
for b in dl:
preds = learn_gen.pred_batch(batch=b, reconstruct=True)
for o in preds:
o.save(path_gen / names[i].name)
i += 1
save_preds(data_gen.fix_dl)
PIL.Image.open(path_gen.ls()[0])
learn_gen=None
gc.collect()
22678
Pretrain the critic on crappy vs not crappy.
def get_crit_data(classes, bs, size):
src = ImageItemList.from_folder(path, include=classes).random_split_by_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data = (ll.transform(get_transforms(max_zoom=2.), size=size)
.databunch(bs=bs).normalize(imagenet_stats))
data.c = 3
return data
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())
def create_critic_learner(data, metrics):
return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)
learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)
learn_critic.fit_one_cycle(6, 1e-3)
epoch train_loss valid_loss accuracy_thresh_expand 1 0.687101 0.687558 0.550000 2 0.400085 0.717118 0.683219 3 0.188067 0.608230 0.777493 4 0.122945 0.313099 0.883989 5 0.086369 0.131634 0.953989 6 0.064002 0.098896 0.969914
learn_critic.save('critic_pre2')
Now we'll combine those pretrained model in a GAN.
learn_gen = None
learn_critic = None
gc.collect()
15877
data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic_pre2')
learn_gen = create_gen_learner().load('gen_pre2')
To define a GAN Learner, we just have to specify the learner objects for the generator and the critic. The switcher is a callback that decides when to switch from discriminator/critic to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.
The loss of the critic is given by learn_crit.loss_func
. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0).
The loss of the generator is weighted sum (weights in weights_gen
) of learn_crit.loss_func
on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the learn_gen.loss_func
applied to the output (batch of fake) and the target (corresponding batch of superres images).
doc(AdaptiveGANSwitcher)
# Switcher (a LearnerCallback) that goes back to generator/critic when the loes goes below gen_thresh/crit_thresh.
doc(GANLearner)
# A Learner suitable for GANs.
doc(GANDiscriminativeLR)
# Callback that handles multiplying the learning rate by mult_lr for the critic.
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1., 50.), show_img=False, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0., 0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
lr = 1e-4
learn.fit(40, lr)
epoch train_loss gen_loss disc_loss 1 2.125576 2.035303 4.047349 2 2.038668 1.940441 3.742569 3 1.998752 1.655455 3.683571 4 1.935438 1.988385 3.592564 5 2.016738 1.840756 3.617384 6 1.895863 2.181625 3.580500 7 1.938236 1.875703 3.570604 8 1.919673 1.955053 3.559882 9 1.995669 1.908432 3.565393 10 1.995930 1.853988 3.602203 11 2.025528 1.788676 3.594409 12 2.030285 1.993515 3.600706 13 2.013598 1.817824 3.617030 14 2.037255 1.816738 3.615322 15 1.930506 1.833512 3.546154 16 1.996290 1.958869 3.592804 17 1.999194 2.107492 3.626837 18 1.933447 1.129894 3.584512 19 2.001979 1.877815 3.591382 20 1.909830 2.009139 3.564503 21 1.987533 1.981500 3.559192 22 1.995646 2.088564 3.570924 23 2.015997 2.085294 3.580336 24 2.021464 1.937294 3.601035 25 2.062711 1.893884 3.645248 26 1.997037 1.930555 3.608972 27 1.979673 1.929840 3.537084 28 2.000307 1.855003 3.564431 29 1.950889 1.804022 3.525722 30 2.048434 1.784515 3.568585 31 2.037204 1.289834 3.612416 32 2.019243 1.826666 3.580298 33 2.037152 1.605087 3.572932 34 2.008617 2.068817 3.583849 35 1.891210 2.087649 3.535491 36 2.012012 0.735636 3.589628 37 2.048726 1.251969 3.543407 38 2.021612 1.930033 3.584294 39 1.935026 0.824109 3.547806 40 2.030452 1.993718 3.565240
learn.save('gan_1c')
# change learner data_gen by assigning get_data(bs, sz)
learn.data = get_data(16, 192) # previously bs=32 and sz=128. increase img size, so half batch size.
learn.fit(10, lr/2) # progressive resizing technique in play here. halfing learning rate.
epoch train_loss gen_loss disc_loss 1 2.573608 2.731844 4.739538 2 2.738971 2.613524 4.903574 3 2.666267 2.335060 4.768784 4 2.650287 2.534996 4.781271 5 2.546113 2.452332 4.709240 6 2.559172 2.411619 4.747292 7 2.520839 2.539139 4.705715 8 2.507591 2.571561 4.638741 9 2.617532 2.494489 4.770554 10 2.475456 2.401739 4.664874
learn.show_results(rows=16)
learn.save('gan_1c_size_192')