In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
In [2]:
from fastai.conv_learner import *
from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50

import json
In [3]:
torch.cuda.set_device(3)
In [4]:
torch.backends.cudnn.benchmark=True

Data

In [5]:
PATH = Path('data/carvana')
MASKS_FN = 'train_masks.csv'
META_FN = 'metadata.csv'
masks_csv = pd.read_csv(PATH/MASKS_FN)
meta_csv = pd.read_csv(PATH/META_FN)
In [6]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax
In [7]:
TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
sz = 128
bs = 64
nw = 16
In [7]:
TRAIN_DN = 'train-128'
MASKS_DN = 'train_masks-128'
sz = 128
bs = 64
nw = 16
In [8]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0
In [9]:
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])
In [10]:
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
In [11]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]
In [12]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm
In [80]:
x,y = next(iter(md.trn_dl))
In [81]:
x.shape,y.shape
Out[81]:
(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))

Simple upsample

In [13]:
f = resnet34
cut,lr_cut = model_meta[f]
In [14]:
def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)
In [15]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()
In [16]:
class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, x): return self.bn(F.relu(self.conv(x)))
In [17]:
class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn, nn.ReLU(),
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2))
        
    def forward(self,x): return self.features(x)[:,0]
In [18]:
class UpsampleModel():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model.features)[1:]]
In [21]:
m_base = get_base()
In [22]:
m = to_gpu(Upsample34(m_base))
models = UpsampleModel(m)
In [24]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
In [25]:
learn.summary()
Out[25]:
OrderedDict([('Conv2d-1',
              OrderedDict([('input_shape', [-1, 3, 128, 128]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', tensor(9408))])),
             ('Conv2d-2',
              OrderedDict([('input_shape', [-1, 3, 128, 128]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', tensor(9408))])),
             ('BatchNorm2d-3',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-4',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-5',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('nb_params', 0)])),
             ('ReLU-6',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('nb_params', 0)])),
             ('MaxPool2d-7',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('MaxPool2d-8',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-9',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-10',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-11',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-12',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-13',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-14',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-15',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-16',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-17',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-18',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-19',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-20',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-21',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-22',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-23',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-24',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-25',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-26',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-27',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-28',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-29',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-30',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-31',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-32',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-33',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-34',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-35',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-36',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-37',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-38',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-39',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-40',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-41',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-42',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-43',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('Conv2d-44',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(36864))])),
             ('BatchNorm2d-45',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('BatchNorm2d-46',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', tensor(128))])),
             ('ReLU-47',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('ReLU-48',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-49',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-50',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-51',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(73728))])),
             ('Conv2d-52',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(73728))])),
             ('BatchNorm2d-53',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-54',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-55',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-56',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-57',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-58',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-59',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-60',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('Conv2d-61',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(8192))])),
             ('Conv2d-62',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(8192))])),
             ('BatchNorm2d-63',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-64',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-65',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-66',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-67',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-68',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-69',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-70',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-71',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-72',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-73',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-74',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-75',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-76',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-77',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-78',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-79',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-80',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-81',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-82',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-83',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-84',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-85',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-86',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-87',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-88',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-89',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-90',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-91',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-92',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-93',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-94',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-95',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-96',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-97',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-98',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-99',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-100',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-101',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-102',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-103',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('Conv2d-104',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(1.4746e+05))])),
             ('BatchNorm2d-105',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('BatchNorm2d-106',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', tensor(256))])),
             ('ReLU-107',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('ReLU-108',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-109',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-110',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-111',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(2.9491e+05))])),
             ('Conv2d-112',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(2.9491e+05))])),
             ('BatchNorm2d-113',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-114',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-115',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-116',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-117',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-118',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-119',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-120',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('Conv2d-121',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(32768))])),
             ('Conv2d-122',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(32768))])),
             ('BatchNorm2d-123',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-124',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-125',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-126',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-127',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-128',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-129',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-130',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-131',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-132',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-133',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-134',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-135',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-136',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-137',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-138',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-139',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-140',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-141',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-142',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-143',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-144',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-145',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-146',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-147',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-148',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-149',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-150',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-151',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-152',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-153',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-154',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-155',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-156',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-157',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-158',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-159',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-160',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-161',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-162',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-163',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-164',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-165',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-166',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-167',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-168',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-169',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-170',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-171',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-172',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-173',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-174',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-175',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-176',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-177',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-178',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-179',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-180',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-181',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-182',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-183',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-184',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-185',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-186',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-187',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-188',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-189',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-190',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-191',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('Conv2d-192',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(5.8982e+05))])),
             ('BatchNorm2d-193',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('BatchNorm2d-194',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', tensor(512))])),
             ('ReLU-195',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ReLU-196',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-197',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-198',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-199',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1.1796e+06))])),
             ('Conv2d-200',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1.1796e+06))])),
             ('BatchNorm2d-201',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-202',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-203',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-204',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-205',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('Conv2d-206',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('BatchNorm2d-207',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-208',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('Conv2d-209',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1.3107e+05))])),
             ('Conv2d-210',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1.3107e+05))])),
             ('BatchNorm2d-211',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-212',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-213',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-214',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-215',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-216',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-217',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('Conv2d-218',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('BatchNorm2d-219',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-220',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-221',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-222',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-223',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('Conv2d-224',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('BatchNorm2d-225',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-226',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-227',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-228',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-229',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-230',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-231',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('Conv2d-232',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('BatchNorm2d-233',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-234',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-235',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-236',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-237',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('Conv2d-238',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(2.3593e+06))])),
             ('BatchNorm2d-239',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('BatchNorm2d-240',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', tensor(1024))])),
             ('ReLU-241',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-242',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-243',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-244',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ReLU-245',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-246',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', True),
                           ('nb_params', tensor(5.2454e+05))])),
             ('BatchNorm2d-247',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', True),
                           ('nb_params', tensor(512))])),
             ('StdUpsample-248',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-249',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('trainable', True),
                           ('nb_params', tensor(2.6240e+05))])),
             ('BatchNorm2d-250',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('trainable', True),
                           ('nb_params', tensor(512))])),
             ('StdUpsample-251',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-252',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('trainable', True),
                           ('nb_params', tensor(2.6240e+05))])),
             ('BatchNorm2d-253',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('trainable', True),
                           ('nb_params', tensor(512))])),
             ('StdUpsample-254',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-255',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('trainable', True),
                           ('nb_params', tensor(2.6240e+05))])),
             ('BatchNorm2d-256',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('trainable', True),
                           ('nb_params', tensor(512))])),
             ('StdUpsample-257',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-258',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 1, 128, 128]),
                           ('trainable', True),
                           ('nb_params', tensor(1025))]))])
