This notebook walks through Clay model v1 inference on NAIP (National Agriculture Imagery Program) data and similarity search. The workflow includes loading and preprocessing data from STAC, tiling the images and encoding metadata, generating embeddings and querying across them for similar representations. The NAIP data comes in annual composites. We are using data from one year within a sampled region in San Francisco, California.
The workflow includes the following steps:
Loading and Preprocessing Data:
Generating Embeddings:
Saving Embeddings:
Similarity Search:
Install the stacchip library.
%pip install stacchip==0.1.33
/home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages/lancedb/__init__.py:220: UserWarning: lance is not fork-safe. If you are using multiprocessing, use spawn instead. warnings.warn(
Requirement already satisfied: stacchip==0.1.33 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (0.1.33) Requirement already satisfied: boto3>=1.29.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (1.38.14) Requirement already satisfied: geoarrow-pyarrow>=0.1.2 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (0.1.2) Requirement already satisfied: geopandas>=0.14.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (1.0.1) Requirement already satisfied: numpy>=1.26.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (2.2.5) Requirement already satisfied: planetary-computer>=1.0.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (1.0.0) Requirement already satisfied: pyarrow>=14.0.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (20.0.0) Requirement already satisfied: pystac-client>=0.7.5 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (0.8.6) Requirement already satisfied: pystac>=1.9.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (1.13.0) Requirement already satisfied: rasterio>=1.3.9 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (1.4.3) Requirement already satisfied: rio-stac>=0.9.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from stacchip==0.1.33) (0.11.0) Requirement already satisfied: botocore<1.39.0,>=1.38.14 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from boto3>=1.29.0->stacchip==0.1.33) (1.38.14) Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from boto3>=1.29.0->stacchip==0.1.33) (1.0.1) Requirement already satisfied: s3transfer<0.13.0,>=0.12.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from boto3>=1.29.0->stacchip==0.1.33) (0.12.0) Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from botocore<1.39.0,>=1.38.14->boto3>=1.29.0->stacchip==0.1.33) (2.9.0.post0) Requirement already satisfied: urllib3!=2.2.0,<3,>=1.25.4 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from botocore<1.39.0,>=1.38.14->boto3>=1.29.0->stacchip==0.1.33) (2.4.0) Requirement already satisfied: six>=1.5 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.39.0,>=1.38.14->boto3>=1.29.0->stacchip==0.1.33) (1.17.0) Requirement already satisfied: geoarrow-c in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geoarrow-pyarrow>=0.1.2->stacchip==0.1.33) (0.1.2) Requirement already satisfied: pyarrow-hotfix in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geoarrow-pyarrow>=0.1.2->stacchip==0.1.33) (0.7) Requirement already satisfied: pyogrio>=0.7.2 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geopandas>=0.14.1->stacchip==0.1.33) (0.11.0) Requirement already satisfied: packaging in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geopandas>=0.14.1->stacchip==0.1.33) (24.2) Requirement already satisfied: pandas>=1.4.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geopandas>=0.14.1->stacchip==0.1.33) (2.2.3) Requirement already satisfied: pyproj>=3.3.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geopandas>=0.14.1->stacchip==0.1.33) (3.7.1) Requirement already satisfied: shapely>=2.0.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from geopandas>=0.14.1->stacchip==0.1.33) (2.1.0) Requirement already satisfied: pytz>=2020.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pandas>=1.4.0->geopandas>=0.14.1->stacchip==0.1.33) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pandas>=1.4.0->geopandas>=0.14.1->stacchip==0.1.33) (2025.2) Requirement already satisfied: click>=7.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from planetary-computer>=1.0.0->stacchip==0.1.33) (8.2.0) Requirement already satisfied: pydantic>=1.7.3 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from planetary-computer>=1.0.0->stacchip==0.1.33) (2.11.4) Requirement already satisfied: requests>=2.25.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from planetary-computer>=1.0.0->stacchip==0.1.33) (2.32.3) Requirement already satisfied: python-dotenv in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from planetary-computer>=1.0.0->stacchip==0.1.33) (1.1.0) Requirement already satisfied: annotated-types>=0.6.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip==0.1.33) (0.7.0) Requirement already satisfied: pydantic-core==2.33.2 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip==0.1.33) (2.33.2) Requirement already satisfied: typing-extensions>=4.12.2 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip==0.1.33) (4.13.2) Requirement already satisfied: typing-inspection>=0.4.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip==0.1.33) (0.4.0) Requirement already satisfied: certifi in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pyogrio>=0.7.2->geopandas>=0.14.1->stacchip==0.1.33) (2025.4.26) Requirement already satisfied: jsonschema~=4.18 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from pystac[validation]>=1.10.0->pystac-client>=0.7.5->stacchip==0.1.33) (4.23.0) Requirement already satisfied: attrs>=22.2.0 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from jsonschema~=4.18->pystac[validation]>=1.10.0->pystac-client>=0.7.5->stacchip==0.1.33) (25.3.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from jsonschema~=4.18->pystac[validation]>=1.10.0->pystac-client>=0.7.5->stacchip==0.1.33) (2025.4.1) Requirement already satisfied: referencing>=0.28.4 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from jsonschema~=4.18->pystac[validation]>=1.10.0->pystac-client>=0.7.5->stacchip==0.1.33) (0.36.2) Requirement already satisfied: rpds-py>=0.7.1 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from jsonschema~=4.18->pystac[validation]>=1.10.0->pystac-client>=0.7.5->stacchip==0.1.33) (0.24.0) Requirement already satisfied: affine in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from rasterio>=1.3.9->stacchip==0.1.33) (2.4.0) Requirement already satisfied: cligj>=0.5 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from rasterio>=1.3.9->stacchip==0.1.33) (0.7.2) Requirement already satisfied: click-plugins in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from rasterio>=1.3.9->stacchip==0.1.33) (1.1.1) Requirement already satisfied: pyparsing in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from rasterio>=1.3.9->stacchip==0.1.33) (3.2.3) Requirement already satisfied: charset_normalizer<4,>=2 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from requests>=2.25.1->planetary-computer>=1.0.0->stacchip==0.1.33) (3.4.2) Requirement already satisfied: idna<4,>=2.5 in /home/brunosan/anaconda3/envs/claymodel/lib/python3.11/site-packages (from requests>=2.25.1->planetary-computer>=1.0.0->stacchip==0.1.33) (3.10) Note: you may need to restart the kernel to use updated packages.
import sys
sys.path.append("../../") # Model src
# If the pip install for stacchip doesn't work above,
# git clone the repo and comment out the following with the path
sys.path.append("../../../stacchip/")
import datetime
import glob
import math
import os
import random
import geopandas as gpd
import lancedb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystac_client
import requests
import shapely
import torch
import yaml
from box import Box
from pyproj import Transformer
from rasterio.io import MemoryFile
from shapely.geometry import box
from stacchip.chipper import Chipper
from stacchip.indexer import NoStatsChipIndexer
from stacchip.processors.prechip import normalize_timestamp
from torchvision.transforms import v2
from src.module import ClayMAEModule
# Define the platform name and year for the NAIP data
PLATFORM_NAME = "naip"
YEAR = 2023
# Query STAC catalog for NAIP data
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1"
)
# Perform a search on the STAC catalog,
# specifying the collection to search within (NAIP data),
# defining the bounding box for the search area (San Francisco region), and
# setting the date range for the search (entire year 2020).
# Also limit the search to a maximum of 100 items.
items = catalog.search(
collections=[PLATFORM_NAME],
bbox=[-122.6, 37.6, -122.35, 37.85],
datetime=f"{YEAR}-01-01T00:00:00Z/{YEAR+1}-01-01T00:00:00Z",
max_items=100,
)
# Convert the search results to an item collection
items = items.item_collection()
# Convert the item collection to a list for easier manipulation
items_list = list(items)
# Randomly shuffle the list of items to ensure random sampling
random.shuffle(items_list)
def get_bounds_centroid(url: str):
"""
Retrieve the bounds and centroid of an image from its URL.
Parameters:
url (str): The URL of the image.
Returns:
tuple: Bounds coordinates and centroid coordinates.
"""
response = requests.get(url)
response.raise_for_status()
with MemoryFile(response.content) as memfile:
with memfile.open() as src:
bounds = src.bounds
transformer = Transformer.from_crs(src.crs, 4326)
# Calculate centroid
centroid_x = (bounds.left + bounds.right) / 2
centroid_y = (bounds.top + bounds.bottom) / 2
centroid_x, centroid_y = transformer.transform(centroid_x, centroid_y)
bounds_b, bounds_l = transformer.transform(bounds.left, bounds.bottom)
bounds_t, bounds_r = transformer.transform(bounds.right, bounds.top)
return [bounds_b, bounds_l, bounds_t, bounds_r], centroid_x, centroid_y
chip_images = [] # List to hold chip pixels
chip_bounds = [] # List to hold chip bounds
for item in items_list[:2]:
print(f"Working on {item}")
# Index the chips in the item
indexer = NoStatsChipIndexer(item)
# Obtain the item bounds and centroid
bounds, centroid_x, centroid_y = get_bounds_centroid(item.assets["image"].href)
print(
f"Bbox coordinates: {bounds}, centroid coordinates: {centroid_x}, {centroid_y}"
)
# Instantiate the chipper
chipper = Chipper(indexer, asset_blacklist=["metadata"])
# Get 5 randomly sampled chips from the total
# number of chips within this item's entire image
for chip_id in random.sample(range(0, len(chipper)), 5):
chip_images.append(chipper[chip_id]["image"])
chip_bounds.append(bounds)
Visualize a generated image chip.
fig, ax = plt.subplots(1, 1, gridspec_kw={"wspace": 0.01, "hspace": 0.01}, squeeze=True)
chip = chip_images[0]
# Visualize the data
ax.imshow(chip[:3].swapaxes(0, 1).swapaxes(1, 2))
plt.tight_layout()
plt.show()
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[21], line 3 1 fig, ax = plt.subplots(1, 1, gridspec_kw={"wspace": 0.01, "hspace": 0.01}, squeeze=True) ----> 3 chip = chip_images[0] 5 # Visualize the data 6 ax.imshow(chip[:3].swapaxes(0, 1).swapaxes(1, 2)) IndexError: list index out of range
Below are some functions we will rely on to prepare the data cubes, generate embeddings, and plot subsets of the chipped images for visualization purposes.
def plot_rgb(stack):
"""
Plot the RGB bands of the given stack.
Parameters:
stack (xarray.DataArray): The input data array containing band information.
"""
stack.sel(band=[1, 2, 3]).plot.imshow(rgb="band", vmin=0, vmax=2000, col_wrap=6)
plt.show()
def normalize_latlon(lat, lon):
"""
Normalize latitude and longitude to a range between -1 and 1.
Parameters:
lat (float): Latitude value.
lon (float): Longitude value.
Returns:
tuple: Normalized latitude and longitude values.
"""
lat = lat * np.pi / 180
lon = lon * np.pi / 180
return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))
def load_model(ckpt, device="cuda"):
"""
Load a pretrained Clay model from a checkpoint.
Parameters:
ckpt (str): Path to the model checkpoint.
device (str): Device to load the model onto (default is 'cuda').
Returns:
model: Loaded model.
"""
torch.set_default_device(device)
model = ClayMAEModule.load_from_checkpoint(
ckpt, metadata_path="../configs/metadata.yaml", shuffle=False, mask_ratio=0
)
model.eval()
return model.to(device)
def prep_datacube(image, lat, lon, date, gsd, device):
"""
Prepare a data cube for model input.
Parameters:
image (np.array): The input image array.
lat (float): Latitude value for the location.
lon (float): Longitude value for the location.
device (str): Device to load the data onto.
Returns:
dict: Prepared data cube with normalized values and embeddings.
"""
platform = "naip"
# Extract mean, std, and wavelengths from metadata
metadata = Box(yaml.safe_load(open("../configs/metadata.yaml")))
mean = []
std = []
waves = []
bands = ["red", "green", "blue", "nir"]
for band_name in bands:
mean.append(metadata[platform].bands.mean[band_name])
std.append(metadata[platform].bands.std[band_name])
waves.append(metadata[platform].bands.wavelength[band_name])
transform = v2.Compose(
[
v2.Normalize(mean=mean, std=std),
]
)
# Prep datetimes embedding
times = normalize_timestamp(date)
week_norm = times[0]
hour_norm = times[1]
# Prep lat/lon embedding
latlons = normalize_latlon(lat, lon)
lat_norm = latlons[0]
lon_norm = latlons[1]
# Prep pixels
pixels = torch.from_numpy(image.astype(np.float32))
pixels = transform(pixels)
pixels = pixels.unsqueeze(0)
# Prepare additional information
return {
"pixels": pixels.to(device),
"time": torch.tensor(
np.hstack((week_norm, hour_norm)),
dtype=torch.float32,
device=device,
).unsqueeze(0),
"latlon": torch.tensor(
np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device=device
).unsqueeze(0),
"gsd": torch.tensor(gsd, device=device),
"waves": torch.tensor(waves, device=device),
}
def generate_embeddings(model, datacube):
"""
Generate embeddings from the model.
Parameters:
model (ClayMAEModule): The pretrained model.
datacube (dict): Prepared data cube.
Returns:
numpy.ndarray: Generated embeddings.
"""
with torch.no_grad():
unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)
# The first embedding is the class token, which is the
# overall single embedding.
return unmsk_patch[:, 0, :].cpu().numpy()
outdir_embeddings = "../data/embeddings/"
os.makedirs(outdir_embeddings, exist_ok=True)
# Download the pretrained model from
# https://huggingface.co/made-with-clay/Clay/blob/main/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt
# and put it in a checkpoints folder.
model = load_model(
ckpt="../../checkpoints/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt",
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
)
embeddings = []
i = 0
for tile, bounding_box in zip(chip_images, chip_bounds):
date = datetime.datetime.strptime(f"{YEAR}-06-01", "%Y-%m-%d")
gsd = 0.6
lon, lat = box(
bounding_box[0], bounding_box[1], bounding_box[2], bounding_box[3]
).centroid.coords[0]
datacube = prep_datacube(
np.array(tile), lat, lon, pd.to_datetime(f"{YEAR}-06-01"), gsd, model.device
)
embeddings_ = generate_embeddings(model, datacube)
embeddings.append(embeddings_)
data = {
"source_url": str(i),
"date": pd.to_datetime(arg=date, format="%Y-%m-%d"),
"embeddings": [np.ascontiguousarray(embeddings_.squeeze())],
"image": [np.ascontiguousarray(np.array(tile.transpose(1, 2, 0)).flatten())],
}
# Create the GeoDataFrame
gdf = gpd.GeoDataFrame(data, geometry=[bounding_box], crs="EPSG:4326")
outpath = f"{outdir_embeddings}/{i}.gpq"
gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
print(
f"Saved {len(gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
)
i += 1
print(f"Created {len(embeddings)} embeddings of shape {embeddings[0].shape[1]}")
We will select a random index to search with and plot the corresponding RGB images from the search results.
# Connect to the embeddings database
db = lancedb.connect(outdir_embeddings)
# Data for DB table
data = []
# Dataframe to find overlaps within
gdfs = []
idx = 0
for emb in glob.glob(f"{outdir_embeddings}/*.gpq"):
gdf = gpd.read_parquet(emb)
gdf["year"] = gdf.date.dt.year
gdf["tile"] = gdf["source_url"]
gdf["idx"] = idx
gdf["box"] = [shapely.geometry.box(*geom.bounds) for geom in gdf.geometry]
gdfs.append(gdf)
for _, row in gdf.iterrows():
data.append(
{
"vector": row["embeddings"],
"path": row["source_url"],
"tile": row["tile"],
"date": row["date"],
"year": int(row["year"]),
"idx": row["idx"],
"box": row["box"].bounds,
"image": row["image"],
}
)
idx += 1
# Combine the geodataframes into one
embeddings_gdf = pd.concat(gdfs, ignore_index=True)
# Drop existing table if any
try:
db.drop_table("clay-v001")
except FileNotFoundError:
pass
db.table_names()
# Create a new table with the embeddings data
tbl = db.create_table("clay-v001", data=data, mode="overwrite")
# Select a random embedding for the search query
idx = random.randint(0, len(embeddings_gdf))
v = tbl.to_pandas().iloc[idx]["vector"]
# Perform the search
search_x_images = 6
result = tbl.search(query=v).limit(search_x_images).to_pandas()
result
def plot(df, cols=4, save=False):
"""
Plot the top similar images.
Parameters:
df (pandas.DataFrame): DataFrame containing the search results.
cols (int): Number of columns to display in the plot.
"""
fig, axs = plt.subplots(1, cols, figsize=(20, 10))
i = 0
for ax, (_, row) in zip(axs.flatten(), df.iterrows()):
# row = df.iloc[i]
chip = np.array(row["image"]).reshape(256, 256, 4)
chip = chip[:, :, :3]
ax.imshow(chip)
ax.set_title(f"{row['idx']}")
i += 1
plt.tight_layout()
if save:
fig.savefig("similar.png")
# Plot the top similar images
plot(result, search_x_images)