GroupNorm vs BatchNorm on Pets Dataset

In this notebook, we implement GroupNorm with Weight Standardization and compare the results with BatchNorm. Simply replacing BN with GN lead to sub-optimal results.

In [1]:
from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_s`eed(2)

Resnet Implementation

We copy the implementation of Weight Standardization from the official repository here and also copy the implementation of ResNet from TorchVision.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
In [3]:
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

We replace the Convolution layer inside ResNet with the standardized version as in the Standardized Weights research paper. Everything else remains the same.

In [4]:
class Conv2d_WS(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return Conv2d_WS(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return Conv2d_WS(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
In [5]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
In [6]:
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
In [7]:
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = Conv2d_WS(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)
In [8]:
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)

Pets Dataset

Now we use the wonderful fastai library to get the Pets dataset.

In [10]:
bs = 4
In [11]:
path = untar_data(URLs.PETS); path
Out[11]:
Path('/home/ubuntu/.fastai/data/oxford-iiit-pet')
In [12]:
(path/'images').ls()
Out[12]:
(#7381) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

Dataset

The implementation of the PetsDataset has been heavily inspired and partially copied (regex part) from fastai2 repo here.

In [13]:
class PetsDataset:
    def __init__(self, paths, transforms=None):
        self.image_paths = paths
        self.transforms = transforms
        
    def __len__(self): 
        return len(self.image_paths)
    
    def setup(self, pat=r'(.+)_\d+.jpg$', label2int=None):
        "adds a label dictionary to `self`"
        self.pat = re.compile(pat)
        if label2int is not None:
            self.label2int = label2int
            self.int2label = {v:i for i,v in self.label2int.items()}
        else:
            labels = [os.path.basename(self.pat.search(str(p)).group(1))
                  for p in self.image_paths]
            self.labels = set(labels)
            self.label2int = {label:i for i,label in enumerate(self.labels)}
            self.int2label = {v:i for i,v in self.label2int.items()}

    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path)
        img = np.array(img)
        
        target = os.path.basename(self.pat.search(str(img_path)).group(1))
        target = self.label2int[target]
        
        if self.transforms:
            img = self.transforms(image=img)['image']      
            
        return img, torch.tensor(target, dtype=torch.long)
In [14]:
image_paths = get_image_files(path/'images')
image_paths
Out[14]:
(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]
In [15]:
# remove those images that are not 3 channel
from tqdm.notebook import tqdm
run_remove = False
def remove(o):
    img = Image.open(o)
    img = np.array(img)
    if img.shape[2] != 3:
        os.remove(o)
if run_remove:
    for o in tqdm(image_paths): remove(o)
In [16]:
image_paths = get_image_files(path/'images')
image_paths
Out[16]:
(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]
In [17]:
# augmentations using `albumentations` library
sz = 224
tfms = albumentations.Compose([
    albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),
    albumentations.OneOf(
        [albumentations.Cutout(random.randint(1,8), 16, 16),
         albumentations.CoarseDropout(random.randint(1,8), 16, 16)]
    ),
    albumentations.Normalize(always_apply=True),
    ToTensorV2()
])
In [18]:
dataset = PetsDataset(image_paths, tfms)
In [19]:
# to setup the `label2int` dictionary
dataset.setup()
In [20]:
dataset[0]
Out[20]:
(tensor([[[ 0.8618,  0.1597,  0.4166,  ..., -0.6452, -0.3198, -0.2171],
          [ 1.1872,  0.3481,  0.4166,  ..., -0.3027,  0.0912,  0.3138],
          [ 0.8104,  0.6049,  0.0227,  ..., -0.3712, -0.1657, -0.1828],
          ...,
          [ 1.2385,  0.4851,  0.0227,  ...,  0.8789,  1.2214,  0.8961],
          [ 0.7077,  0.9474, -0.6965,  ...,  0.1254,  1.5297,  1.6667],
          [ 0.1083, -0.0801,  0.3652,  ...,  0.2111,  0.5193,  0.6734]],
 
         [[ 0.9230,  0.4328,  0.4503,  ..., -0.2850, -0.0224, -0.0399],
          [ 1.3256,  0.7304,  0.4678,  ..., -0.0399,  0.1527,  0.3277],
          [ 0.8354,  0.8354,  0.3102,  ..., -0.2500, -0.1975, -0.3200],
          ...,
          [ 1.3606,  1.3431,  0.6078,  ...,  0.9755,  1.3957,  1.1331],
          [ 0.7654,  1.0455, -0.0574,  ...,  0.7654,  1.6232,  1.7458],
          [ 0.4153,  0.5903,  0.9230,  ...,  0.7654,  0.8529,  1.0980]],
 
         [[ 0.3393, -0.3578, -0.4275,  ..., -0.7936, -0.4624, -0.3578],
          [ 0.6531, -0.2358, -0.4973,  ..., -0.3753, -0.0615,  0.1128],
          [ 0.0431,  0.1128, -1.0201,  ..., -0.4101, -0.2707, -0.3578],
          ...,
          [ 0.7228,  0.3219, -0.5321,  ...,  0.4439,  1.0017,  0.7576],
          [ 0.2173,  0.4265, -1.1247,  ..., -0.0790,  1.1411,  1.2457],
          [-0.4450, -0.2881,  0.1302,  ...,  0.0082,  0.2696,  0.4439]]]),
 tensor(24))
In [21]:
dataset[0][0].shape
Out[21]:
torch.Size([3, 224, 224])

DataLoaders

We divide the image_paths into train and validation with 20% split.

In [22]:
nval = int(len(image_paths)*0.2)
nval
Out[22]:
1475
In [23]:
trn_img_paths = image_paths[:-nval]
val_img_paths = image_paths[-nval:]
assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)
len(trn_img_paths), len(val_img_paths)
Out[23]:
(5903, 1475)
In [24]:
trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)
val_dataset = PetsDataset(val_img_paths, transforms=tfms)
In [25]:
# use same `label2int` dictionary as in `dataset` for consistency across train and val
trn_dataset.setup(label2int=dataset.label2int)
val_dataset.setup(label2int=dataset.label2int)
In [27]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)
In [28]:
# make sure eveyrthing works so far
next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape
Out[28]:
(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))

Model

Now, we define the resnet34 from the torchvision repo with pretrained=False as we do not have pretrained weights for the GroupNorm layer.

In [29]:
# Vanilla resnet with `BatchNorm`
resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int), pretrained=False)
resnet34_bn
Out[29]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=37, bias=True)
)

Next, we define GroupNorm_32 class with default 32 groups as in the Group Normalization research paper here.

In [30]:
class GroupNorm_32(torch.nn.GroupNorm):
    def __init__(self, num_channels, num_groups=32, **kwargs):
        super().__init__(num_groups, num_channels, **kwargs)
In [31]:
# resnet34 with `GroupNorm` and `Standardized Weights`
# `conv2d` replaced with `Conv2d_WS` and `BatchNorm` replaced with `GroupNorm`
resnet34_gn = resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))
resnet34_gn
Out[31]:
ResNet(
  (conv1): Conv2d_WS(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d_WS(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      (downsample): Sequential(
        (0): Conv2d_WS(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d_WS(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (downsample): Sequential(
        (0): Conv2d_WS(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d_WS(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)
      (downsample): Sequential(
        (0): Conv2d_WS(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): GroupNorm_32(32, 512, eps=1e-05, affine=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=37, bias=True)
)
In [32]:
# make sure we are able to make forward pass
resnet34_gn(next(iter(trn_loader))[0]).shape
Out[32]:
torch.Size([4, 37])

Training using PytorchLightning

Finally, we use PytorchLightning for training the model.

In [33]:
from pytorch_lightning import LightningModule, Trainer
In [34]:
class Model(LightningModule):
    def __init__(self, base):
        super().__init__()
        self.base = base

    def forward(self, x):
        return self.base(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def step(self, batch):
        x, y  = batch
        y_hat = self(x)
        loss  = nn.CrossEntropyLoss()(y_hat, y)
        return loss, y, y_hat

    def training_step(self, batch, batch_nb):
        loss, _, _ = self.step(batch)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        loss, y, y_hat = self.step(batch)
        return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        acc = self.get_accuracy(outputs)
        print(f"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}")
        return {'loss': avg_loss}
    
    def get_accuracy(self, outputs):
        from sklearn.metrics import accuracy_score
        y = torch.cat([x['y'] for x in outputs])
        y_hat = torch.cat([x['y_hat'] for x in outputs])
        preds = y_hat.argmax(1)
        return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())
In [35]:
# define PL versions 
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
In [36]:
debug = False
gpus = torch.cuda.device_count()
trainer = Trainer(gpus=gpus, max_epochs=50, 
                  num_sanity_val_steps=1 if debug else 0)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

batch_size=4

In [37]:
# train model with `GroupNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:3.638690710067749 | Accuracy:0.022372881355932205
Epoch:1 | Loss:3.5767452716827393 | Accuracy:0.03728813559322034
Epoch:2 | Loss:3.532081365585327 | Accuracy:0.05152542372881356
Epoch:3 | Loss:3.497438907623291 | Accuracy:0.06033898305084746
Epoch:4 | Loss:3.437784194946289 | Accuracy:0.07457627118644068
Epoch:5 | Loss:3.3992772102355957 | Accuracy:0.07322033898305084
Epoch:6 | Loss:3.3322556018829346 | Accuracy:0.08203389830508474
Epoch:7 | Loss:3.278475761413574 | Accuracy:0.09220338983050848
Epoch:8 | Loss:3.2041637897491455 | Accuracy:0.12
Epoch:9 | Loss:3.1338086128234863 | Accuracy:0.13288135593220338
Epoch:10 | Loss:2.9662578105926514 | Accuracy:0.15593220338983052
Epoch:11 | Loss:2.9380886554718018 | Accuracy:0.16203389830508474
Epoch:12 | Loss:2.7531585693359375 | Accuracy:0.21627118644067797
Epoch:13 | Loss:2.7896103858947754 | Accuracy:0.2223728813559322
Epoch:14 | Loss:2.5649585723876953 | Accuracy:0.26372881355932204
Epoch:15 | Loss:2.5243453979492188 | Accuracy:0.3071186440677966
Epoch:16 | Loss:2.453778028488159 | Accuracy:0.3220338983050847
Epoch:17 | Loss:2.575655460357666 | Accuracy:0.33016949152542374
Epoch:18 | Loss:2.723491668701172 | Accuracy:0.3193220338983051
Epoch:19 | Loss:3.0088090896606445 | Accuracy:0.3369491525423729
Epoch:20 | Loss:3.221853494644165 | Accuracy:0.3213559322033898
Epoch:21 | Loss:3.3212766647338867 | Accuracy:0.34576271186440677
Epoch:22 | Loss:3.6144063472747803 | Accuracy:0.3247457627118644
Epoch:23 | Loss:3.542142868041992 | Accuracy:0.34440677966101696
Epoch:24 | Loss:3.8027701377868652 | Accuracy:0.32610169491525426

