In this notebook, we will see how to deal with data in PyTorch using Dataset and DataLoaders.
import torch
from torch.utils.data import Dataset
To work with data, PyTorch provides a Dataset class that can be subclassed.
A dataset is an object that can be queried with an index and that will return the corresponding sample.
It should implement two functions:
__len__
: this should return the size of the dataset__getitem__
: this should return one sample from the dataset
class DummyDataset(Dataset):
def __init__(self):
self.data = torch.rand(10, 2)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
label = sample[0] > sample[1]
return (sample, label)
dataset = DummyDataset()
dataset.data
When indexed, the dataset returns tuple (train data, class label)
dataset[1]
A DataLoader
is a PyTorch utility class to iterate over the dataset.
It allows multi-process data loading, automatic batching, shuffling and more.
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=5, shuffle=True)
for sample, label in loader:
print(sample, label, sep="\n")
break
Use multiple workers to load data in parallel: Find how in the PyTorch doc.
loader = DataLoader(dataset, batch_size=5, shuffle=True)
Install the dependencies, download the dataset and place it in data/alien-vs-predator folder.
!wget -q https://raw.githubusercontent.com/theevann/amld-pytorch-workshop/master/binder/requirements.txt -O requirements.txt
!pip install -qr requirements.txt
!mkdir -p data
!curl -Lo alien-vs-predator.zip "https://docs.google.com/uc?export=download&id=1hct3PjRf14ZBp83ob3f6Uo_0mqrT9FGZ"
!unzip -oq alien-vs-predator.zip -d data/
!rm alien-vs-predator.zip
!ls -l data/alien-vs-predator/
# for PIL.Image
!pip install --no-cache-dir -I Pillow==7.1.2
The dataset is located in data/alien-vs-predator
!tree -nd ./data/alien-vs-predator # This command will not work on colab
Each directory contains images of the corresponding class:
from PIL import Image
img_predator = Image.open("./data/alien-vs-predator/train/predator/10.jpg").convert('RGB')
img_alien = Image.open("./data/alien-vs-predator/train/alien/10.jpg").convert('RGB')
img_alien
img_predator
The code below is implementing a Dataset class for these images.
It loads all the image paths and add it in the img_instance
variable along with a label.
The alien class has label 0 while the predator class has label 1.
This code is incomplete: you need to fill the __len__
and __get_item__
functions.
You can use this snippet to load an image from a path
:
with open(path, 'rb') as f:
img = Image.open(f).convert('RGB')
from pathlib import Path
from PIL import Image
class AlienPredatorDataset(Dataset):
def __init__(self, root, split):
self.root = root
self.split = split
# Load and save all image paths
self.img_instances = []
for img_path in Path(root, split, "alien").glob("*.jpg"):
self.img_instances.append((img_path, 0))
for img_path in Path(root, split, "predator").glob("*.jpg"):
self.img_instances.append((img_path, 1))
def __len__(self):
return # YOUR TURN
def __getitem__(self, index):
# YOUR TURN
return (img, target)
dataset = AlienPredatorDataset("./data/alien-vs-predator/", "train")
len(dataset)
dataset[0] # Here again it returns a tuple (image, class label)
dataset[0][0]
dataset[1][0]
Note that we get PIL images that are of different sizes.
To create proper PyTorch batches, we need to input tensors that have the same size.
To do so, we will use Torchvision transforms.
from torchvision.transforms import ToTensor, RandomCrop
crop_transform = RandomCrop(100)
img = dataset[0][0]
img
crop_transform(img)
from torchvision.transforms import Compose
all_transforms = Compose((
RandomCrop(100),
ToTensor(),
))
all_transforms(img)
all_transforms(img).shape
Let's apply it to our dataset !
loader = DataLoader(dataset, batch_size=5, shuffle=True) # workers
for sample, label in loader:
print(sample.shape, label)
break
Torchvision provides many more useful classes to deal with images.
Specifically, as image classification is a pretty common computer vision task, torchvision provides a dataset named ImageFolder
that loads images given a folder (the subfolders are splitting the different classes).
from torchvision.datasets import ImageFolder
dataset = ImageFolder(root="./data/alien-vs-predator/train", transform=all_transforms)
dataset[0]