! pip install pytorch-lightning --quiet import os import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy class MNISTModel(pl.LightningModule): def __init__(self): super(MNISTModel, self).__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_nb): x, y = batch loss = F.cross_entropy(self(x), y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) # Init our model mnist_model = MNISTModel() # Init DataLoader from MNIST Dataset train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) train_loader = DataLoader(train_ds, batch_size=32) # Initialize a trainer trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) # Train the model ⚡ trainer.fit(mnist_model, train_loader) class LitMNIST(pl.LightningModule): def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4): super().__init__() # Set our init args as class attributes self.data_dir = data_dir self.hidden_size = hidden_size self.learning_rate = learning_rate # Hardcode some dataset specific attributes self.num_classes = 10 self.dims = (1, 28, 28) channels, width, height = self.dims self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Define PyTorch model self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes) ) def forward(self, x): x = self.model(x) return F.log_softmax(x, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) # Calling self.log will surface up scalars for you in TensorBoard self.log('val_loss', loss, prog_bar=True) self.log('val_acc', acc, prog_bar=True) return loss def test_step(self, batch, batch_idx): # Here we just reuse the validation_step for testing return self.validation_step(batch, batch_idx) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer #################### # DATA RELATED HOOKS #################### def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=32) def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) model = LitMNIST() trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) trainer.fit(model) trainer.test() trainer.fit(model) # Start tensorboard. %load_ext tensorboard %tensorboard --logdir lightning_logs/