Out[37]:
1
In [38]:
# train model with `BatchNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:4.403476715087891 | Accuracy:0.01966101694915254
Epoch:1 | Loss:3.615051746368408 | Accuracy:0.03932203389830508
Epoch:2 | Loss:3.6922903060913086 | Accuracy:0.05084745762711865
Epoch:3 | Loss:3.4302172660827637 | Accuracy:0.062372881355932205
Epoch:4 | Loss:3.351684331893921 | Accuracy:0.08271186440677966
Epoch:5 | Loss:3.2836146354675293 | Accuracy:0.0935593220338983
Epoch:6 | Loss:3.2269628047943115 | Accuracy:0.10915254237288136
Epoch:7 | Loss:3.2704873085021973 | Accuracy:0.1023728813559322
Epoch:8 | Loss:3.071798801422119 | Accuracy:0.1423728813559322
Epoch:9 | Loss:3.0656063556671143 | Accuracy:0.15457627118644068
Epoch:10 | Loss:3.0375216007232666 | Accuracy:0.17288135593220338
Epoch:11 | Loss:2.8739380836486816 | Accuracy:0.2094915254237288
Epoch:12 | Loss:2.7329418659210205 | Accuracy:0.23186440677966103
Epoch:13 | Loss:2.737560510635376 | Accuracy:0.24813559322033898
Epoch:14 | Loss:2.541532516479492 | Accuracy:0.27728813559322035
Epoch:15 | Loss:2.540792226791382 | Accuracy:0.3064406779661017
Epoch:16 | Loss:2.485729217529297 | Accuracy:0.3328813559322034
Epoch:17 | Loss:2.7257814407348633 | Accuracy:0.31389830508474575
Epoch:18 | Loss:3.07981276512146 | Accuracy:0.3247457627118644
Epoch:19 | Loss:3.1801645755767822 | Accuracy:0.31661016949152543
Epoch:20 | Loss:3.270585298538208 | Accuracy:0.3328813559322034
Epoch:21 | Loss:3.355048656463623 | Accuracy:0.3376271186440678
Epoch:22 | Loss:3.362093687057495 | Accuracy:0.29898305084745763
Epoch:23 | Loss:3.470551013946533 | Accuracy:0.3389830508474576
Epoch:24 | Loss:3.5411648750305176 | Accuracy:0.31254237288135595

