In this tutorial, we'll learn how to apply a land cover classification model to imagery hosted in the Planetary Computer Data Catalog. In the process, we'll see how to:
If you're running this on the Planetary Computer Hub, make sure to choose the GPU - PyTorch profile when presented with the form to choose your environment.
We'll work with NAIP data, a collection of high-resolution aerial imagery covering the continental US. We'll apply a PyTorch model trained for land cover classification to the data. The model takes in an image and classifies each pixel into a category (e.g. "water", "tree canopy", "road", etc.). We're using a neural network trained by data scientists from Microsoft's AI for Good program. We'll use the model to analyze how land cover changed over a portion of Maryland from 2013 to 2017.
This is a somewhat large computation, and we'll handle the scale in two ways:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(threads_per_worker=4)
client = Client(cluster)
print(f"/proxy/{client.scheduler_info()['services']['dashboard']}/status")
2022-09-20 16:11:38,787 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
/proxy/8787/status
Make sure to open the Dask Dashboard, either by clicking the Dashboard link or by using the dask-labextension to lay out your workspace (See Scale with Dask.).
Next, we'll load the model. It's available in a public Azure Blob Storage container. We'll download the model locally and construct the Unet using segmentation_models_pytorch
. Deep learning models can be somewhat large and difficult to seralize, so we'll make sure to load it directly on the worker using Client.submit
. This returns a Future
pointing to the model, which we'll use later on.
import azure.storage.blob
from pathlib import Path
import segmentation_models_pytorch
import torch
import warnings
# ignore SyntaxWarning in pretrainedmodels
warnings.filterwarnings("ignore", category=SyntaxWarning)
def load_model():
p = Path("unet_both_lc.pt")
if not p.exists():
blob_client = azure.storage.blob.BlobClient(
account_url="https://naipeuwest.blob.core.windows.net/",
container_name="naip-models",
blob_name="unet_both_lc.pt",
)
with p.open("wb") as f:
f.write(blob_client.download_blob().readall())
model = segmentation_models_pytorch.Unet(
encoder_name="resnet18",
encoder_depth=3,
encoder_weights=None,
decoder_channels=(128, 64, 64),
in_channels=4,
classes=13,
)
model.load_state_dict(torch.load("unet_both_lc.pt", map_location="cuda:0"))
device = torch.device("cuda")
model = model.to(device)
return model
remote_model = client.submit(load_model)
print(remote_model)
<Future: pending, key: load_model-44ec68adfa36409075c18d5e0cd519cf>
Suppose we've been tasked with analyzing how land use changed from 2013 to 2017 for a region of Maryland. The full NAIP dataset consists of millions of images. How do we find the few hundred files that we care about?
With the Planetary Computer's metadata query API, that's straightforward. First, we'll define our area of interest as a bounding box.
bbox = [-77.9754638671875, 38.58037909468592, -76.37969970703125, 39.812755695478124]
Next, we'll use pystac_client
to query the Planetary Computer's STAC endpoint. We'll filter the results by space (to return only images touching our area of interest) and time (to return a set of images from 2013, and a second set for 2017).
import pystac_client
api = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1/")
search_2013 = api.search(
bbox=bbox,
datetime="2012-12-31T00:00:00Z/2014-01-01T00:00:00Z",
collections=["naip"],
)
search_2017 = api.search(
bbox=bbox,
datetime="2016-12-31T00:00:00Z/2018-01-01T00:00:00Z",
collections=["naip"],
)
Each item in those results is a single stac Item
, which includes URLs to cloud-optimized GeoTIFF files stored in Azure Blob Storage.
We have URLs to many files in Blob Storage. We want to treat all those as one big, logical dataset, so we'll use some open-source libraries to stitch them all together.
stac-vrt will take a collection of STAC items and efficiently build a GDAL VRT.
import stac_vrt
data_2013 = search_2013.get_all_items_as_dict()["features"]
data_2017 = search_2017.get_all_items_as_dict()["features"]
print("2013:", len(data_2013), "items")
print("2017:", len(data_2017), "items")
naip_2013 = stac_vrt.build_vrt(
data_2013, block_width=512, block_height=512, data_type="Byte"
)
mosaic_2017 = stac_vrt.build_vrt(
data_2017, block_width=512, block_height=512, data_type="Byte"
)
2013: 440 items 2017: 443 items
Once we have a pair of VRTs (one per year), we use rasterio.warp to make sure they're aligned.
import rasterio
a = rasterio.open(naip_2013)
naip_2017 = rasterio.vrt.WarpedVRT(
rasterio.open(mosaic_2017), transform=a.transform, height=a.height, width=a.width
)
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray
ds1 = rioxarray.open_rasterio(naip_2013, chunks=(4, 8192, 8192), lock=False)
ds2 = rioxarray.open_rasterio(naip_2017, chunks=(4, 8192, 8192), lock=False)
ds = xr.concat([ds1, ds2], dim=pd.Index([2013, 2017], name="time"))
ds
<xarray.DataArray (time: 2, band: 4, y: 149498, x: 145987)> dask.array<concatenate, shape=(2, 4, 149498, 145987), dtype=uint8, chunksize=(1, 4, 8192, 8192), chunktype=numpy.ndarray> Coordinates: * band (band) int64 1 2 3 4 * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06 spatial_ref int64 0 * time (time) int64 2013 2017 Attributes: scale_factor: 1.0 add_offset: 0.0
Now we have a big dataset, that's been pixel-aligned on a grid for the two time periods. The model requires a bit of pre-processing upfront. We'll define a couple variables with the per-band mean and standard deviation for each year.
bands = xr.DataArray(
[1, 2, 3, 4], name="band", dims=["band"], coords={"band": [1, 2, 3, 4]}
)
NAIP_2013_MEANS = xr.DataArray(
np.array([117.00, 130.75, 122.50, 159.30], dtype="float32"),
name="mean",
coords=[bands],
)
NAIP_2013_STDS = xr.DataArray(
np.array([38.16, 36.68, 24.30, 66.22], dtype="float32"),
name="mean",
coords=[bands],
)
NAIP_2017_MEANS = xr.DataArray(
np.array([72.84, 86.83, 76.78, 130.82], dtype="float32"),
name="std",
coords=[bands],
)
NAIP_2017_STDS = xr.DataArray(
np.array([41.78, 34.66, 28.76, 58.95], dtype="float32"),
name="mean",
coords=[bands],
)
mean = xr.concat([NAIP_2013_MEANS, NAIP_2017_MEANS], dim="time")
std = xr.concat([NAIP_2013_STDS, NAIP_2017_STDS], dim="time")
With those constants defined, we can normalize the data by subtracting the mean and dividing by the standard deviation. We'll also fix an issue the model had with partial chunks by dropping some pixels from the bottom-right corner.
# Normalize by per-year mean, std
normalized = (ds - mean) / std
# The Unet model doesn't like partial chunks, so we chop off the
# last 1-31 pixels.
slices = {}
for coord in ["y", "x"]:
remainder = len(ds.coords[coord]) % 32
slice_ = slice(-remainder) if remainder else slice(None)
slices[coord] = slice_
normalized = normalized.isel(**slices)
normalized
<xarray.DataArray (time: 2, band: 4, y: 149472, x: 145984)> dask.array<getitem, shape=(2, 4, 149472, 145984), dtype=float32, chunksize=(1, 4, 8192, 8192), chunktype=numpy.ndarray> Coordinates: * band (band) int64 1 2 3 4 * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06 spatial_ref int64 0 * time (time) int64 2013 2017
At this point, we're ready to make predictions.
We'll apply the model to the entire dataset, taking care to not over-saturate the GPUs. The GPUs will work on relatively small "chips" which fit comfortably in memory. The prediction, which comes from model(data)
, will happen on the GPU so that it's nice and fast.
Stepping up a level, we have Dask chunks. This is just a regular NumPy array. We'll break each chunk into a bunch of chips (using dask.array.core.slices_from_chunks
) and get a prediction for each chip.
import dask.array
def predict_chip(data: torch.Tensor, model) -> torch.Tensor:
# Input is GPU, output is GPU.
with torch.no_grad():
result = model(data).argmax(dim=1).to(torch.uint8)
return result.to("cpu")
def copy_and_predict_chunked(tile, model, token=None):
has_time = tile.ndim == 4
if has_time:
assert tile.shape[0] == 1
tile = tile[0]
slices = dask.array.core.slices_from_chunks(dask.array.empty(tile.shape).chunks)
out = np.empty(shape=tile.shape[1:], dtype="uint8")
device = torch.device("cuda")
for slice_ in slices:
gpu_chip = torch.as_tensor(tile[slice_][np.newaxis, ...]).to(device)
out[slice_[1:]] = predict_chip(gpu_chip, model).cpu().numpy()[0]
if has_time:
out = out[np.newaxis, ...]
return out
Stepping up yet another level, we'll apply the predictions to the entire xarray DataArray. We'll use DataArray.map_blocks
to do the prediction in parallel.
meta = np.array([[]], dtype="uint8")[:0]
predictions_array = normalized.data.map_blocks(
copy_and_predict_chunked,
meta=meta,
drop_axis=1,
model=remote_model,
name="predict",
)
predictions = xr.DataArray(
predictions_array,
coords=normalized.drop_vars("band").coords,
dims=("time", "y", "x"),
)
predictions
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 149472, x: 145984)> dask.array<predict, shape=(2, 149472, 145984), dtype=uint8, chunksize=(1, 8192, 8192), chunktype=numpy.ndarray> Coordinates: * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06 spatial_ref int64 0 * time (time) int64 2013 2017
So there's three levels:
We can kick off a computation by calling predictions.persist()
. This should cause some activity on your Dask Dashboard.
predictions[:, :200, :200].compute()
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 200, x: 200)> array([[[1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], ..., [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1]], [[3, 3, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3], ..., [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8) Coordinates: * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06 spatial_ref int64 0 * time (time) int64 2013 2017
Each element of predictions
is an integer encoding the class the PyTorch model things the pixel belongs to (tree canopy, building, water, etc.).
Finally, we can compute the result we're interested in: Which pixels (spots on the earth) changed land cover over the four years, at least according to our model.
change = predictions.sel(time=2013) != predictions.sel(time=2017)
change
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (y: 149472, x: 145984)> dask.array<ne, shape=(149472, 145984), dtype=bool, chunksize=(8192, 8192), chunktype=numpy.ndarray> Coordinates: * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06 spatial_ref int64 0
That's a boolean array where True
means "this location changed". We'll mask out the predictions
array with change
. The value other=0
means "no change", so changed_predictions
has just the predictions (the integer codes) where there was a change.
changed_predictions = predictions.where(change, other=0)
changed_predictions
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 149472, x: 145984)> dask.array<where, shape=(2, 149472, 145984), dtype=uint8, chunksize=(1, 8192, 8192), chunktype=numpy.ndarray> Coordinates: * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06 spatial_ref int64 0 * time (time) int64 2013 2017
Again, we can kick off some computation.
changed_predictions[:, :200, :200].compute()
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 200, x: 200)> array([[[1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], [[3, 3, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8) Coordinates: * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05 * y (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06 spatial_ref int64 0 * time (time) int64 2013 2017
Now let's do some visual spot checking of our model. This does require processing the full-resolution images, so we need to limit things to something that fits in memory now.
middle = ds.shape[2] // 2, ds.shape[3] // 2
slice_y = slice(middle[0], middle[0] + 5_000)
slice_x = slice(middle[1], middle[1] + 5_000)
parts = [x.isel(y=slice_y, x=slice_x) for x in [ds, predictions, changed_predictions]]
ds_local, predictions_local, changed_predictions_local = dask.compute(*parts)
import matplotlib.colors
from bokeh.models.tools import BoxZoomTool
import panel
import hvplot.xarray # noqa
cmap = matplotlib.colors.ListedColormap(
np.array(
[
(0, 0, 0),
(0, 197, 255),
(0, 168, 132),
(38, 115, 0),
(76, 230, 0),
(163, 255, 115),
(255, 170, 0),
(255, 0, 0),
(156, 156, 156),
(0, 0, 0),
(115, 115, 0),
(230, 230, 0),
(255, 255, 115),
(197, 0, 255),
]
)
/ 255
)
def logo(plot, element):
plot.state.toolbar.logo = None
zoom = BoxZoomTool(match_aspect=True)
style_kwargs = dict(
width=450,
height=400,
xaxis=False,
yaxis=False,
)
kwargs = dict(
x="x",
y="y",
cmap=cmap,
rasterize=True,
aggregator="mode",
colorbar=False,
tools=["pan", zoom, "wheel_zoom", "reset"],
clim=(0, 12),
)
image_2013_plot = (
ds_local.sel(time=2013)
.hvplot.rgb(
bands="band",
x="x",
y="y",
rasterize=True,
title="NAIP 2013",
hover=False,
**style_kwargs,
)
.opts(default_tools=[], hooks=[logo])
)
classification_2013_plot = (
changed_predictions_local.sel(time=2013)
.hvplot.image(title="Classification 2013", **kwargs, **style_kwargs)
.opts(default_tools=[])
)
image_2017_plot = (
ds_local.sel(time=2017)
.hvplot.rgb(
bands="band",
x="x",
y="y",
rasterize=True,
title="NAIP 2017",
hover=False,
**style_kwargs,
)
.opts(default_tools=[], hooks=[logo])
)
classification_2017_plot = (
changed_predictions_local.sel(time=2013)
.hvplot.image(title="Classification 2017", **kwargs, **style_kwargs)
.opts(default_tools=[])
)
panel.GridBox(
image_2013_plot,
classification_2013_plot,
image_2017_plot,
classification_2017_plot,
ncols=2,
)
That visualization uses Panel, a Python dashboarding library. In an interactive Jupyter Notebook you can pan and zoom around the large dataset.
This example created a local Dask "cluster" on this single node. You can scale your computation out to a true GPU cluster with Dask Gateway by setting the gpu=True
option when creating a cluster.
import dask_gateway
N_WORKERS = 2
g = dask_gateway.Gateway()
options = g.cluster_options()
options["gpu"] = True
options["worker_memory"] = 25
options["worker_cores"] = 3
options["environment"] = {
"DASK_DISTRIBUTED__WORKERS__RESOURCES__GPU": "1",
}
cluster = g.new_cluster(options)
client = cluster.get_client()
cluster.scale(N_WORKERS)