This example follows Torch's transfer learning tutorial. We will
The primary focus is using a Dask cluster for batch prediction.
Note that the base environment on the examples.dask.org Binder does not include PyTorch or torchvision. To run this example, you'll need to run
!conda install -y pytorch-cpu torchvision
which will take a bit of time to run.
The PyTorch documentation hosts a small set of data. We'll download and extract it locally.
import urllib.request import zipfile
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip") zipfile.ZipFile(filename).extractall()
The directory looks like
hymenoptera_data/ train/ ants/ 0013035.jpg ... 1030023514_aad5c608f9.jpg bees/ 1092977343_cb42b38d62.jpg ... 2486729079_62df0920be.jpg train/ ants/ 0013025.jpg ... 1030023514_aad5c606d9.jpg bees/ 1092977343_cb42b38e62.jpg ... 2486729079_62df0921be.jpg
import torchvision from tutorial_helper import (imshow, train_model, visualize_model, dataloaders, class_names, finetune_model)
Our base model is resnet18. It predicts for 1,000 categories, while ours just predicts 2 (ants or bees). To make this model train quickly on examples.dask.org, we'll only use a couple of epochs.
%%time model = finetune_model()
Epoch 0/1 ---------- train Loss: 0.6196 Acc: 0.6844 val Loss: 0.2042 Acc: 0.9281 Epoch 1/1 ---------- train Loss: 0.4517 Acc: 0.7787 val Loss: 0.1458 Acc: 0.9477 Training complete in 0m 4s Best val Acc: 0.947712 CPU times: user 3.92 s, sys: 2.03 s, total: 5.95 s Wall time: 6.33 s
Things seem OK on a few random images:
Now for the main topic: using a pretrained model for batch prediction on a Dask cluster. There are two main complications, that both deal with minimizing the amount of data moved around:
dask.delayedto load the data on the workers, rather than loading it on the client and sending it to the workers.
from distributed import Client client = Client(n_workers=2, threads_per_worker=2) client
import glob import toolz import dask import dask.array as da import torch from torchvision import transforms from PIL import Image @dask.delayed def load(path, fs=__builtins__): with fs.open(path, 'rb') as f: img = Image.open(f).convert("RGB") return img @dask.delayed def transform(img): trn = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return trn(img)
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]
To load the data from cloud storage, say Amazon S3, you would use
import s3fs fs = s3fs.S3FileSystem(...) objs = [load(x, fs=fs) for x in fs.glob(...)]
The PyTorch model expects tensors of a specific shape, so let's transform them.
tensors = [transform(x) for x in objs]
And the model expects batches of inputs, so let's stack a few together.
batches = [dask.delayed(torch.stack)(batch) for batch in toolz.partition_all(10, tensors)] batches[:5]
[Delayed('stack-da59d324-464a-4dce-adfa-0dc99dc53299'), Delayed('stack-939f881b-58ba-4bb5-b4eb-1df6ccfa850f'), Delayed('stack-e3809d5d-84f2-4279-a1a6-71131f4d2c53'), Delayed('stack-a172c545-7cdd-467f-a2bc-e5c5ae611d50'), Delayed('stack-8698c88b-6e05-442d-8346-8af67d0992ae')]
Finally, we'll write a small
predict helper to predict the output class (0 or 1).
@dask.delayed def predict(batch, model): with torch.no_grad(): out = model(batch) _, predicted = torch.max(out, 1) predicted = predicted.numpy() return predicted
PyTorch neural networks are large, so we don't want to repeat it many times in our task graph (once per batch).
import pickle dask.utils.format_bytes(len(pickle.dumps(model)))
Instead, we'll also wrap the model itself in
dask.delayed. This means the model only shows up once in the Dask graph.
Additionally, since we performed fine-tuning in the above (and that runs on a GPU if its available), we should move the model back to the CPU.
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU
Now we'll use the (delayed)
predict method to get our predictions.
predictions = [predict(batch, dmodel) for batch in batches] dask.visualize(predictions[:2])