# imports for the tutorial
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# pytorch
import torch
import torchvision
import torchvision.transforms as transforms
torchvision.datasets.VOCSegmentation(root, year='2012', ...)
torchvision.datasets.CocoDetection()
torchvision.datasets.Cityscapes()
Results: 62.2% mIoU score on the 2012 PASCAL VOC segmentation challenge using pretrained models on the 2012 ImageNet dataset.
Results: The best PSPNet with a pretrained ResNet has reached a 85.4% mIoU score on the 2012 PASCAL VOC segmentation challenge.
Results: The best Mask R-CNN uses a ResNeXt (2016) to extract features and a Feature Pyramid Network (FPN) architecture. It has obtained a 37.1% AP score on the 2016 COCO segmentation challenge and a 41.8% AP score on the 2017 COCO segmentation challenge.
Results: DeepLab V2 using a ResNet-101 as backbone has reached a 79.7% mIoU score on the 2012 PASCAL VOC challenge, a 45.7% mIoU score on the PASCAL-Context challenge and a 70.4% mIoU score on the Cityscapes challenge.
Results: the best DeepLabv3+ has obtained a 89.0% mIoU score on the 2012 PASCAL VOC challenge. The model trained on the Cityscapes dataset has reached a 82.1% mIoU score for the associated challenge.
# download and load the pre-trained model
model = torch.hub.load('pytorch/vision:v0.5.0', 'deeplabv3_resnet101', pretrained=True)
# put in inference mode
model.eval();
Using cache found in C:\Users\tabad/.cache\torch\hub\pytorch_vision_v0.5.0
OrderedDict
with two Tensors that are of the same height and width as the input Tensor, but with 21 classes.output['out']
contains the semantic masks, and output['aux']
contains the auxillary loss values per-pixel.output['aux']
is not useful. So, output['out']
is of shape $(N, 21, H, W)$.# define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load an image
filename = "./assets/tut_seg_dog.jpg"
# filename = "./assets/kofiko.jpg"
input_image = Image.open(filename)
# define the pre-processing steps
# image->tensor, normalization
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# perform pre-processing
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch of size 1 as expected by the model
# send to device
model = model.to(device)
input_batch = input_batch.to(device)
input_image
# forward pass
with torch.no_grad():
output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)
print("output shape: ", output.shape)
print("output_predictions shape: ", output_predictions.shape)
output shape: torch.Size([21, 1213, 1546]) output_predictions shape: torch.Size([1213, 1546])
output_predictions = output.argmax(0)
.# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)
# plot
fig = plt.figure(figsize=(15,15))
ax = fig.add_subplot(111)
ax.imshow(r)
ax.set_axis_off()
labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow',
'diningtable','dog', 'horse', 'motorbike', 'person',
'pottedplant','sheep', 'sofa', 'train', 'tvmonitor']
print(["{}: {}".format(i + 1, labels[i]) for i in range(len(labels))])
['1: aeroplane', '2: bicycle', '3: bird', '4: boat', '5: bottle', '6: bus', '7: car', '8: cat', '9: chair', '10: cow', '11: diningtable', '12: dog', '13: horse', '14: motorbike', '15: person', '16: pottedplant', '17: sheep', '18: sofa', '19: train', '20: tvmonitor']
# what labels were recognized?
np.unique(output_predictions.cpu().numpy())
array([ 0, 8, 12], dtype=int64)
# create a mask
mask = torch.zeros_like(output_predictions).float().to(device)
mask[output_predictions == 12] = 1 # 12 is dog
# mask[output_predictions == 15] = 1 # 15 is person
masked_img = input_image * mask.unsqueeze(2).byte().cpu().numpy()
fig = plt.figure(figsize=(15,15))
ax = fig.add_subplot(111)
ax.imshow(masked_img)
ax.set_axis_off()