In this notebook, we'll go over the basics of lightning by preparing models to train on the MNIST Handwritten Digits dataset.
Lightning is easy to install. Simply pip install pytorch-lightning
! pip install pytorch-lightning --quiet
import os
import torch
from torch import nn
from torch.nn import functional as F
from import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
Here's the simplest most minimal example with just a training loop (no validation, no testing).
Keep in Mind - A LightningModule
is a PyTorch nn.Module
- it just has a few more helpful features.
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
By using the Trainer
you automatically get:
# Init our model
mnist_model = MNISTModel()
# Init DataLoader from MNIST Dataset
train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=32)
# Initialize a trainer
trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)
# Train the model ⚡, train_loader)
That wasn't so hard was it?
Now that we've got our feet wet, let's dive in a bit deeper and write a more complete LightningModule
for MNIST...
This time, we'll bake in all the dataset specific pieces directly in the LightningModule
. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.
dataset class to download if the dataset isn't found there.self.something = ...
)setup(stage) ⚙️
is passed to stage
(or ignore it altogether and exclude any conditionals).train_dataloader()
, val_dataloader()
, and test_dataloader()
all return PyTorch DataLoader
instances that are created by wrapping their respective datasets that we prepared in setup()
class LitMNIST(pl.LightningModule):
def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):
# Set our init args as class attributes
self.data_dir = data_dir
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Hardcode some dataset specific attributes
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
# Define PyTorch model
self.model = nn.Sequential(
nn.Linear(channels * width * height, hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, self.num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
model = LitMNIST()
trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)
To test a model, call trainer.test(model)
Or, if you've just trained a model, you can just call trainer.test()
and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss).
You can keep calling
as many times as you'd like to continue training
In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
Congratulations - Time to Join the Community!
Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general
Bolts has a collection of state-of-the-art models, all implemented in Lightning and can be easily integrated within your own projects.
The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for "good first issue".