In [25]:
learn.freeze_to(1)
In [62]:
learn.lr_find()
learn.sched.plot()
 86%|█████████████████████████████████████████████████████████████          | 55/64 [00:22<00:03,  2.46it/s, loss=3.21]
In [19]:
lr=4e-2
wd=1e-7
lrs = np.array([lr/100,lr/10,lr])/2
In [35]:
learn.fit(lr,1, wds=wd, cycle_len=4,use_clr=(20,8))
  0%|          | 0/64 [00:00<?, ?it/s]
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.216882   0.133512   0.938017   0.855221  
    1      0.169544   0.115158   0.946518   0.878381       
    2      0.153114   0.099104   0.957748   0.903353       
    3      0.144105   0.093337   0.964404   0.915084       

Out[35]:
[0.09333742126112893, 0.9644036065964472, 0.9150839788573129]
In [36]:
learn.save('tmp')
In [27]:
learn.load('tmp')
In [ ]:
learn.unfreeze()
learn.bn_freeze(True)
In [38]:
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.174897   0.061603   0.976321   0.94382   
    1      0.122911   0.053625   0.982206   0.957624       
    2      0.106837   0.046653   0.985577   0.965792       
    3      0.099075   0.042291   0.986519   0.968925        

Out[38]:
[0.042291240323157536, 0.986519161670927, 0.9689251193924556]
In [39]:
learn.save('128')
In [40]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
In [41]:
show_img(py[0]>0);
In [42]:
show_img(y[0]);

U-net (ish)

In [19]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()
In [20]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))
In [33]:
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = UnetBlock(256,3,16)
        self.up6 = nn.ConvTranspose2d(16, 1, 1)
        
    def forward(self,x):
        inp = x
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x, inp)
        x = self.up6(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()
In [34]:
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]
In [30]:
m_base = get_base()
In [31]:
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
In [32]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
In [34]:
[o.features.size() for o in m.sfs]
Out[34]:
[torch.Size([3, 64, 64, 64]),
 torch.Size([3, 64, 32, 32]),
 torch.Size([3, 128, 16, 16]),
 torch.Size([3, 256, 8, 8])]
In [35]:
learn.freeze_to(1)
In [90]:
learn.lr_find()
learn.sched.plot()
  9%|▉         | 6/64 [00:05<00:50,  1.14it/s, loss=0.669]
 91%|█████████ | 58/64 [00:13<00:01,  4.39it/s, loss=1.57] 
In [ ]:
lr=4e-2
wd=1e-7

lrs = np.array([lr/200,lr/20,lr])/2
In [37]:
learn.fit(lr,1,wds=wd,cycle_len=8,use_clr=(5,8))
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.142851   0.038569   0.988547   0.971717  
    1      0.101173   0.03713    0.988605   0.975992       
    2      0.08908    0.039838   0.986759   0.972326        
    3      0.081489   0.029241   0.992063   0.979739        
    4      0.079016   0.028968   0.991646   0.97719         
    5      0.076655   0.023804   0.992451   0.982042        
    6      0.07636    0.024756   0.992738   0.983401        
    7      0.07495    0.028137   0.993225   0.983095        

