This notebook applies the Clay model on imagery composites, specifically a Sentinel-2 Geometric Median (GeoMAD) composite. We will use Digital Earth Pacific's STAC API to obtain these datasets, and apply it on a mineral resource detection downstream task to do the following:
This is a demonstration of how one can use Clay to identify possible new events provided a reference dataset for where known events occur.
References:
import glob
import random
from pathlib import Path
import geopandas as gpd
import lancedb
import matplotlib.pyplot as plt
import numpy
import pandas as pd
import pystac_client
import rasterio
import rioxarray # noqa: F401
import shapely
import stackstac
import torch
from rasterio.enums import Resampling
from shapely.geometry import Point, Polygon, box
from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule
pd.set_option("display.max_colwidth", None)
BAND_GROUPS = {
"rgb": ["B04", "B03", "B02"],
"rededge": ["B05", "B06", "B07", "B8A"],
"nir": ["B08"],
"swir": ["B11", "B12"],
"sar": ["mean_vv", "mean_vh"],
}
STAC_API = "https://stac.staging.digitalearthpacific.org"
COLLECTION = "dep_s2_geomad"
We will use these for reference in a similarity search.
mrd = gpd.read_file(
"https://raw.githubusercontent.com/digitalearthpacific/mineral-resource-detection/"
"d117d04703f77ff21c15c7ffc424c3c55b51c492/training_data/draft_inputs/MRD_dissagregated_25.geojson"
)
# Assuming mining extraction relates to the quarry lulc class
mrd_mining = mrd[mrd["lulc_class"] == "quarry"]
# Extent of the mining extraction activity per the ground truth reference points
mrd_mining_bounds = mrd_mining.total_bounds
Since we are demoing a single 512x512 tile in this tutorial, let's identify a cluster where several ground truth points exist.
# sample cluster
bbox_bl = (177.4199, -17.8579)
bbox_tl = (177.4156, -17.6812)
bbox_br = (177.5657, -17.8572)
bbox_tr = (177.5657, -17.6812)
Define spatiotemporal query
# Define area of interest
area_of_interest = shapely.box(
xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1]
)
# Define temporal range
# one annual composite
daterange: dict = ["2021-01-01T00:00:00Z", "2021-12-31T23:59:59Z"]
catalog = pystac_client.Client.open(url=STAC_API)
sen2_search = catalog.search(
collections=[COLLECTION],
datetime=daterange,
intersects=area_of_interest,
max_items=100,
)
items = sen2_search.get_all_items()
print(f"Found {len(items)} items")
Get the composite data into a numpy array and visualize the imagery. STAC browser URL is at https://stac-browser.staging.digitalearthpacific.org
# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]
# Convert point from lon/lat to UTM projection
poidf = gpd.GeoDataFrame(crs="OGC:CRS84", geometry=[area_of_interest.centroid]).to_crs(
epsg
)
geom = poidf.iloc[0].geometry
# Create bounds of the correct size, the model
# requires 512x512 pixels at 10m resolution.
bounds = (geom.x - 2560, geom.y - 2560, geom.x + 2560, geom.y + 2560)
# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR band groups.
stack = stackstac.stack(
items,
bounds=bounds,
snap_bounds=False,
epsg=epsg,
resolution=10,
dtype="float32",
rescale=False,
fill_value=0,
assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"],
resampling=Resampling.nearest,
)
stack = stack.compute()
assert stack.shape == (1, 4, 512, 512)
stack.sel(band=["B04", "B03", "B02"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000
)
outdir = Path("data/minicubes")
outdir.mkdir(exist_ok=True, parents=True)
write = True
if write:
# Write tile to output dir
for tile in stack:
date = str(tile.time.values)[:10]
name = "{dir}/claytile_{date}.tif".format(
dir=outdir,
date=date.replace("-", ""),
)
tile.rio.to_raster(name, compress="deflate")
with rasterio.open(name, "r+") as rst:
rst.update_tags(date=date)
We will use the geospatial bounds of the 32x32 windowed subsets ("chunks") to store the patch level embeddings.
# Define the chunk size for tiling
chunk_size = {"x": 32, "y": 32} # Adjust the chunk size as needed
# Tile the data
ds_chunked = stack.chunk(chunk_size)
# Get the dimensions of the data array
dims = ds_chunked.dims
# Get the geospatial information from the original dataset
geo_info = ds_chunked.attrs
# Iterate over the chunks and compute the geospatial bounds for each chunk
chunk_bounds = {}
# Get the geospatial transform and CRS
transform = ds_chunked.attrs["transform"]
crs = ds_chunked.attrs["crs"]
for x in range(ds_chunked.sizes["x"] // chunk_size["x"]):
for y in range(ds_chunked.sizes["y"] // chunk_size["y"]):
# Compute chunk coordinates
x_start = x * chunk_size["x"]
y_start = y * chunk_size["y"]
x_end = min(x_start + chunk_size["x"], ds_chunked.sizes["x"])
y_end = min(y_start + chunk_size["y"], ds_chunked.sizes["y"])
# Compute chunk geospatial bounds
lon_start, lat_start = transform * (x_start, y_start)
lon_end, lat_end = transform * (x_end, y_end)
# print(lon_start, lat_start, lon_end, lat_end, x, y)
# Store chunk bounds
chunk_bounds[(x, y)] = {
"lon_start": lon_start,
"lat_start": lat_start,
"lon_end": lon_end,
"lat_end": lat_end,
}
# Print chunk bounds
# for key, value in chunk_bounds.items():
# print(f"Chunk {key}: {value}")
We will generate patch level embeddings averaged over the band groups.
DATA_DIR = "data/minicubes"
CKPT_PATH = (
"https://huggingface.co/made-with-clay/Clay/resolve/main/"
"Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
)
# Load model
multi_model = CLAYModule.load_from_checkpoint(
CKPT_PATH,
mask_ratio=0.0,
band_groups={"rgb": (2, 1, 0), "nir": (3,)},
bands=4,
strict=False, # ignore the extra parameters in the checkpoint
embeddings_level="group",
)
# Set the model to evaluation mode
multi_model.eval()
# Load the datamodule, with the reduced set of
class ClayDataModuleMulti(ClayDataModule):
MEAN = [
1369.03, # red
1597.68, # green
1741.10, # blue
2893.86, # nir
]
STD = [
2026.96, # red
2011.88, # green
2146.35, # blue
1917.12, # nir
]
data_dir = Path(DATA_DIR)
dm = ClayDataModuleMulti(data_dir=str(data_dir.absolute()), batch_size=1)
dm.setup(stage="predict")
trn_dl = iter(dm.predict_dataloader())
embeddings = []
for batch in trn_dl:
with torch.no_grad():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(multi_model.device)
# Pass just the specific band through the model
batch["timestep"] = batch["timestep"].to(multi_model.device)
batch["latlon"] = batch["latlon"].to(multi_model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = multi_model.model.encoder(batch)
print(unmasked_patches.detach().cpu().numpy())
embeddings.append(unmasked_patches.detach().cpu().numpy())
print(len(embeddings[0])) # embeddings is a list
print(embeddings[0].shape) # with date and lat/lon
print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon
# remove date and lat/lon and reshape to disaggregated patches
embeddings_patch = embeddings[0][:, :-2, :].reshape([2, 16, 16, 768])
embeddings_patch.shape
# average over the band groups
embeddings_patch_avg_group = embeddings_patch.mean(axis=0)
embeddings_patch_avg_group.shape
Save the patch level embeddings with the matching geospatial bounds from the chunks we computed earlier. We are correlating patch to chunk bounds based on matching index. This assumes the patches and chunks both define 32x32 subsets with zero overlap.
outdir_embeddings = Path("data/embeddings")
outdir_embeddings.mkdir(exist_ok=True, parents=True)
# Iterate through each patch
for i in range(embeddings_patch_avg_group.shape[0]):
for j in range(embeddings_patch_avg_group.shape[1]):
embeddings_output_patch = embeddings_patch_avg_group[i, j]
item_ = [
element for element in list(chunk_bounds.items()) if element[0] == (i, j)
]
box_ = [
item_[0][1]["lon_start"],
item_[0][1]["lat_start"],
item_[0][1]["lon_end"],
item_[0][1]["lat_end"],
]
source_url = batch["source_url"]
date = batch["date"]
data = {
"source_url": batch["source_url"][0],
"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
}
# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
# The box_ list is encoded as
# [bottom left x, bottom left y, top right x, top right y]
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])
# Create the GeoDataFrame
gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{epsg}")
# Reproject to WGS84 (lon/lat coordinates)
gdf = gdf.to_crs(epsg=4326)
outpath = (
f"{outdir_embeddings}/"
f"{batch['source_url'][0].split('/')[-1][:-4]}_{i}_{j}.gpq"
)
print(
f"Saved {len(gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
)
We will use reference lon,lat points from the ground truth mining extraction data to define a filtered search where a point maps to its overlapping patch, and that patch is used to find similar patches (aka potential new mining extraction sites).
db = lancedb.connect("embeddings")
# Data for DB table
data = []
# Dataframe to find overlaps within
gdfs = []
for emb in glob.glob(f"{outdir_embeddings}/*.gpq"):
gdf = gpd.read_parquet(emb)
gdf["year"] = gdf.date.dt.year
gdf["tile"] = gdf["source_url"].apply(
lambda x: Path(x).stem.rsplit("/")[-1].rsplit("_")[0]
)
gdf["idx"] = "_".join(emb.split("/")[-1].split("_")[2:]).replace(".gpq", "")
gdf["box"] = [box(*geom.bounds) for geom in gdf.geometry]
gdfs.append(gdf)
for _, row in gdf.iterrows():
data.append(
{
"vector": row["embeddings"],
"path": row["source_url"],
"tile": row["tile"],
"date": row["date"],
"year": int(row["year"]),
"idx": row["idx"],
"box": row["box"].bounds,
}
)
# Combine patch level geodataframes into one
embeddings_gdf = pd.concat(gdfs, ignore_index=True)
embeddings_gdf_shuffled = embeddings_gdf.sample(frac=1).reset_index(drop=True)
area_of_interest_embedding = embeddings_gdf_shuffled.box.iloc[0]
# Extract coordinate system from first item
epsg = items[0].properties["proj:epsg"]
# Convert point from lon/lat to UTM projection
box_embedding = gpd.GeoDataFrame(
crs="OGC:CRS84", geometry=[area_of_interest_embedding]
).to_crs(epsg)
geom_embedding = box_embedding.iloc[0].geometry
# Create bounds of the correct size, the model
# requires 32x32 pixels at 10m resolution.
# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR band groups.
stack_embedding = stackstac.stack(
items,
bounds=geom_embedding.bounds,
snap_bounds=False,
epsg=epsg,
resolution=10,
dtype="float32",
rescale=False,
fill_value=0,
assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"],
resampling=Resampling.nearest,
)
stack_embedding = stack_embedding.compute()
assert stack_embedding.shape == (1, 4, 32, 32)
stack_embedding.sel(band=["B04", "B03", "B02"]).plot.imshow(
row="time", rgb="band", vmin=0, vmax=2000
)
db.drop_table("clay-v001")
db.table_names()
tbl = db.create_table("clay-v001", data=data, mode="overwrite")
v = tbl.head(1).to_pandas()["vector"].values[0]
# Function to check if a point intersects with a bounding box
def check_intersection(row):
return point.intersects(row["box"])
intersecting_rows_ = []
for i, row in mrd_mining.iterrows():
# Create a Point geometry from the latitude and longitude
point = Point(row.x, row.y)
# print(point)
# Apply the function to each row of the DataFrame
intersects = embeddings_gdf.apply(check_intersection, axis=1)
# Get the rows where the point intersects with the bounding box
intersecting_rows = embeddings_gdf[intersects]
if not intersecting_rows.empty:
# print(intersecting_rows)
intersecting_rows_.append(intersecting_rows)
# Number of intersections in our AOI (which depicts a cluster)
len(intersecting_rows_)
reference_number = random.randint(0, len(intersecting_rows_) - 1)
reference = tbl.to_pandas().query(
f"idx == '{intersecting_rows_[reference_number].idx.values[0]}'"
)
result = tbl.search(query=reference.iloc[0]["vector"]).limit(10).to_pandas()
# result.head(10)
result.columns
def plot(df, cols=10):
fig, axs = plt.subplots(1, cols, figsize=(20, 10))
row_0 = df.iloc[0]
path = row_0["path"]
chip = rasterio.open(path)
tile = row_0["tile"]
width = chip.width
height = chip.height
# Define the window size
window_size = (32, 32)
idxs_windows = {"idx": [], "window": []}
# Iterate over the image in 32x32 windows
for col in range(0, width, window_size[0]):
for row in range(0, height, window_size[1]):
# Define the window
window = ((row, row + window_size[1]), (col, col + window_size[0]))
# Read the data within the window
data = chip.read(window=window)
# Get the index of the window
index = (col // window_size[0], row // window_size[1])
# Process the window data here
# For example, print the index and the shape of the window data
# print("Index:", index)
# print("Window Shape:", data.shape)
idxs_windows["idx"].append("_".join(map(str, index)))
idxs_windows["window"].append(data)
# print(idxs_windows)
for ax, (_, row) in zip(axs.flatten(), df.iterrows()):
idx = row["idx"]
# Find the corresponding window based on the idx
window_index = idxs_windows["idx"].index(idx)
window_data = idxs_windows["window"][window_index]
# print(window_data.shape)
subset_img = numpy.clip(
(window_data.transpose(1, 2, 0)[:, :, :3] / 10_000) * 3, 0, 1
)
ax.imshow(subset_img)
ax.set_title(f"{tile}/{idx}")
ax.set_axis_off()
plt.tight_layout()
fig.savefig("similar.png")
plot(result)
The reference patch will be plotted in yellow.
# Make a geodataframe of the search results
result_boxes = [
Polygon(
[(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]
)
for bbox in result["box"]
]
result_gdf = gpd.GeoDataFrame(result, geometry=result_boxes)
result_gdf.crs = "EPSG:4326"
# Plot the AOI in RGB
plot = stack.sel(band=["B04", "B03", "B02"]).plot
plot.imshow(row="time", rgb="band", vmin=0, vmax=2000)
# Overlay the bounding boxes of the patches identified from the similarity search
result_gdf.to_crs(epsg).plot(ax=plt.gca(), color="red", alpha=0.5)
# Reference embedding
reference_gdf = gpd.GeoDataFrame(
intersecting_rows_[reference_number],
geometry=intersecting_rows_[reference_number]["box"],
)
reference_gdf.to_crs(epsg).plot(ax=plt.gca(), color="yellow", alpha=0.5)
# Overlay the ground truth quarry points
mrd_mining.to_crs(epsg).cx[bounds[0] : bounds[2], bounds[1] : bounds[3]].plot(
ax=plt.gca(), color="blue", markersize=5
)
# Set plot title and labels
plt.title("Sentinel-2 with ground truth and similar embeddings")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
# Show the plot
plt.show()