Out[38]:
1

batch_size=64

In [39]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False)
In [44]:
# redefine PL versions to remove trained weights
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
In [45]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:4.571255683898926 | Accuracy:0.33152542372881355
Epoch:1 | Loss:4.823599815368652 | Accuracy:0.33084745762711865
Epoch:2 | Loss:4.738388538360596 | Accuracy:0.33152542372881355
Epoch:3 | Loss:4.6921844482421875 | Accuracy:0.3383050847457627
Epoch:4 | Loss:5.571420669555664 | Accuracy:0.3227118644067797
Epoch:5 | Loss:4.973819255828857 | Accuracy:0.31864406779661014
Epoch:6 | Loss:4.960039138793945 | Accuracy:0.31186440677966104
Epoch:7 | Loss:4.72049617767334 | Accuracy:0.33152542372881355
Epoch:8 | Loss:4.7438859939575195 | Accuracy:0.3410169491525424
Epoch:9 | Loss:4.7650861740112305 | Accuracy:0.33220338983050846
Epoch:10 | Loss:4.842560768127441 | Accuracy:0.33491525423728813
Epoch:11 | Loss:5.002099514007568 | Accuracy:0.3410169491525424
Epoch:12 | Loss:4.969579696655273 | Accuracy:0.3328813559322034
Epoch:13 | Loss:4.797631740570068 | Accuracy:0.3328813559322034
Epoch:14 | Loss:4.790388107299805 | Accuracy:0.33220338983050846
Epoch:15 | Loss:4.84404993057251 | Accuracy:0.3464406779661017
Epoch:16 | Loss:4.882577896118164 | Accuracy:0.3416949152542373
Epoch:17 | Loss:4.831890106201172 | Accuracy:0.3403389830508475
Epoch:18 | Loss:4.815413475036621 | Accuracy:0.34576271186440677
Epoch:19 | Loss:4.880715370178223 | Accuracy:0.34779661016949154
Epoch:20 | Loss:4.870474815368652 | Accuracy:0.34508474576271186
Epoch:21 | Loss:4.8547258377075195 | Accuracy:0.3430508474576271
Epoch:22 | Loss:4.814042568206787 | Accuracy:0.3505084745762712
Epoch:23 | Loss:5.573678970336914 | Accuracy:0.29152542372881357
Epoch:24 | Loss:4.861083030700684 | Accuracy:0.33220338983050846

