With the release of pytorch-lightning
version 0.9.0, we have included a new class called LightningDataModule
to help you decouple data related hooks from your LightningModule
This notebook will walk you through how to start using Datamodules.
The most up to date documentation on datamodules can be found here.
Lightning is easy to install. Simply pip install pytorch-lightning
! pip install pytorch-lightning --quiet
First, we'll go over a regular LightningModule
implementation without the use of a LightningDataModule
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
Below, we reuse a LightningModule
from our hello world tutorial that classifies MNIST Handwritten Digits.
Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢
This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets.
class LitMNIST(pl.LightningModule):
def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):
# We hardcode dataset specific stuff here.
self.data_dir = data_dir
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Build model
self.model = nn.Sequential(
nn.Linear(channels * width * height, hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, self.num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
model = LitMNIST()
trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)
DataModules are a way of decoupling data-related hooks from the LightningModule
so you can develop dataset agnostic models.
Let's go over each function in the class below and talk about what they're doing:
arg that points to where you have downloaded/wish to download the MNIST dataset.self.dims
, which is a tuple returned from datamodule.size()
that can help you initialize models.prepare_data
dataset class to download if the dataset isn't found there.self.something = ...
is passed to stage
, val_dataloader()
, and test_dataloader()
all return PyTorch DataLoader
instances that are created by wrapping their respective datasets that we prepared in setup()
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
¶Below, we define the same model as the LitMNIST
model we made earlier.
However, this time our model has the freedom to use any input data that we'd like 🔥.
class LitModel(pl.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
# We take in input dimensions as parameters and use those to dynamically build model.
self.channels = channels
self.width = width
self.height = height
self.num_classes = num_classes
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.model = nn.Sequential(
nn.Linear(channels * width * height, hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
using the MNISTDataModule
¶Now, we initialize and train the LitModel
using the MNISTDataModule
's configuration settings and dataloaders.
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)
Lets prove the LitModel
we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset.
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.dims = (3, 32, 32)
self.num_classes = 10
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=32)
using the CIFAR10DataModule
¶Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.
The point here is that we can see that our LitModel
has no problem using a different datamodule as its input data.
dm = CIFAR10DataModule()
model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)
trainer.fit(model, dm)
Congratulations - Time to Join the Community!
Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general
Bolts has a collection of state-of-the-art models, all implemented in Lightning and can be easily integrated within your own projects.
The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for "good first issue".