Using the
notebook_launcher
to use Accelerate from inside a Jupyter Notebook
This notebook covers how to run the cv_example.py
script as a Jupyter Notebook and train it on a distributed system. It will also cover the few specific requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training.
Before any training can be performed, an accelerate config file must exist in the system. Usually this can be done by running the following in a terminal:
accelerate config
However, if general defaults are fine and you are not running on a TPU, accelerate has a utility to quickly write your GPU configuration into a config file via write_basic_config
.
The following cell will restart Jupyter after writing the configuration, as CUDA code was called to perform this. CUDA can't be initialized more than once (once for the single-GPU's notebooks use by default, and then what would be again when notebook_launcher
is called). It's fine to debug in the notebook and have calls to CUDA, but remember that in order to finally train a full cleanup and restart will need to be performed, such as what is shown below:
#import os
#from accelerate.utils import write_basic_config
#write_basic_config() # Write a config file
#os._exit(00) # Restart the notebook
Next you should prepare your dataset. As mentioned at earlier, great care should be taken when preparing the DataLoaders
and model to make sure that nothing is put on any GPU.
If you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later.
Make sure the dataset is downloaded based on the directions here
import os, re, torch, PIL
import numpy as np
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
from accelerate import Accelerator
from accelerate.utils import set_seed
from timm import create_model
First we'll create a function to extract the class name based on a file:
import os
data_dir = "../../images"
fnames = os.listdir(data_dir)
fname = fnames[0]
print(fname)
beagle_32.jpg
In the case here, the label is beagle
:
import re
def extract_label(fname):
stem = fname.split(os.path.sep)[-1]
return re.search(r"^(.*)_\d+\.jpg$", stem).groups()[0]
extract_label(fname)
'beagle'
Next we'll create a Dataset
class:
class PetsDataset(Dataset):
def __init__(self, file_names, image_transform=None, label_to_id=None):
self.file_names = file_names
self.image_transform = image_transform
self.label_to_id = label_to_id
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
fname = self.file_names[idx]
raw_image = PIL.Image.open(fname)
image = raw_image.convert("RGB")
if self.image_transform is not None:
image = self.image_transform(image)
label = extract_label(fname)
if self.label_to_id is not None:
label = self.label_to_id[label]
return {"image": image, "label": label}
And build our dataset
# Grab all the image filenames
fnames = [
os.path.join(data_dir, fname)
for fname in fnames
if fname.endswith(".jpg")
]
# Build the labels
all_labels = [
extract_label(fname)
for fname in fnames
]
id_to_label = list(set(all_labels))
id_to_label.sort()
label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}
Note: This will be stored inside of a function as we'll be setting our seed during training.
def get_dataloaders(batch_size:int=64):
"Builds a set of dataloaders with a batch_size"
random_perm = np.random.permutation(len(fnames))
cut = int(0.8 * len(fnames))
train_split = random_perm[:cut]
eval_split = random_perm[:cut]
# For training we use a simple RandomResizedCrop
train_tfm = Compose([
RandomResizedCrop((224, 224), scale=(0.5, 1.0)),
ToTensor()
])
train_dataset = PetsDataset(
[fnames[i] for i in train_split],
image_transform=train_tfm,
label_to_id=label_to_id
)
# For evaluation we use a deterministic Resize
eval_tfm = Compose([
Resize((224, 224)),
ToTensor()
])
eval_dataset = PetsDataset(
[fnames[i] for i in eval_split],
image_transform=eval_tfm,
label_to_id=label_to_id
)
# Instantiate dataloaders
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4
)
eval_dataloader = DataLoader(
eval_dataset,
shuffle=False,
batch_size=batch_size*2,
num_workers=4
)
return train_dataloader, eval_dataloader
Now we can build our training loop. notebook_launcher
works by passing in a function to call that will be ran across the distributed system.
Here is a basic training loop for our animal classification problem:
from torch.optim.lr_scheduler import CosineAnnealingLR
def training_loop(mixed_precision="fp16", seed:int=42, batch_size:int=64):
set_seed(seed)
# Initialize accelerator
accelerator = Accelerator(mixed_precision=mixed_precision)
# Build dataloaders
train_dataloader, eval_dataloader = get_dataloaders(batch_size)
# instantiate the model (we build the model here so that the seed also controls new weight initaliziations)
model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id))
# Freeze the base model
for param in model.parameters():
param.requires_grad=False
for param in model.get_classifier().parameters():
param.requires_grad=True
# We normalize the batches of images to be a bit faster
mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None]
std = torch.tensor(model.default_cfg["std"])[None, :, None, None]
# To make this constant available on the active device, we set it to the accelerator device
mean = mean.to(accelerator.device)
std = std.to(accelerator.device)
# Intantiate the optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr = 3e-2/25)
# Instantiate the learning rate scheduler
lr_scheduler = OneCycleLR(
optimizer=optimizer,
max_lr=3e-2,
epochs=5,
steps_per_epoch=len(train_dataloader)
)
# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# Now we train the model
for epoch in range(5):
model.train()
for step, batch in enumerate(train_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, batch["label"])
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.eval()
accurate = 0
num_elems = 0
for _, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"])
num_elems += accurate_preds.shape[0]
accurate += accurate_preds.long().sum()
eval_metric = accurate.item() / num_elems
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
All that's left is to use the notebook_launcher
.
We pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the documentation for more information)
from accelerate import notebook_launcher
args = ("fp16", 42, 64)
notebook_launcher(training_loop, args, num_processes=2)
Launching training on 2 GPUs. epoch 0: 88.12 epoch 1: 91.73 epoch 2: 92.58 epoch 3: 93.90 epoch 4: 94.71
And that's it!
This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:
notebook_launcher
num_processes
to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)