Out[45]:
1
In [46]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:4.338170051574707 | Accuracy:0.36135593220338985
Epoch:1 | Loss:4.264873027801514 | Accuracy:0.3593220338983051
Epoch:2 | Loss:4.475521564483643 | Accuracy:0.368135593220339
Epoch:3 | Loss:4.5568928718566895 | Accuracy:0.37559322033898307
Epoch:4 | Loss:4.563418865203857 | Accuracy:0.36610169491525424
Epoch:5 | Loss:4.532094955444336 | Accuracy:0.36677966101694914
Epoch:6 | Loss:4.709390163421631 | Accuracy:0.36474576271186443
Epoch:7 | Loss:4.703502178192139 | Accuracy:0.34983050847457625
Epoch:8 | Loss:4.687512397766113 | Accuracy:0.36135593220338985
Epoch:9 | Loss:4.453052997589111 | Accuracy:0.37559322033898307
Epoch:10 | Loss:4.729727745056152 | Accuracy:0.3423728813559322
Epoch:11 | Loss:4.887462139129639 | Accuracy:0.34847457627118644
Epoch:12 | Loss:4.761058807373047 | Accuracy:0.36
Epoch:13 | Loss:4.628625869750977 | Accuracy:0.36610169491525424
Epoch:14 | Loss:4.939492225646973 | Accuracy:0.3735593220338983
Epoch:15 | Loss:4.9373321533203125 | Accuracy:0.36
Epoch:16 | Loss:4.884154796600342 | Accuracy:0.3701694915254237
Epoch:17 | Loss:5.015425682067871 | Accuracy:0.34576271186440677
Epoch:18 | Loss:5.0034356117248535 | Accuracy:0.34372881355932206
Epoch:19 | Loss:5.081662178039551 | Accuracy:0.34372881355932206
Epoch:20 | Loss:5.115207195281982 | Accuracy:0.3403389830508475
Epoch:21 | Loss:4.923257827758789 | Accuracy:0.368135593220339
Epoch:22 | Loss:5.064967632293701 | Accuracy:0.3701694915254237
Epoch:23 | Loss:4.966062545776367 | Accuracy:0.368135593220339
Epoch:24 | Loss:5.010922431945801 | Accuracy:0.376271186440678

