#!/usr/bin/env python # coding: utf-8 # Open In Colab # # Introduction to Pytorch Lightning ⚡ # # In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database). # # --- # - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/) # - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/) # - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A) # ### Setup # Lightning is easy to install. Simply ```pip install pytorch-lightning``` # In[ ]: get_ipython().system(' pip install pytorch-lightning --quiet') # In[2]: import os import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy # ## Simplest example # # Here's the simplest most minimal example with just a training loop (no validation, no testing). # # **Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features. # In[3]: class MNISTModel(pl.LightningModule): def __init__(self): super(MNISTModel, self).__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_nb): x, y = batch loss = F.cross_entropy(self(x), y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) # By using the `Trainer` you automatically get: # 1. Tensorboard logging # 2. Model checkpointing # 3. Training and validation loop # 4. early-stopping # In[ ]: # Init our model mnist_model = MNISTModel() # Init DataLoader from MNIST Dataset train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) train_loader = DataLoader(train_ds, batch_size=32) # Initialize a trainer trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) # Train the model ⚡ trainer.fit(mnist_model, train_loader) # ## A more complete MNIST Lightning Module Example # # That wasn't so hard was it? # # Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST... # # This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it. # # --- # # ### Note what the following built-in functions are doing: # # 1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾 # - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there. # - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`) # # 2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️ # - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). # - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'. # - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals). # - **Note this runs across all GPUs and it *is* safe to make state assignments here** # # 3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️ # - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()` # In[5]: class LitMNIST(pl.LightningModule): def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4): super().__init__() # Set our init args as class attributes self.data_dir = data_dir self.hidden_size = hidden_size self.learning_rate = learning_rate # Hardcode some dataset specific attributes self.num_classes = 10 self.dims = (1, 28, 28) channels, width, height = self.dims self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Define PyTorch model self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes) ) def forward(self, x): x = self.model(x) return F.log_softmax(x, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) # Calling self.log will surface up scalars for you in TensorBoard self.log('val_loss', loss, prog_bar=True) self.log('val_acc', acc, prog_bar=True) return loss def test_step(self, batch, batch_idx): # Here we just reuse the validation_step for testing return self.validation_step(batch, batch_idx) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer #################### # DATA RELATED HOOKS #################### def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=32) def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) # In[ ]: model = LitMNIST() trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) trainer.fit(model) # ### Testing # # To test a model, call `trainer.test(model)`. # # Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss). # In[ ]: trainer.test() # ### Bonus Tip # # You can keep calling `trainer.fit(model)` as many times as you'd like to continue training # In[ ]: trainer.fit(model) # In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you! # In[ ]: # Start tensorboard. get_ipython().run_line_magic('load_ext', 'tensorboard') get_ipython().run_line_magic('tensorboard', '--logdir lightning_logs/') # #

Congratulations - Time to Join the Community!

#
# # Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways! # # ### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub # The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building. # # * Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) # # ### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)! # The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel # # ### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) # Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects. # # * Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) # # ### Contributions ! # The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for "good first issue". # # * [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) # * [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) # * You can also contribute your own notebooks with useful examples ! # # ### Great thanks from the entire Pytorch Lightning Team for your interest ! # #