GroupNorm
vs BatchNorm
on Pets
Dataset¶In this notebook, we implement GroupNorm
with Weight Standardization
and compare the results with BatchNorm
. Simply replacing BN
with GN
lead to sub-optimal results.
from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_s`eed(2)
Resnet
Implementation¶We copy the implementation of Weight Standardization
from the official repository here and also copy the implementation of ResNet
from TorchVision.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
We replace the Convolution
layer inside ResNet
with the standardized version as in the Standardized Weights
research paper. Everything else remains the same.
class Conv2d_WS(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return Conv2d_WS(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return Conv2d_WS(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = Conv2d_WS(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
Pets
Dataset¶Now we use the wonderful fastai library to get the Pets
dataset.
bs = 4
path = untar_data(URLs.PETS); path
(path/'images').ls()
Dataset
¶The implementation of the PetsDataset
has been heavily inspired and partially copied (regex part) from fastai2
repo here.
class PetsDataset:
def __init__(self, paths, transforms=None):
self.image_paths = paths
self.transforms = transforms
def __len__(self):
return len(self.image_paths)
def setup(self, pat=r'(.+)_\d+.jpg$', label2int=None):
"adds a label dictionary to `self`"
self.pat = re.compile(pat)
if label2int is not None:
self.label2int = label2int
self.int2label = {v:i for i,v in self.label2int.items()}
else:
labels = [os.path.basename(self.pat.search(str(p)).group(1))
for p in self.image_paths]
self.labels = set(labels)
self.label2int = {label:i for i,label in enumerate(self.labels)}
self.int2label = {v:i for i,v in self.label2int.items()}
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path)
img = np.array(img)
target = os.path.basename(self.pat.search(str(img_path)).group(1))
target = self.label2int[target]
if self.transforms:
img = self.transforms(image=img)['image']
return img, torch.tensor(target, dtype=torch.long)
image_paths = get_image_files(path/'images')
image_paths
# remove those images that are not 3 channel
from tqdm.notebook import tqdm
run_remove = False
def remove(o):
img = Image.open(o)
img = np.array(img)
if img.shape[2] != 3:
os.remove(o)
if run_remove:
for o in tqdm(image_paths): remove(o)
image_paths = get_image_files(path/'images')
image_paths
# augmentations using `albumentations` library
sz = 224
tfms = albumentations.Compose([
albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),
albumentations.OneOf(
[albumentations.Cutout(random.randint(1,8), 16, 16),
albumentations.CoarseDropout(random.randint(1,8), 16, 16)]
),
albumentations.Normalize(always_apply=True),
ToTensorV2()
])
dataset = PetsDataset(image_paths, tfms)
# to setup the `label2int` dictionary
dataset.setup()
dataset[0]
dataset[0][0].shape
DataLoaders
¶We divide the image_paths
into train and validation with 20% split.
nval = int(len(image_paths)*0.2)
nval
trn_img_paths = image_paths[:-nval]
val_img_paths = image_paths[-nval:]
assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)
len(trn_img_paths), len(val_img_paths)
trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)
val_dataset = PetsDataset(val_img_paths, transforms=tfms)
# use same `label2int` dictionary as in `dataset` for consistency across train and val
trn_dataset.setup(label2int=dataset.label2int)
val_dataset.setup(label2int=dataset.label2int)
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)
# make sure eveyrthing works so far
next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape
Model
¶Now, we define the resnet34 from the torchvision
repo with pretrained=False
as we do not have pretrained weights for the GroupNorm
layer.
# Vanilla resnet with `BatchNorm`
resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int), pretrained=False)
resnet34_bn
Next, we define GroupNorm_32
class with default 32 groups as in the Group Normalization
research paper here.
class GroupNorm_32(torch.nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, **kwargs):
super().__init__(num_groups, num_channels, **kwargs)
# resnet34 with `GroupNorm` and `Standardized Weights`
# `conv2d` replaced with `Conv2d_WS` and `BatchNorm` replaced with `GroupNorm`
resnet34_gn = resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))
resnet34_gn
# make sure we are able to make forward pass
resnet34_gn(next(iter(trn_loader))[0]).shape
PytorchLightning
¶Finally, we use PytorchLightning for training the model.
from pytorch_lightning import LightningModule, Trainer
class Model(LightningModule):
def __init__(self, base):
super().__init__()
self.base = base
def forward(self, x):
return self.base(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def step(self, batch):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return loss, y, y_hat
def training_step(self, batch, batch_nb):
loss, _, _ = self.step(batch)
return {'loss': loss}
def validation_step(self, batch, batch_nb):
loss, y, y_hat = self.step(batch)
return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
acc = self.get_accuracy(outputs)
print(f"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}")
return {'loss': avg_loss}
def get_accuracy(self, outputs):
from sklearn.metrics import accuracy_score
y = torch.cat([x['y'] for x in outputs])
y_hat = torch.cat([x['y_hat'] for x in outputs])
preds = y_hat.argmax(1)
return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())
# define PL versions
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
debug = False
gpus = torch.cuda.device_count()
trainer = Trainer(gpus=gpus, max_epochs=50,
num_sanity_val_steps=1 if debug else 0)
batch_size=4
¶# train model with `GroupNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
# train model with `BatchNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
batch_size=64
¶trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False)
# redefine PL versions to remove trained weights
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
batch_size=1
¶trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=1, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=4, shuffle=False)
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
Model with GroupNorm
+ Standardised Weights
was able to achieve similar performance as BatchNorm
. Thus, GroupNorm
can be considered as an alternative to BatchNorm
.
GroupNorm
does not necessarily achieve better performance than BatchNorm
with lower batch size as reported in the paper for Pets
dataset.
The research paper uses Imagenet
dataset whereas this experiment was run using the Pets
dataset due to lack of compute required to train on Imagenet
For more details, refer to my blogpost
For bs=1
GroupNorm
performs significantly better than BatchNorm