This notebook contains a complete example for how to run Clay. It combines the following three different aspects:
# Ensure working directory is the repo home
import os
os.chdir("..")
import warnings
from pathlib import Path
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy
import pandas as pd
import pystac_client
import rasterio
import rioxarray # noqa: F401
import stackstac
import torch
from rasterio.enums import Resampling
from shapely import Point
from sklearn import decomposition
from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule
warnings.filterwarnings("ignore")
BAND_GROUPS = {
"rgb": ["red", "green", "blue"],
"rededge": ["rededge1", "rededge2", "rededge3", "nir08"],
"nir": [
"nir",
],
"swir": ["swir16", "swir22"],
"sar": ["vv", "vh"],
}
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"
In this example we use a location and date range to visualize a major monsoon flood that happened in Padidan, Pakistan in 2022.
# Point over Padidan, Pakistan
poi = 26.776567, 68.287374
# Dates of a major monsoon flood (August 20, 2022)
start = "2022-06-01"
end = "2022-09-30"
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
collections=[COLLECTION],
datetime=f"{start}/{end}",
bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),
max_items=100,
query={"eo:cloud_cover": {"lt": 50}},
)
items = search.get_all_items()
print(f"Found {len(items)} items")
Get the data into a numpy array and visualize the imagery. The flood is visible in the last seven images. Note: SWIR is very helpful for flood mapping (ref: Satellite flood detection integrating hydrogeomorphic and spectral indices.
# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]
# Convert point into the image projection
poidf = gpd.GeoDataFrame(
pd.DataFrame(),
crs="EPSG:4326",
geometry=[Point(poi[1], poi[0])],
).to_crs(epsg)
coords = poidf.iloc[0].geometry.coords[0]
# Create bounds of the correct size, the model
# requires 512x512 pixels at 10m resolution.
bounds = (
coords[0] - 2560,
coords[1] - 2560,
coords[0] + 2560,
coords[1] + 2560,
)
# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use the
# the RGB, NIR and SWIR band groups.
stack = stackstac.stack(
items,
bounds=bounds,
snap_bounds=False,
epsg=epsg,
resolution=10,
dtype="float32",
rescale=False,
fill_value=0,
assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"] + BAND_GROUPS["swir"],
resampling=Resampling.nearest,
)
stack = stack.compute()
stack.sel(band=["red", "green", "blue"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
Notice the significant signal starting on August 26th (first image after the major flood) but appearing more strongly on the next date as August 26th was fairly cloudy.
stack.sel(band=["nir", "nir", "nir"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
Notice the same significant signal starting on August 26th (first image after the major flood), again appearing more strongly on the next date as August 26th was fairly cloudy.
stack.sel(band=["swir16", "swir16", "swir16"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
Notice the same significant signal starting on August 26th (first image after the major flood), again appearing more strongly on the next date as August 26th was fairly cloudy.
stack.sel(band=["swir22", "swir22", "swir22"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)
To use the mini datacube in the Clay dataloader, we need to write the images to tif files on disk. These tif files are then used by the Clay data loader for creating embeddings below.
outdir = Path("data/minicubes")
outdir.mkdir(exist_ok=True, parents=True)
# Write tile to output dir
for tile in stack:
# Grid code like MGRS-29SNB
mgrs = str(tile.coords["grid:code"].values).split("-")[1]
date = str(tile.time.values)[:10]
name = "{dir}/claytile_{mgrs}_{date}.tif".format(
dir=outdir,
mgrs=mgrs,
date=date.replace("-", ""),
)
tile.rio.to_raster(name, compress="deflate")
with rasterio.open(name, "r+") as rst:
rst.update_tags(date=date)
Now switch gears and load the tiles to create embeddings and analyze them.
The model checkpoint can be loaded directly from huggingface, and the data directory points to the directory we created in the steps above.
Note that the normalization parameters for the data module need to be adapted based on the band groups that were selected as partial input. The full set of normalization parameters can be found here.
DATA_DIR = "data/minicubes"
CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
# Load model
multi_model = CLAYModule.load_from_checkpoint(
CKPT_PATH,
mask_ratio=0.0,
band_groups={"rgb": (2, 1, 0), "nir": (3,), "swir": (4, 5)},
bands=6,
strict=False, # ignore the extra parameters in the checkpoint
)
# Set the model to evaluation mode
multi_model.eval()
# Load the datamodule, with the reduced set of
class ClayDataModuleMulti(ClayDataModule):
MEAN = [
1369.03, # red
1597.68, # green
1741.10, # blue
2893.86, # nir
2303.00, # swir16
1807.79, # swir22
]
STD = [
2026.96, # red
2011.88, # green
2146.35, # blue
1917.12, # nir
1679.88, # swir16
1568.06, # swir22
]
data_dir = Path(DATA_DIR)
dm = ClayDataModuleMulti(data_dir=str(data_dir.absolute()), batch_size=2)
dm.setup(stage="predict")
trn_dl = iter(dm.predict_dataloader())
This will loop through the images returned by the data loader and evaluate the model for each one of the images. The raw embeddings are reduced to mean values to simplify the data.
embeddings = []
ts = []
for batch in trn_dl:
with torch.no_grad():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(multi_model.device)
# Pass just the specific band through the model
batch["timestep"] = batch["timestep"].to(multi_model.device)
batch["date"] = batch["date"] # .to(multi_model.device)
batch["latlon"] = batch["latlon"].to(multi_model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = multi_model.model.encoder(batch)
embeddings.append(unmasked_patches.detach().cpu().numpy())
ts.append(batch["date"])
embeddings = numpy.vstack(embeddings)
embeddings_mean = embeddings[:, :-2, :].mean(axis=1)
print(f"Average embeddings have shape {embeddings_mean.shape}")
Check the dates. Notice they are in sublists of size 2 because of the batch size.
ts
Flatten the dates
tss = [t for tb in ts for t in tb]
tss
Now we can make a simple analysis of the embeddings. We reduce all the embeddings to a single number using Principle Component Analysis. Then we can plot the principal components. The effect of the flood on the embeddings is clearly visible. We use the following color code in the graph:
Color | Interpretation |
---|---|
Green | Cloudy Images |
Blue | Before the flood |
Red | After the flood |
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings_mean)
plt.xticks(rotation=-30)
# All points
plt.scatter(tss, pca_result, color="blue")
# plt.scatter(stack.time, pca_result, color="blue")
# Cloudy images
plt.scatter(tss[7], pca_result[7], color="green")
plt.scatter(tss[8], pca_result[8], color="green")
# plt.scatter(stack.time[7], pca_result[7], color="green")
# plt.scatter(stack.time[8], pca_result[8], color="green")
# After flood
plt.scatter(tss[-7:], pca_result[-7:], color="red")
# plt.scatter(stack.time[-7:], pca_result[-7:], color="red")
In the plot above, each image embedding is one point. One can clearly distinguish the two cloudy images and the values after the flood are consistently low.
A quick t-SNE calculation shows that the dates indeed cluster as we'd expect, with the before flood dates grouped together, and the after-flood days together.
from sklearn.manifold import TSNE
# Perform t-SNE on the embeddings
tsne = TSNE(n_components=2, perplexity=5)
X_tsne = tsne.fit_transform(embeddings_mean)
# Plot the results
plt.figure(figsize=(10, 6))
plt.scatter(X_tsne[:, 0], X_tsne[:, 1])
# Annotate each point with the corresponding date
for i, (x, y) in enumerate(zip(X_tsne[:, 0], X_tsne[:, 1])):
plt.annotate(f"{tss[i]}", (x, y))
plt.title("t-SNE Visualization")
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.show()