Out[46]:
1

batch_size=1

In [47]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=1, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=4, shuffle=False)
In [58]:
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
In [59]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:3.6087236404418945 | Accuracy:0.04067796610169491
Epoch:1 | Loss:3.8362090587615967 | Accuracy:0.025084745762711864
Epoch:2 | Loss:3.6673178672790527 | Accuracy:0.03593220338983051
Epoch:3 | Loss:3.7399044036865234 | Accuracy:0.03389830508474576
Epoch:4 | Loss:4.054337501525879 | Accuracy:0.0488135593220339
Epoch:5 | Loss:4.010653972625732 | Accuracy:0.04542372881355932
Epoch:6 | Loss:4.764206886291504 | Accuracy:0.05288135593220339
Epoch:7 | Loss:10.56059455871582 | Accuracy:0.04474576271186441
Epoch:8 | Loss:5.048521041870117 | Accuracy:0.05830508474576271
Epoch:9 | Loss:4.828557014465332 | Accuracy:0.06508474576271187
Epoch:10 | Loss:7.225879192352295 | Accuracy:0.05694915254237288
Epoch:11 | Loss:6.472527027130127 | Accuracy:0.06779661016949153
Epoch:12 | Loss:9.755941390991211 | Accuracy:0.07050847457627119
Epoch:13 | Loss:13.05939769744873 | Accuracy:0.059661016949152545
Epoch:14 | Loss:18.591503143310547 | Accuracy:0.06508474576271187
Epoch:15 | Loss:11.946345329284668 | Accuracy:0.06915254237288136
Epoch:16 | Loss:16.744611740112305 | Accuracy:0.06983050847457627
Epoch:17 | Loss:12.913531303405762 | Accuracy:0.07661016949152542
Epoch:18 | Loss:23.76015281677246 | Accuracy:0.06508474576271187
Epoch:19 | Loss:26.5297794342041 | Accuracy:0.06576271186440678
Epoch:20 | Loss:35.212242126464844 | Accuracy:0.05898305084745763
Epoch:21 | Loss:16.634546279907227 | Accuracy:0.06169491525423729
Epoch:22 | Loss:21.815725326538086 | Accuracy:0.062372881355932205
Epoch:23 | Loss:12.68907356262207 | Accuracy:0.0711864406779661
Epoch:24 | Loss:19.639753341674805 | Accuracy:0.06779661016949153

Out[59]:
1
In [60]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  
Epoch:0 | Loss:3.6178038120269775 | Accuracy:0.03593220338983051
Epoch:1 | Loss:3.5887539386749268 | Accuracy:0.03864406779661017
Epoch:2 | Loss:3.4937922954559326 | Accuracy:0.0576271186440678
Epoch:3 | Loss:3.426539421081543 | Accuracy:0.06508474576271187
Epoch:4 | Loss:3.4010708332061768 | Accuracy:0.06915254237288136
Epoch:5 | Loss:3.352757453918457 | Accuracy:0.08949152542372882
Epoch:6 | Loss:3.3006396293640137 | Accuracy:0.10033898305084746
Epoch:7 | Loss:3.2513763904571533 | Accuracy:0.09966101694915254
Epoch:8 | Loss:3.2186825275421143 | Accuracy:0.11254237288135593
Epoch:9 | Loss:3.1824042797088623 | Accuracy:0.12067796610169491
Epoch:10 | Loss:3.1842432022094727 | Accuracy:0.1152542372881356
Epoch:11 | Loss:3.080850839614868 | Accuracy:0.1342372881355932
Epoch:12 | Loss:3.1100575923919678 | Accuracy:0.1430508474576271
Epoch:13 | Loss:3.085071563720703 | Accuracy:0.14508474576271185
Epoch:14 | Loss:3.007901906967163 | Accuracy:0.17559322033898306
Epoch:15 | Loss:3.1437573432922363 | Accuracy:0.1694915254237288
Epoch:16 | Loss:3.110459089279175 | Accuracy:0.1864406779661017
Epoch:17 | Loss:3.5012593269348145 | Accuracy:0.18847457627118644
Epoch:18 | Loss:3.4454123973846436 | Accuracy:0.21084745762711865
Epoch:19 | Loss:3.8177714347839355 | Accuracy:0.21152542372881356
Epoch:20 | Loss:4.031371116638184 | Accuracy:0.1952542372881356
Epoch:21 | Loss:4.404645919799805 | Accuracy:0.1769491525423729
Epoch:22 | Loss:4.856805324554443 | Accuracy:0.1959322033898305
Epoch:23 | Loss:4.558755874633789 | Accuracy:0.21152542372881356
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Conclusion

  • Model with GroupNorm + Standardised Weights was able to achieve similar performance as BatchNorm. Thus, GroupNorm can be considered as an alternative to BatchNorm.

  • GroupNorm does not necessarily achieve better performance than BatchNorm with lower batch size as reported in the paper for Pets dataset.

  • The research paper uses Imagenet dataset whereas this experiment was run using the Pets dataset due to lack of compute required to train on Imagenet

  • For more details, refer to my blogpost

  • For bs=1 GroupNorm performs significantly better than BatchNorm