Out[37]:
[0.028136500290461948, 0.993224613250248, 0.9830952419175042]
In [38]:
learn.save('128urn-tmp')
In [56]:
learn.load('128urn-tmp')
In [39]:
learn.unfreeze()
learn.bn_freeze(True)
In [40]:
learn.fit(lrs/2, 1, wds=wd, cycle_len=10,use_clr=(20,10))
epoch      trn_loss   val_loss   <lambda>   dice            
    0      0.074118   0.020005   0.993274   0.983666  
    1      0.073695   0.021194   0.993121   0.984339        
    2      0.073121   0.024444   0.993327   0.984351        
    3      0.071765   0.021549   0.993588   0.984733        
    4      0.072239   0.019336   0.993528   0.985078        
    5      0.071364   0.01934    0.993351   0.985128        
    6      0.071399   0.018902   0.993654   0.985341        
    7      0.071632   0.020567   0.993527   0.985378        
    8      0.070797   0.019375   0.993866   0.985641        
    9      0.070172   0.019292   0.993802   0.985716        

Out[40]:
[0.019291689264632407, 0.9938022683537195, 0.9857159994897389]
In [41]:
learn.save('128urn-0')
In [26]:
learn.load('128urn-0')
In [42]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
In [43]:
show_img(py[0]>0);
In [44]:
show_img(y[0]);
In [45]:
m.close()

512x512

In [23]:
TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
In [24]:
sz=512
bs=8
In [25]:
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])
In [26]:
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
In [27]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm
In [28]:
lr=2e-2
wd=1e-7

lrs = np.array([lr/200,lr/20,lr])/2
In [35]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
In [36]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
In [37]:
learn.freeze_to(1)
In [38]:
learn.load('128urn-0')
In [39]:
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.072306   0.024262   0.995649   0.990836  
    1      0.070101   0.018257   0.996652   0.992687          
    2      0.068719   0.013143   0.996877   0.992726          
    3      0.067462   0.01673    0.997156   0.993161          
    4      0.069176   0.010519   0.997248   0.993573          

Out[39]:
[0.010519474116523587, 0.9972476694318984, 0.9935726703159393]
In [40]:
learn.save('512urn-tmp')
In [41]:
learn.unfreeze()
learn.bn_freeze(True)
In [38]:
learn.load('512urn-tmp')
In [42]:
learn.fit(lrs/2,1,wds=wd, cycle_len=8,use_clr=(20,8))
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.066793   0.009628   0.997183   0.993656  
    1      0.064545   0.010879   0.997386   0.993893          
    2      0.06125    0.010519   0.997421   0.994018          
    3      0.06288    0.01147    0.997403   0.99407           
    4      0.064336   0.008862   0.997411   0.994132          
    5      0.063767   0.009989   0.997463   0.994194          
    6      0.063922   0.009927   0.997504   0.994277          
    7      0.06392    0.008977   0.997514   0.994242          

Out[42]:
[0.00897712022110465, 0.9975136726621597, 0.9942418502436744]
In [73]:
learn.fit(lrs/2,1,wds=wd, cycle_len=8,use_clr=(20,8))
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.06605    0.013602   0.997      0.993014  
    1      0.066885   0.011252   0.997248   0.993563          
    2      0.065796   0.009802   0.997223   0.993817          
    3      0.065089   0.009668   0.997296   0.993744          
    4      0.064552   0.011683   0.997269   0.993835          
    5      0.065089   0.010553   0.997415   0.993827          
    6      0.064303   0.009472   0.997431   0.994046          
    7      0.062506   0.009623   0.997441   0.994118          

Out[73]:
[0.009623114736602894, 0.9974409020136273, 0.9941179137381296]
In [74]:
learn.save('512urn')
In [26]:
learn.load('512urn')
In [75]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
In [76]:
show_img(py[0]>0);
In [77]:
show_img(y[0]);
In [78]:
m.close()

1024x1024

In [26]:
sz=1024
bs=4
In [27]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm
In [28]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
In [29]:
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
In [30]:
learn.load('512urn')
In [31]:
learn.freeze_to(1)
In [32]:
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.007656   0.008155   0.997247   0.99353   
    1      0.004706   0.00509    0.998039   0.995437             

Out[32]:
[0.005090427414942828, 0.9980387706605215, 0.995437301104031]
In [33]:
learn.save('1024urn-tmp')
In [30]:
learn.load('1024urn-tmp')
In [31]:
learn.unfreeze()
learn.bn_freeze(True)
In [32]:
lrs = np.array([lr/200,lr/30,lr])
In [33]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.005688   0.006135   0.997616   0.994616  
    1      0.004412   0.005223   0.997983   0.995349             
    2      0.004186   0.004975   0.99806    0.99554              
    3      0.004016   0.004899   0.99812    0.995627             

Out[33]:
[0.004898778487196458, 0.9981196409180051, 0.9956271404784823]
In [33]:
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.004169   0.004962   0.998049   0.995517  
    1      0.004022   0.004595   0.99823    0.995818             
    2      0.003772   0.004497   0.998215   0.995916             
    3      0.003618   0.004435   0.998291   0.995991             

Out[33]:
[0.004434524739663753, 0.9982911745707194, 0.9959913929776539]
In [34]:
learn.sched.plot_loss()
In [35]:
learn.save('1024urn')
In [26]:
learn.load('1024urn')
In [36]:
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
In [37]:
show_img(py[0]>0);
In [38]:
show_img(y[0]);
In [ ]: