Welcome to the benchmark solution tutorial for our newest competition run in partnership with Microsoft AI for Earth and Cloud to Street.
In this challenge, you are tasked with detecting the presence of floodwater in Sentinel-1 global synthetic aperture radar (SAR) imagery. While traditional optical sensors measure reflected solar light, SAR operates in the microwave band of the electromagnetic spectrum. It is therefore uniquely positioned to provide a day-and-night supply of images of the Earth's surface through the coverage of clouds, smoke, vegetation, and precipitation. With the increasing prevalence of extreme weather events due to climate change, your machine learning models and insights have the potential to strengthen flood mapping algorithms and improve disaster risk management efforts around the world.
The training data for this competition consist of 542 "chips" from flood events around the world, which each contain two polarization bands. Each chip has a corresponding label mask that indicates which pixels in a scene contain water. You can optionally supplement this training data with information on permanent water sources and elevation, available through the Microsoft AI for Earth STAC API (see the competition STAC resources page for more information on bringing in this data). To enter this challenge, you will package and submit your code to generate predictions as single-band tifs with binary pixel values.
In this post, we'll cover:
Let's get started!
!pip install watermark
Collecting watermark Using cached watermark-2.2.0-py2.py3-none-any.whl (6.8 kB) Requirement already satisfied: ipython in /srv/conda/envs/notebook/lib/python3.8/site-packages (from watermark) (7.26.0) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (3.0.19) Requirement already satisfied: backcall in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (0.2.0) Requirement already satisfied: matplotlib-inline in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (0.1.2) Requirement already satisfied: decorator in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (5.0.9) Requirement already satisfied: setuptools>=18.5 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (57.4.0) Requirement already satisfied: jedi>=0.16 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (0.18.0) Requirement already satisfied: pickleshare in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (0.7.5) Requirement already satisfied: pexpect>4.3 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (4.8.0) Requirement already satisfied: pygments in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (2.10.0) Requirement already satisfied: traitlets>=4.2 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from ipython->watermark) (5.0.5) Requirement already satisfied: parso<0.9.0,>=0.8.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from jedi>=0.16->ipython->watermark) (0.8.2) Requirement already satisfied: ptyprocess>=0.5 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pexpect>4.3->ipython->watermark) (0.7.0) Requirement already satisfied: wcwidth in /srv/conda/envs/notebook/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->watermark) (0.2.5) Requirement already satisfied: ipython-genutils in /srv/conda/envs/notebook/lib/python3.8/site-packages (from traitlets>=4.2->ipython->watermark) (0.2.0) Installing collected packages: watermark Successfully installed watermark-2.2.0
%load_ext watermark
%watermark -v
Python implementation: CPython Python version : 3.8.10 IPython version : 7.26.0
%watermark -i
First we'll download the training data to a new subdirectory called training_data/
in our current working directory. If you are working on the Planetary Computer Hub, consider using Azure Blob Storage to store your data using the West Europe region, which is where your server is running. See Writing data to Azure Blob Storage for an example.
$ tree training_data/
training_data/
├── flood-training-metadata.csv
├── train_features
│ ├── awc00_vh.tif
│ ├── awc00_vv.tif
...
│ ├── wvy31_vh.tif
│ └── wvy31_vv.tif
└── train_labels
├── awc00.tif
├── awc01.tif
...
├── wvy30.tif
└── wvy31.tif
2 directories, 1627 files
To better understand what's in our data, we can begin by exploring trends across different locations and flood events. Let's load the metadata and look at what we have.
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# This is where our downloaded images and metadata live locally
DATA_PATH = Path("../../training_data")
FEATURE_PATH = DATA_PATH / "train_features"
LABEL_PATH = DATA_PATH / "train_labels"
train_metadata = pd.read_csv(
DATA_PATH / "flood-training-metadata.csv", parse_dates=["scene_start"]
)
train_metadata.head()
image_id | chip_id | flood_id | polarization | location | scene_start | |
---|---|---|---|---|---|---|
0 | awc00_vh | awc00 | awc | vh | Bolivia | 2018-02-15 |
1 | awc00_vv | awc00 | awc | vv | Bolivia | 2018-02-15 |
2 | awc01_vh | awc01 | awc | vh | Bolivia | 2018-02-15 |
3 | awc01_vv | awc01 | awc | vv | Bolivia | 2018-02-15 |
4 | awc02_vh | awc02 | awc | vh | Bolivia | 2018-02-15 |
The training data for this competition consist of two polarization bands that bring out distinct physical properties in a scene: VV (vertical transmit, vertical receive) and VH (vertical transmit, horizontal receive). Each band is identified by a unique image_id
, which is equivalent to {chip_id}_{polarization}
. A chip_id
consists of a three letter flood_id
and a two digit chip number. Each unique chip has both a _vv
and _vh
band. For example, awc00_vh
represents the vv
band of chip number 00
for flood awc
.
train_metadata.shape
(1084, 6)
train_metadata.chip_id.nunique()
542
Let's take a look at how many flood events are in the training data and how many chips we have per event.
flood_counts = train_metadata.groupby("flood_id")["chip_id"].nunique()
flood_counts.describe()
count 13.000000 mean 41.692308 std 19.947367 min 15.000000 25% 28.000000 50% 32.000000 75% 65.000000 max 69.000000 Name: chip_id, dtype: float64
The training data include chips from 13 flood events. We have anywhere from 15 to 69 chips (30 to 138 images) per unique event, with half of events containing fewer than 32 chips (64 images).
We can also take a look at the distribution of chips by location.
import seaborn as sns
plt.figure(figsize=(12, 4))
sns.countplot(
x=train_metadata.location,
order=train_metadata.location.value_counts().index,
color="lightgray",
)
plt.title("Number of Images by Location")
Text(0.5, 1.0, 'Number of Images by Location')
The data cover a wide geographic range. We have more than 60 chips (120 images) for floods in the United States, India, Paraguay, and Slovakia.
Let's also take a look at when the images were originally captured.
year = train_metadata.scene_start.dt.year
year_counts = train_metadata.groupby(year)["flood_id"].nunique()
year_counts
scene_start 2016 1 2017 2 2018 6 2019 3 2020 1 Name: flood_id, dtype: int64
train_metadata.groupby("flood_id")["scene_start"].nunique()
flood_id awc 1 ayt 1 coz 1 hbe 1 hxu 1 jja 1 kuo 1 pxs 1 qus 1 qxb 1 tht 1 tnp 1 wvy 1 Name: scene_start, dtype: int64
The training data cover flood events that occurred between 2016 and 2020. Images for each event were captured on the same day.
Next, we can begin exploring the image data!
A GeoTIFF is a raster image file that contains geographic metadata describing the location of the image. This metadata can include bounding coordinates, an affine transform, and a coordinate reference system (CRS) projection. The package rasterio makes it easy to interact with our geospatial raster data.
Recall that each set of two polarizations (vh and vh) correspond with a single water label.
import rasterio
# Examine an arbitrary image
image_path = FEATURE_PATH / f"{train_metadata.image_id[0]}.tif"
with rasterio.open(image_path) as img:
metadata = img.meta
bounds = img.bounds
data = img.read(1) # read a single band
metadata
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': 0.0, 'width': 512, 'height': 512, 'count': 1, 'crs': CRS.from_epsg(32720), 'transform': Affine(10.0, 0.0, 314030.0, 0.0, -10.0, 8585890.0)}
bounds
BoundingBox(left=314030.0, bottom=8580770.0, right=319150.0, top=8585890.0)
data
array([[-16.208015 , -17.71951 , -16.281353 , ..., 0. , 0. , 0. ], [-15.3288965, -18.231857 , -16.451893 , ..., 0. , 0. , 0. ], [-15.353134 , -16.88831 , -15.585904 , ..., 0. , 0. , 0. ], ..., [-15.741662 , -15.230668 , -13.455255 , ..., 0. , 0. , 0. ], [-15.498258 , -14.100984 , -13.11027 , ..., 0. , 0. , 0. ], [-16.055603 , -14.1121 , -14.76084 , ..., 0. , 0. , 0. ]], dtype=float32)
We will need to be able to identify pixels with missing data, since we will only be evaluated on predictions made for valid input pixels. The metadata tells us that a value of 0.0 represents missing data for an input image. In rasterio
, you can access two different kinds of missing data masks. The first mask is a GDAL-style mask, in which non-zero elements (typically 255) indicate that the corresponding data elements are valid.
with rasterio.open(image_path) as img:
gdal_mask = img.dataset_mask()
gdal_mask
array([[255, 255, 255, ..., 0, 0, 0], [255, 255, 255, ..., 0, 0, 0], [255, 255, 255, ..., 0, 0, 0], ..., [255, 255, 255, ..., 0, 0, 0], [255, 255, 255, ..., 0, 0, 0], [255, 255, 255, ..., 0, 0, 0]], dtype=uint8)
The second mask is a numpy masked array, which has the inverse sense: True
values indicate that the corresponding data elements are invalid. To load the data as a numpy masked array and access this type of missing data mask, simply pass a masked
flag to read
.
with rasterio.open(image_path) as img:
numpy_mask = img.read(1, masked=True)
numpy_mask
masked_array( data=[[-16.20801544189453, -17.71950912475586, -16.281352996826172, ..., --, --, --], [-15.328896522521973, -18.231857299804688, -16.451892852783203, ..., --, --, --], [-15.353134155273438, -16.888309478759766, -15.585904121398926, ..., --, --, --], ..., [-15.74166202545166, -15.230668067932129, -13.455254554748535, ..., --, --, --], [-15.498257637023926, -14.100983619689941, -13.110269546508789, ..., --, --, --], [-16.05560302734375, -14.112099647521973, -14.76084041595459, ..., --, --, --]], mask=[[False, False, False, ..., True, True, True], [False, False, False, ..., True, True, True], [False, False, False, ..., True, True, True], ..., [False, False, False, ..., True, True, True], [False, False, False, ..., True, True, True], [False, False, False, ..., True, True, True]], fill_value=0.0, dtype=float32)
Pixel values represent energy that was reflected back to the satellite measured in decibels. To better visualize the bands or channels of Sentinel-1 images, we will create a false color composite by treating the two bands and their ratio as red, grean, and blue channels, respectively. We will prepare a few helper functions to visualize the data.
import warnings
warnings.filterwarnings("ignore")
# Helper functions for visualizing Sentinel-1 images
def scale_img(matrix):
"""
Returns a scaled (H, W, D) image that is visually inspectable.
Image is linearly scaled between min_ and max_value, by channel.
Args:
matrix (np.array): (H, W, D) image to be scaled
Returns:
np.array: Image (H, W, 3) ready for visualization
"""
# Set min/max values
min_values = np.array([-23, -28, 0.2])
max_values = np.array([0, -5, 1])
# Reshape matrix
w, h, d = matrix.shape
matrix = np.reshape(matrix, [w * h, d]).astype(np.float64)
# Scale by min/max
matrix = (matrix - min_values[None, :]) / (
max_values[None, :] - min_values[None, :]
)
matrix = np.reshape(matrix, [w, h, d])
# Limit values to 0/1 interval
return matrix.clip(0, 1)
def create_false_color_composite(path_vv, path_vh):
"""
Returns a S1 false color composite for visualization.
Args:
path_vv (str): path to the VV band
path_vh (str): path to the VH band
Returns:
np.array: image (H, W, 3) ready for visualization
"""
# Read VV/VH bands
with rasterio.open(path_vv) as vv:
vv_img = vv.read(1)
with rasterio.open(path_vh) as vh:
vh_img = vh.read(1)
# Stack arrays along the last dimension
s1_img = np.stack((vv_img, vh_img), axis=-1)
# Create false color composite
img = np.zeros((512, 512, 3), dtype=np.float32)
img[:, :, :2] = s1_img.copy()
img[:, :, 2] = s1_img[:, :, 0] / s1_img[:, :, 1]
return scale_img(img)
def display_random_chip(random_state):
"""
Plots a 3-channel representation of VV/VH polarizations as a single chip (image 1).
Overlays a chip's corresponding water label (image 2).
Args:
random_state (int): random seed used to select a chip
Returns:
plot.show(): chip and labels plotted with pyplot
"""
f, ax = plt.subplots(1, 2, figsize=(9, 9))
# Select a random chip from train_metadata
random_chip = train_metadata.chip_id.sample(random_state=random_state).values[0]
# Extract paths to image files
vv_path = FEATURE_PATH / f"{random_chip}_vv.tif"
vh_path = FEATURE_PATH / f"{random_chip}_vh.tif"
label_path = LABEL_PATH / f"{random_chip}.tif"
# Create false color composite
s1_img = create_false_color_composite(vv_path, vh_path)
# Visualize features
ax[0].imshow(s1_img)
ax[0].set_title("S1 Chip", fontsize=14)
# Load water mask
with rasterio.open(label_path) as lp:
lp_img = lp.read(1)
# Mask missing data and 0s for visualization
label = np.ma.masked_where((lp_img == 0) | (lp_img == 255), lp_img)
# Visualize water label
ax[1].imshow(s1_img)
ax[1].imshow(label, cmap="cool", alpha=1)
ax[1].set_title("S1 Chip with Water Label", fontsize=14)
plt.tight_layout(pad=5)
plt.show()
Let's inspect a few chips and their water labels.
display_random_chip(7)
display_random_chip(66)
display_random_chip(90)
Pretty interesting!
You'll notice that some images contain high floodwater coverage, while others have little to no coverage. Water spread may be concentrated in one location or spread out, depending on a variety of surface characteristics like topography, terrain, and ground cover. Missing pixels are displayed in yellow.
Finally, let's confirm that the first few training images are the expected size of 512 x 512 pixels.
imgs = [FEATURE_PATH / f"{train_metadata.image_id[x]}.tif" for x in range(5)]
for img in imgs:
print(rasterio.open(img).shape)
(512, 512) (512, 512) (512, 512) (512, 512) (512, 512)
For this benchmark, we will train a first pass model using Sentinel-1 imagery only. You are encouraged to incorporate supplementary elevation and permanent water source data into your own model training using the Planetary Computer STAC API. See Fetching supplementary model input for an example.
The test set for this competition may include flood events not included in the training data. We do not want to overstate our model's performance by overfitting to one or more specific events. To be sure that our method is sufficiently generalizable, we will set aside a portion of the training imagery to validate the model during its development.
Since individual events are independent, we can randomly subset the training data into a training and validation using the flood_id
field. For this benchmark, we will holdout three flood events or approximately one third of available images for our validation set. You are encouraged to think carefully about the implications of different holdouts when determining how best to validate your model.
import random
random.seed(9) # set a seed for reproducibility
# Sample 3 random floods for validation set
flood_ids = train_metadata.flood_id.unique().tolist()
val_flood_ids = random.sample(flood_ids, 3)
val_flood_ids
['pxs', 'qxb', 'jja']
val_df = train_metadata[train_metadata.flood_id.isin(val_flood_ids)].copy()
train_df = train_metadata[~train_metadata.flood_id.isin(val_flood_ids)].copy()
# Confirm approx. 1/3 of chips are in the validation set
len(val_df) / (len(val_df) + len(train_df)) * 100
32.47232472324723
The goal of our first pass is to build a relatively simple model that takes radar imagery as input and outputs binary masks that indicate which pixels in a scene contain floodwater. Once we test this basic approach, we can attempt to improve our model by adding additional sophistication and complexity. We will use a lightweight PyTorch wrapper called PyTorch Lightning for this benchmark solution.
Rather than train an entire convolutional neural network (CNN) from scratch, we will tune a U-Net model for semantic segmentation. U-Net was first designed in 2015 to localize abnormalities in biomedical images. By applying a variety of data augmentation techniques, it can achieve high performance using relatively small training datasets. U-Net can be broadly thought of as an encoder network followed by a decoder network, where the encoder can be initialized using a pretrained backbone. For this exercise, we will initialize our encoder using ResNet34. ResNet was prepared by Microsoft Research Asia in 2015 and is pretrained on the ImageNet dataset.
First, we will need to read the training data into memory, convert the data to PyTorch tensors, and serve the data to our model in batches. Luckily, the PyTorch Dataset
and DataLoader
classes make implementing these complex tasks relatively straightforward. A Dataset
object allows us to define custom methods for working with the data, and a DataLoader
object parallelizes data loading. If you haven't worked with these classes before, we highly recommend this short tutorial.
Our custom dataset will inherit an abstract class called torch.utils.data.Dataset
and override the following methods:
__len__()
: returns the length of the dataset, measured as number of samples__getitem__()
: provided as index, returns a sample from the datasetThe dataset object will return samples as dictionaries with keys for:
chip_id
: the chip idchip
: a two-band image tensor (VV and VH)label
: the label mask, if it existsimport torch
class FloodDataset(torch.utils.data.Dataset):
"""Reads in images, transforms pixel values, and serves a
dictionary containing chip ids, image tensors, and
label masks (where available).
"""
def __init__(self, metadata, feature_dir, label_dir=None, transforms=None):
self.chip_ids = metadata.chip_id.tolist()
self.feature_dir = feature_dir
self.label_dir = label_dir
self.transforms = transforms
def __len__(self):
return len(self.chip_ids)
def __getitem__(self, idx):
# Loads a 2-channel image from a chip_id
chip_id = self.chip_ids[idx]
vv_path = self.feature_dir / f"{chip_id}_vv.tif"
vh_path = self.feature_dir / f"{chip_id}_vh.tif"
with rasterio.open(vv_path) as vv:
vv_img = vv.read(1)
with rasterio.open(vh_path) as vh:
vh_img = vh.read(1)
x_arr = np.stack([vv_img, vh_img], axis=-1)
# Min-max normalization
min_norm = -77
max_norm = 26
x_arr = np.clip(x_arr, min_norm, max_norm)
x_arr = (x_arr - min_norm) / (max_norm - min_norm)
# Apply data augmentations, if provided
if self.transforms:
x_arr = self.transforms(image=x_arr)["image"]
x_arr = np.transpose(x_arr, [2, 0, 1])
# Prepare sample dictionary
sample = {"chip_id": chip_id, "chip": x_arr}
# Load label if available - training only
if self.label_dir is not None:
label_path = self.label_dir / f"{chip_id}.tif"
with rasterio.open(label_path) as lp:
y_arr = lp.read(1)
# Apply same data augmentations to label
if self.transforms:
y_arr = self.transforms(image=y_arr)["image"]
sample["label"] = y_arr
return sample
Our custom dataset normalizes input pixel values by applying min-max normalization with a minimum of -77 and a maximum of 26.
To prevent overfitting during training, we'll increase the size of our training data by applying a set of data augmentations to our input, including random cropping, random 90 degree rotations, and horizontal and vertical flipping. The image augmentation library albumentations is a helpful resource for this task.
!pip install albumentations
Collecting albumentations Using cached albumentations-1.1.0-py3-none-any.whl (102 kB) Requirement already satisfied: PyYAML in /srv/conda/envs/notebook/lib/python3.8/site-packages (from albumentations) (5.4.1) Requirement already satisfied: scipy in /srv/conda/envs/notebook/lib/python3.8/site-packages (from albumentations) (1.7.1) Requirement already satisfied: scikit-image>=0.16.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from albumentations) (0.18.2) Collecting qudida>=0.0.4 Using cached qudida-0.0.4-py3-none-any.whl (3.5 kB) Collecting opencv-python-headless>=4.1.1 Using cached opencv_python_headless-4.5.4.58-cp38-cp38-manylinux2014_x86_64.whl (47.6 MB) Requirement already satisfied: numpy>=1.11.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from albumentations) (1.21.2) Requirement already satisfied: scikit-learn>=0.19.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from qudida>=0.0.4->albumentations) (0.24.2) Requirement already satisfied: typing-extensions in /srv/conda/envs/notebook/lib/python3.8/site-packages (from qudida>=0.0.4->albumentations) (3.10.0.0) Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (3.4.3) Requirement already satisfied: networkx>=2.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (2.5) Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (8.2.0) Requirement already satisfied: imageio>=2.3.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (2.9.0) Requirement already satisfied: tifffile>=2019.7.26 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (2021.4.8) Requirement already satisfied: PyWavelets>=1.1.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations) (1.1.1) Requirement already satisfied: python-dateutil>=2.7 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.16.1->albumentations) (2.7.5) Requirement already satisfied: pyparsing>=2.2.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.16.1->albumentations) (2.4.7) Requirement already satisfied: cycler>=0.10 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.16.1->albumentations) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.16.1->albumentations) (1.3.1) Requirement already satisfied: six in /srv/conda/envs/notebook/lib/python3.8/site-packages (from cycler>=0.10->matplotlib!=3.0.0,>=2.0.0->scikit-image>=0.16.1->albumentations) (1.16.0) Requirement already satisfied: decorator>=4.3.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from networkx>=2.0->scikit-image>=0.16.1->albumentations) (5.0.9) Requirement already satisfied: joblib>=0.11 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (1.0.1) Requirement already satisfied: threadpoolctl>=2.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (2.2.0) Installing collected packages: opencv-python-headless, qudida, albumentations Successfully installed albumentations-1.1.0 opencv-python-headless-4.5.4.58 qudida-0.0.4
import albumentations
# These transformations will be passed to our model class
training_transformations = albumentations.Compose(
[
albumentations.RandomCrop(256, 256),
albumentations.RandomRotate90(),
albumentations.HorizontalFlip(),
albumentations.VerticalFlip(),
]
)
Next, we will create a custom class to define our training loss function and a helper function to calculate our validation metric, Intersection over Union (IOU).
For training, we will use a standard mixture of 50% cross-entropy loss and 50% dice loss, which improves learning when there are unbalanced classes. Since our images tend to contain more non-water than water pixels, this metric should be a good fit. Broadly speaking, cross-entropy loss evaluates differences between predicted and ground truth pixels and averages over all pixels, while dice loss measures overlap between predicted and ground truth pixels and divides a function of this value by the total number of pixels in both images. This custom class will inherit torch.nn.Module
, which is the base class for all neural network modules in PyTorch. A lower XEDiceLoss
score indicates better performance.
For validation, we will calculate IOU. As a reminder, IOU represents the size of the intersection divided by the size of the union of pixels. A higher IOU indicates better performance.
class XEDiceLoss(torch.nn.Module):
"""
Computes (0.5 * CrossEntropyLoss) + (0.5 * DiceLoss).
"""
def __init__(self):
super().__init__()
self.xe = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self, pred, true):
valid_pixel_mask = true.ne(255) # valid pixel mask
# Cross-entropy loss
temp_true = torch.where(~valid_pixel_mask, 0, true) # cast 255 to 0 temporarily
xe_loss = self.xe(pred, temp_true)
xe_loss = xe_loss.masked_select(valid_pixel_mask).mean()
# Dice loss
pred = torch.softmax(pred, dim=1)[:, 1]
pred = pred.masked_select(valid_pixel_mask)
true = true.masked_select(valid_pixel_mask)
dice_loss = 1 - (2.0 * torch.sum(pred * true)) / (torch.sum(pred + true) + 1e-7)
return (0.5 * xe_loss) + (0.5 * dice_loss)
def intersection_and_union(pred, true):
"""
Calculates intersection and union for a batch of images.
Args:
pred (torch.Tensor): a tensor of predictions
true (torc.Tensor): a tensor of labels
Returns:
intersection (int): total intersection of pixels
union (int): total union of pixels
"""
valid_pixel_mask = true.ne(255) # valid pixel mask
true = true.masked_select(valid_pixel_mask).to("cpu")
pred = pred.masked_select(valid_pixel_mask).to("cpu")
# Intersection and union totals
intersection = np.logical_and(true, pred)
union = np.logical_or(true, pred)
return intersection.sum(), union.sum()
Given our new FloodDataset
, PyTorch Lightning will allow us to train a model using minimal for loops. This library keeps the flexibility of PyTorch but removes most of the boilerplate training code, making it less error prone, cleaner to read, and easier to update. By subclassing pl.LightningModule
, most of the training logic will happen for us behind the scenes. We can simply define what a forward
call and training_step
are, and provide our model with a train_dataloader
to serve the data to our model. Behavior like checkpoint saving and early stopping are easily parameterized.
We define the following methods within our model:
forward
: define the forward pass for an imagetraining_step
: switch the model to train mode, implement the forward pass, and calculate training loss for a batchvalidation_step
: switch the model to eval mode and calculate the validation metric for a batchtrain_dataloader
: call an iterable over the training dataset for automatic batchingval_dataloader
: call an iterable over the validation dataset for automatic batchingconfigure_optimizers
: configure an optimizer and a scheduler to dynamically adjust the learning rate based on the number of epochs; this step automatically calls backward
and step
in each epochvalidation_epoch_end
: Reset the global trackerWe'll also add a couple of additional, optional, helper methods:
_build_df
: instantiate FloodDataset
objects for training and validation_prepare_model
: initialize a pretrained U-Net model_get_trainer_params
: prepare checkpoint and early stopping behaviorfit
: fit a pl.Trainer
object for training automation!pip install pytorch_lightning
Collecting pytorch_lightning Using cached pytorch_lightning-1.5.1-py3-none-any.whl (1.0 MB) Collecting tensorboard>=2.2.0 Using cached tensorboard-2.7.0-py3-none-any.whl (5.8 MB) Collecting pyDeprecate==0.3.1 Using cached pyDeprecate-0.3.1-py3-none-any.whl (10 kB) Collecting torchmetrics>=0.4.1 Using cached torchmetrics-0.6.0-py3-none-any.whl (329 kB) Requirement already satisfied: future>=0.17.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (0.18.2) Requirement already satisfied: typing-extensions in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (3.10.0.0) Requirement already satisfied: PyYAML>=5.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (5.4.1) Requirement already satisfied: numpy>=1.17.2 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (1.21.2) Requirement already satisfied: tqdm>=4.41.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (4.62.1) Requirement already satisfied: torch>=1.6 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (1.7.0) Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (2021.7.0) Requirement already satisfied: packaging>=17.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pytorch_lightning) (21.0) Requirement already satisfied: aiohttp in /srv/conda/envs/notebook/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.7.4.post0) Requirement already satisfied: requests in /srv/conda/envs/notebook/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.25.1) Requirement already satisfied: pyparsing>=2.0.2 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from packaging>=17.0->pytorch_lightning) (2.4.7) Collecting werkzeug>=0.11.15 Using cached Werkzeug-2.0.2-py3-none-any.whl (288 kB) Requirement already satisfied: wheel>=0.26 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.37.0) Requirement already satisfied: markdown>=2.6.8 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.3.4) Requirement already satisfied: google-auth<3,>=1.6.3 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.34.0) Collecting tensorboard-plugin-wit>=1.6.0 Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB) Collecting grpcio>=1.24.3 Using cached grpcio-1.41.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.9 MB) Requirement already satisfied: setuptools>=41.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (57.4.0) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.4.5) Collecting protobuf>=3.6.0 Using cached protobuf-3.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB) Collecting tensorboard-data-server<0.7.0,>=0.6.0 Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB) Collecting absl-py>=0.4 Using cached absl_py-1.0.0-py3-none-any.whl (126 kB) Requirement already satisfied: six in /srv/conda/envs/notebook/lib/python3.8/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch_lightning) (1.16.0) Requirement already satisfied: rsa<5,>=3.1.4 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.7.2) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.2.2) Requirement already satisfied: pyasn1-modules>=0.2.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.2.7) Requirement already satisfied: requests-oauthlib>=0.7.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (1.3.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.4.8) Requirement already satisfied: idna<3,>=2.5 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.10) Requirement already satisfied: chardet<5,>=3.0.2 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.26.6) Requirement already satisfied: certifi>=2017.4.17 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2021.5.30) Requirement already satisfied: oauthlib>=3.0.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (3.1.1) Collecting dataclasses Using cached dataclasses-0.6-py3-none-any.whl (14 kB) Requirement already satisfied: multidict<7.0,>=4.5 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (5.1.0) Requirement already satisfied: attrs>=17.3.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (21.2.0) Requirement already satisfied: async-timeout<4.0,>=3.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.0.1) Requirement already satisfied: yarl<2.0,>=1.0 in /srv/conda/envs/notebook/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.6.3) Installing collected packages: dataclasses, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, grpcio, absl-py, torchmetrics, tensorboard, pyDeprecate, pytorch-lightning Successfully installed absl-py-1.0.0 dataclasses-0.6 grpcio-1.41.1 protobuf-3.19.1 pyDeprecate-0.3.1 pytorch-lightning-1.5.1 tensorboard-2.7.0 tensorboard-data-server-0.6.1 tensorboard-plugin-wit-1.8.0 torchmetrics-0.6.0 werkzeug-2.0.2
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
class FloodModel(pl.LightningModule):
def __init__(self, train_metadata, val_metadata, hparams):
super(FloodModel, self).__init__()
self.hparams.update(hparams)
self.feature_dir = self.hparams.get("feature_dir")
self.label_dir = self.hparams.get("label_dir")
self.train_chip_ids = self.hparams.get("train_chip_ids")
self.val_chip_ids = self.hparams.get("val_chip_ids")
self.transform = training_transformations
self.backbone = self.hparams.get("backbone", "resnet34")
self.weights = self.hparams.get("weights", "imagenet")
self.learning_rate = self.hparams.get("lr", 1e-3)
self.max_epochs = self.hparams.get("max_epochs", 1000)
self.min_epochs = self.hparams.get("min_epochs", 6)
self.patience = self.hparams.get("patience", 4)
self.num_workers = self.hparams.get("num_workers", 2)
self.batch_size = self.hparams.get("batch_size", 32)
self.fast_dev_run = self.hparams.get("fast_dev_run", False)
self.val_sanity_checks = self.hparams.get("val_sanity_checks", 0)
self.gpu = self.hparams.get("gpu", False)
self.output_path = Path.cwd() / self.hparams.get("output_path", "model_outputs")
self.output_path.mkdir(exist_ok=True)
self.log_path = Path.cwd() / hparams.get("log_path", "tensorboard_logs")
self.log_path.mkdir(exist_ok=True)
# Instantiate training and val datasets
self.train_dataset = self._build_df(train_metadata, group="train")
self.val_dataset = self._build_df(val_metadata, group="val")
# Instantiate model and trainer params
self.model = self._prepare_model()
self.trainer_params = self._get_trainer_params()
# Track validation IOU globally (reset each epoch)
self.intersection = 0
self.union = 0
# Required LightningModule methods #
def forward(self, image):
# Forward pass
return self.model(image)
def training_step(self, batch, batch_idx):
# Switch on training mode
self.model.train()
torch.set_grad_enabled(True)
# Load images and labels
x = batch["chip"]
y = batch["label"].long()
if self.gpu:
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
# Forward pass
preds = self.forward(x)
# Calculate training loss
criterion = XEDiceLoss()
xe_dice_loss = criterion(preds, y)
# Log batch xe_dice_loss
self.log(
"xe_dice_loss",
xe_dice_loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return xe_dice_loss
def validation_step(self, batch, batch_idx):
# Switch on validation mode
self.model.eval()
torch.set_grad_enabled(False)
# Load images and labels
x = batch["chip"]
y = batch["label"].long()
if self.gpu:
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
# Forward pass & softmax
preds = self.forward(x)
preds = torch.softmax(preds, dim=1)[:, 1]
preds = (preds > 0.5) * 1
# Calculate validation IOU (global)
intersection, union = intersection_and_union(preds, y)
self.intersection += intersection
self.union += union
# Log batch IOU
batch_iou = intersection / union
self.log(
"iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
return batch_iou
def train_dataloader(self):
# DataLoader class for training
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self):
# DataLoader class for validation
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=0,
shuffle=False,
pin_memory=True,
)
def configure_optimizers(self):
# Define optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Define scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=self.patience
)
scheduler = {
"scheduler": scheduler,
"interval": "epoch",
"monitor": "iou_epoch",
} # logged value to monitor
return [optimizer], [scheduler]
def validation_epoch_end(self, outputs):
# Reset metrics before next epoch
self.intersection = 0
self.union = 0
# Convenience Methods #
def _build_df(self, metadata, group):
# Instantiate datasets
if group == "train":
df = FloodDataset(
metadata=metadata,
feature_dir=self.feature_dir,
label_dir=self.label_dir,
transforms=self.transform,
)
elif group == "val":
df = FloodDataset(
metadata=metadata,
feature_dir=self.feature_dir,
label_dir=self.label_dir,
transforms=None,
)
return df
def _prepare_model(self):
# Instantiate U-Net model
unet_model = smp.Unet(
encoder_name=self.backbone,
encoder_weights=self.weights,
in_channels=2,
classes=2,
)
if self.gpu:
unet_model.cuda()
return unet_model
def _get_trainer_params(self):
# Define callback behavior
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=self.output_path,
monitor="iou_epoch",
mode="max",
verbose=True,
)
# Define early stopping behavior
early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
monitor="iou_epoch",
patience=(self.patience * 3),
mode="max",
verbose=True,
)
# Specify where TensorBoard logs will be saved
logger = pl.loggers.TensorBoardLogger(self.log_path, name="benchmark-model")
trainer_params = {
"callbacks": [checkpoint_callback, early_stop_callback],
"max_epochs": self.max_epochs,
"min_epochs": self.min_epochs,
"default_root_dir": self.output_path,
"logger": logger,
"gpus": None if not self.gpu else 1,
"fast_dev_run": self.fast_dev_run,
"num_sanity_val_steps": self.val_sanity_checks,
}
return trainer_params
def fit(self):
# Set up and fit Trainer object
self.trainer = pl.Trainer(**self.trainer_params)
self.trainer.fit(self)
Finally, it's time to fit our model. A FloodModel
can be instantiated using a dictionary of hparams
. The only required hparams
are the feature directory and the label directory, but there are several additional hyperparameters we can specify to explore modeling strategies, including the learning rate, patience, and batch size. Consider experimenting with different models and using hyperparameter tuning to find the best combination of parameters to improve performance.
Once we specify our hparams
, we can simply call the fit
method to begin training!
hparams = {
# Required hparams
"feature_dir": FEATURE_PATH,
"label_dir": LABEL_PATH,
# Optional hparams
"backbone": "resnet34",
"weights": "imagenet",
"lr": 1e-3,
"min_epochs": 6,
"max_epochs": 1000,
"patience": 4,
"batch_size": 32,
"num_workers": 2,
"val_sanity_checks": 0,
"fast_dev_run": False,
"output_path": "model_outputs",
"log_path": "tensorboard_logs",
"gpu": torch.cuda.is_available(),
}
flood_model = FloodModel(train_metadata=train_df, val_metadata=val_df, hparams=hparams)
flood_model.fit()
GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ------------------------------- 0 | model | Unet | 24.4 M ------------------------------- 24.4 M Trainable params 0 Non-trainable params 24.4 M Total params 97.734 Total estimated model params size (MB)
Training: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Metric iou_epoch improved. New best score: 0.011 Epoch 0, global step 22: iou_epoch reached 0.01073 (best 0.01073), saving model to "/home/jovyan/PlanetaryComputerExamples/competitions/s1floods/model_outputs/epoch=0-step=22.ckpt" as top 1
Validating: 0it [00:00, ?it/s]
Metric iou_epoch improved by 0.193 >= min_delta = 0.0. New best score: 0.204 Epoch 1, global step 45: iou_epoch reached 0.20409 (best 0.20409), saving model to "/home/jovyan/PlanetaryComputerExamples/competitions/s1floods/model_outputs/epoch=1-step=45.ckpt" as top 1
Validating: 0it [00:00, ?it/s]
Metric iou_epoch improved by 0.017 >= min_delta = 0.0. New best score: 0.221 Epoch 2, global step 68: iou_epoch reached 0.22133 (best 0.22133), saving model to "/home/jovyan/PlanetaryComputerExamples/competitions/s1floods/model_outputs/epoch=2-step=68.ckpt" as top 1
Validating: 0it [00:00, ?it/s]
Metric iou_epoch improved by 0.093 >= min_delta = 0.0. New best score: 0.314 Epoch 3, global step 91: iou_epoch reached 0.31400 (best 0.31400), saving model to "/home/jovyan/PlanetaryComputerExamples/competitions/s1floods/model_outputs/epoch=3-step=91.ckpt" as top 1
Validating: 0it [00:00, ?it/s]
Epoch 4, global step 114: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 5, global step 137: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 6, global step 160: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 7, global step 183: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 8, global step 206: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 9, global step 229: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 10, global step 252: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 11, global step 275: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 12, global step 298: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 13, global step 321: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Epoch 14, global step 344: iou_epoch was not in top 1
Validating: 0it [00:00, ?it/s]
Monitored metric iou_epoch did not improve in the last 12 records. Best score: 0.314. Signaling Trainer to stop. Epoch 15, global step 367: iou_epoch was not in top 1
PyTorch Lightning lets you log PyTorch models and metrics into a directory for visualization within the TensorBoard UI. TensorBoard is a machine learning visualization toolkit that's helpful for tracking metrics across batches, epochs, and models.
To view your Tensorboard logs, start a TensorBoard server by opening a new terminal, navigating to your notebook directory, and then running
$ tensorboard --log-dir=tensorboard_logs
That will start the TensorBoard server on the notebook server at localhost:6006
. You can access it from your local browser at the following link (add your Planetary Computer email address where specified): https://pccompute.westeurope.cloudapp.azure.com/compute/user/%5BUSER_EMAIL_HERE%5D/proxy/6006
flood_model.trainer_params["callbacks"][0].best_model_score
tensor(0.3140, device='cuda:0')
Our best validation IOU is 0.3140.
Now that our model is trained, we can finally prepare our submission. For this competition, your inference code will be executed to generate a score. This means that we'll need to save our model, write inference code, package everything up, and upload submission.zip
to the competition submissions page.
Let's walk through the process. Below we will:
benchmark-pytorch/
directory with an benchmark-pytorch/assets
subdirectory to store our trained model and weightsmain.py
file to load our model and weights, perform inference on images stored in data/test_features
, generate chip-level predictions, and save predictions to submission/
(see an example main.py
file here)zip
the contents of benchmark-pytorch/
(not the directory itself) into a file called submission.zip
submission.zip
to the competition submissions page; the file will be unzipped and main.py
will be run in a containerized execution environment to obtain our model's IOUNote: If you'd like to look at this benchmark submission already completed, you can see the end product at the runtime repo.
First, let's make the benchmark-pytorch/
directory and its benchmark-pytorch/assets
subdirectory.
submission_path = Path("benchmark-pytorch")
submission_path.mkdir(exist_ok=True)
submission_assets_path = submission_path / "assets"
submission_assets_path.mkdir(exist_ok=True)
Next, let's save our model and weights.
In Pytorch, the learnable parameters of a model including its weights and biases are contained in it's parameters. A state_dict
is a Python dictionary object that maps each layer to its parameter tensor. PyTorch provides a handy model.save()
method for serializing our model's state_dict
and, in turn, saving our model's weights.
weight_path = submission_assets_path / "flood_model.pt"
torch.save(flood_model.state_dict(), weight_path)
To load our trained model in the competition runtime environment, we'll save a simplified FloodModel
class to benchmark-pytorch/flood_model.py
. This model class will instantiate a generic U-Net model and will contain a helper method to perform inference for a given chip.
%%file benchmark-pytorch/flood_model.py
import numpy as np
import pytorch_lightning as pl
import rasterio
import segmentation_models_pytorch as smp
import torch
class FloodModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=2,
classes=2,
)
def forward(self, image):
# Forward pass
return self.model(image)
def predict(self, vv_path, vh_path):
# Switch on evaluation mode
self.model.eval()
torch.set_grad_enabled(False)
# Create a 2-channel image
with rasterio.open(vv_path) as vv:
vv_img = vv.read(1)
with rasterio.open(vh_path) as vh:
vh_img = vh.read(1)
x_arr = np.stack([vv_img, vh_img], axis=-1)
# Min-max normalization
min_norm = -77
max_norm = 26
x_arr = np.clip(x_arr, min_norm, max_norm)
x_arr = (x_arr - min_norm) / (max_norm - min_norm)
# Transpose
x_arr = np.transpose(x_arr, [2, 0, 1])
x_arr = np.expand_dims(x_arr, axis=0)
# Perform inference
preds = self.forward(torch.from_numpy(x_arr))
preds = torch.softmax(preds, dim=1)[:, 1]
preds = (preds > 0.5) * 1
return preds.detach().numpy().squeeze().squeeze()
Writing benchmark-pytorch/flood_model.py
The next step is to generate our main.py
script in the root of our benchmark-pytorch/
directory. The code submission format page contains helpful guidance for creating main.py
. This file assumes that FloodModel
can be imported from flood_model.py
.
Note: For each test chip geography, supplementary elevation and permanent water images are available as static resources in the /codeexecution/data/
directory. You have access to seven subdirectories, which each contain a single image per test chip. To use supplementary data for inference, simply load these images by their respective chip_id
.
%%file benchmark-pytorch/main.py
import os
from pathlib import Path
from loguru import logger
import numpy as np
from tifffile import imwrite
from tqdm import tqdm
import torch
import typer
from flood_model import FloodModel
ROOT_DIRECTORY = Path("/codeexecution")
SUBMISSION_DIRECTORY = ROOT_DIRECTORY / "submission"
ASSETS_DIRECTORY = ROOT_DIRECTORY / "assets"
DATA_DIRECTORY = ROOT_DIRECTORY / "data"
INPUT_IMAGES_DIRECTORY = DATA_DIRECTORY / "test_features"
# EXAMPLE_NASADEM_DIRECTORY = DATA_DIRECTORY / "nasadem"
# Make sure the smp loader can find our torch assets because we don't have internet!
os.environ["TORCH_HOME"] = str(ASSETS_DIRECTORY / "torch")
def make_prediction(chip_id, model):
"""
Given a chip_id, read in the vv/vh bands and predict a water mask.
Args:
chip_id (str): test chip id
Returns:
output_prediction (arr): prediction as a numpy array
"""
logger.info("Starting inference.")
try:
vv_path = INPUT_IMAGES_DIRECTORY / f"{chip_id}_vv.tif"
vh_path = INPUT_IMAGES_DIRECTORY / f"{chip_id}_vh.tif"
# example_nasadem_path = EXAMPLE_NASADEM_DIRECTORY / f"{chip_id}.tif"
# etc.
output_prediction = model.predict(vv_path, vh_path)
except Exception as e:
logger.error(f"No bands found for {chip_id}. {e}")
raise
return output_prediction
def get_expected_chip_ids():
"""
Use the test features directory to see which images are expected.
"""
paths = INPUT_IMAGES_DIRECTORY.glob("*.tif")
# Return one chip id per two bands (VV/VH)
ids = list(sorted(set(path.stem.split("_")[0] for path in paths)))
return ids
def main():
"""
For each set of two input bands, generate an output file
using the `make_predictions` function.
"""
logger.info("Loading model")
# Explicitly set where we expect smp to load the saved resnet from just to be sure
torch.hub.set_dir(ASSETS_DIRECTORY / "torch/hub")
model = FloodModel()
model.load_state_dict(torch.load(ASSETS_DIRECTORY / "flood_model.pt"))
logger.info("Finding chip IDs")
chip_ids = get_expected_chip_ids()
if not chip_ids:
typer.echo("No input images found!")
raise typer.Exit(code=1)
logger.info(f"Found {len(chip_ids)} test chip_ids. Generating predictions.")
for chip_id in tqdm(chip_ids, miniters=25):
output_path = SUBMISSION_DIRECTORY / f"{chip_id}.tif"
output_data = make_prediction(chip_id, model).astype(np.uint8)
imwrite(output_path, output_data, dtype=np.uint8)
logger.success(f"Inference complete.")
if __name__ == "__main__":
typer.run(main)
Writing benchmark-pytorch/main.py
We also need to copy the model that segmentation-models-pytorch
saved to our computer. The container doesn't have access to the Internet, so you are responsible for packaging up any local assets and pointing your code to them. In the main.py
below, you will see that we set the TORCH_HOME
environment variable to look at where our submission gets decompressed instead of its default (usually something like ~/.cache/torch
). See the pytorch.hub docs for more on this.
!cp -R ~/.cache/torch benchmark-pytorch/assets/
Now our submission directory looks like this:
!tree benchmark-pytorch/
benchmark-pytorch/ ├── assets │ ├── flood_model.pt │ └── torch │ └── hub │ └── checkpoints │ ├── resnet34-333f7ec4.pth │ └── tf_efficientnet_b8_ap-00e169fa.pth ├── flood_model.py └── main.py 4 directories, 5 files
Finally, let's zip
everything up, submit it, and get our score!
# Remember to avoid including the inference dir itself
!cd benchmark-pytorch && zip -r ../submission.zip *
adding: assets/ (stored 0%) adding: assets/flood_model.pt (deflated 7%) adding: assets/torch/ (stored 0%) adding: assets/torch/hub/ (stored 0%) adding: assets/torch/hub/checkpoints/ (stored 0%) adding: assets/torch/hub/checkpoints/tf_efficientnet_b8_ap-00e169fa.pth (deflated 8%) adding: assets/torch/hub/checkpoints/resnet34-333f7ec4.pth (deflated 7%) adding: flood_model.py (deflated 58%) adding: main.py (deflated 58%)
!du -h submission.zip
472M submission.zip
We can now head over to the competition submissions page to submit our code and obtain our model's IOU!
Our submission takes about 20 minutes to execute. We should see an IOU of 0.6255 on the leaderboard. That's a great start, but now it's to you to improve this score. We hope this benchmark solution provides a helpful framework for exploring the data, designing a machine learning pipeline, training a deep learning model, and preparing your code for submission.
Head over to the STAC Overflow challenge homepage to get started. We can't wait to see what you build!