This notebook contains a complete example for how to run Clay. It combines the following three different aspects
# Ensure working directory is the repo home
import os
os.chdir("..")
import warnings
from pathlib import Path
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy
import pandas as pd
import pystac_client
import rasterio
import rioxarray # noqa: F401
import stackstac
import torch
from rasterio.enums import Resampling
from shapely import Point
from sklearn import decomposition
from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule
warnings.filterwarnings("ignore")
BAND_GROUPS = {
"rgb": ["red", "green", "blue"],
"rededge": ["rededge1", "rededge2", "rededge3", "nir08"],
"nir": [
"nir",
],
"swir": ["swir16", "swir22"],
"sar": ["vv", "vh"],
}
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"
In this example we use a location and date range to visualize a forest fire that happened in Monchique in 2018
# Point over Monchique Portugal
poi = 37.30939, -8.57207
# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
collections=[COLLECTION],
datetime=f"{start}/{end}",
bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),
max_items=100,
query={"eo:cloud_cover": {"lt": 80}},
)
items = search.get_all_items()
print(f"Found {len(items)} items")
Found 12 items
Get the data into a numpy array and visualize the imagery. The burn scar is visible in the last five images.
# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]
# Convert point into the image projection
poidf = gpd.GeoDataFrame(
pd.DataFrame(),
crs="EPSG:4326",
geometry=[Point(poi[1], poi[0])],
).to_crs(epsg)
coords = poidf.iloc[0].geometry.coords[0]
# Create bounds of the correct size, the model
# requires 512x512 pixels at 10m resolution.
bounds = (
coords[0] - 2560,
coords[1] - 2560,
coords[0] + 2560,
coords[1] + 2560,
)
# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR band groups.
stack = stackstac.stack(
items,
bounds=bounds,
snap_bounds=False,
epsg=epsg,
resolution=10,
dtype="float32",
rescale=False,
fill_value=0,
assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"],
resampling=Resampling.nearest,
)
stack = stack.compute()
stack.sel(band=["red", "green", "blue"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
To use the mini datacube in the Clay dataloader, we need to write the images to tif files on disk. These tif files are then used by the Clay data loader for creating embeddings below.
outdir = Path("data/minicubes")
outdir.mkdir(exist_ok=True, parents=True)
# Write tile to output dir
for tile in stack:
# Grid code like MGRS-29SNB
mgrs = str(tile.coords["grid:code"].values).split("-")[1]
date = str(tile.time.values)[:10]
name = "{dir}/claytile_{mgrs}_{date}.tif".format(
dir=outdir,
mgrs=mgrs,
date=date.replace("-", ""),
)
tile.rio.to_raster(name, compress="deflate")
with rasterio.open(name, "r+") as rst:
rst.update_tags(date=date)
Now switch gears and load the tiles to create embeddings and analyze them.
The model checkpoint can be loaded directly from huggingface, and the data directory points to the directory we created in the steps above.
Note that the normalization parameters for the data module need to be adapted based on the band groups that were selected as partial input. The full set of normalization parameters can be found here.
DATA_DIR = "data/minicubes"
CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
# Load model
rgb_model = CLAYModule.load_from_checkpoint(
CKPT_PATH,
mask_ratio=0.0,
band_groups={"rgb": (2, 1, 0), "nir": (3,)},
bands=4,
strict=False, # ignore the extra parameters in the checkpoint
)
# Set the model to evaluation mode
rgb_model.eval()
# Load the datamodule, with the reduced set of
class ClayDataModuleRGB(ClayDataModule):
MEAN = [
1369.03, # red
1597.68, # green
1741.10, # blue
2858.43, # nir
]
STD = [
2026.96, # red
2011.88, # green
2146.35, # blue
2016.38, # nir
]
data_dir = Path(DATA_DIR)
dm = ClayDataModuleRGB(data_dir=str(data_dir.absolute()), batch_size=20)
dm.setup(stage="predict")
trn_dl = iter(dm.predict_dataloader())
Total number of chips: 12
This will loop through the images returned by the data loader and evaluate the model for each one of the images. The raw embeddings are reduced to mean values to simplify the data.
embeddings = []
for batch in trn_dl:
with torch.inference_mode():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(rgb_model.device)
# Pass just the specific band through the model
batch["timestep"] = batch["timestep"].to(rgb_model.device)
batch["latlon"] = batch["latlon"].to(rgb_model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = rgb_model.model.encoder(batch)
embeddings.append(unmasked_patches.detach().cpu().numpy())
embeddings = numpy.vstack(embeddings)
embeddings_mean = embeddings[:, :-2, :].mean(axis=1)
print(f"Average embeddings have shape {embeddings_mean.shape}")
Average embeddings have shape (12, 768)
Now we can make a simple analysis of the embeddings. We reduce all the embeddings to a single number using Principle Component Analysis. Then we can plot the principal components. The effect of the fire on the embeddings is clearly visible. We use the following color code in the graph:
Color | Interpretation |
---|---|
Green | Cloudy Images |
Blue | Before the fire |
Red | After the fire |
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings_mean)
plt.xticks(rotation=-30)
# All points
plt.scatter(stack.time, pca_result, color="blue")
# Cloudy images
plt.scatter(stack.time[0], pca_result[0], color="green")
plt.scatter(stack.time[2], pca_result[2], color="green")
# After fire
plt.scatter(stack.time[-5:], pca_result[-5:], color="red")
<matplotlib.collections.PathCollection at 0x7f9948d29890>
In the plot above, each image embedding is one point. One can clearly distinguish the two cloudy images and the values after the fire are consistently low.