#!/usr/bin/env python # coding: utf-8 # # Burn scar analysis using embeddings from partial inputs # This notebook contains a complete example for how to run Clay. It # combines the following three different aspects # # 1. Create single-chip datacubes with time series data for a location and a date range # 2. Run the model with partial inputs, in this case RGB + NIR # 3. Study burn scares through the embeddings generated for that datacube # # ## Let's start with importing and creating constants # In[1]: # Ensure working directory is the repo home import os os.chdir("..") # In[2]: import warnings from pathlib import Path import geopandas as gpd import matplotlib.pyplot as plt import numpy import pandas as pd import pystac_client import rasterio import rioxarray # noqa: F401 import stackstac import torch from rasterio.enums import Resampling from shapely import Point from sklearn import decomposition from src.datamodule import ClayDataModule from src.model_clay import CLAYModule warnings.filterwarnings("ignore") BAND_GROUPS = { "rgb": ["red", "green", "blue"], "rededge": ["rededge1", "rededge2", "rededge3", "nir08"], "nir": [ "nir", ], "swir": ["swir16", "swir22"], "sar": ["vv", "vh"], } STAC_API = "https://earth-search.aws.element84.com/v1" COLLECTION = "sentinel-2-l2a" # ## Search for imagery over an area of interest # In this example we use a location and date range to visualize a forest fire that happened in [Monchique in 2018](https://pt.wikipedia.org/wiki/Inc%C3%AAndio_de_Monchique_de_2018) # In[3]: # Point over Monchique Portugal poi = 37.30939, -8.57207 # Dates of a large forest fire start = "2018-07-01" end = "2018-09-01" catalog = pystac_client.Client.open(STAC_API) search = catalog.search( collections=[COLLECTION], datetime=f"{start}/{end}", bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5), max_items=100, query={"eo:cloud_cover": {"lt": 80}}, ) items = search.get_all_items() print(f"Found {len(items)} items") # ## Download the data # Get the data into a numpy array and visualize the imagery. The burn scar is visible in the last five images. # In[ ]: # Extract coordinate system from first item epsg = items[0].properties["proj:epsg"] # Convert point into the image projection poidf = gpd.GeoDataFrame( pd.DataFrame(), crs="EPSG:4326", geometry=[Point(poi[1], poi[0])], ).to_crs(epsg) coords = poidf.iloc[0].geometry.coords[0] # Create bounds of the correct size, the model # requires 512x512 pixels at 10m resolution. bounds = ( coords[0] - 2560, coords[1] - 2560, coords[0] + 2560, coords[1] + 2560, ) # Retrieve the pixel values, for the bounding box in # the target projection. In this example we use only # the RGB and NIR band groups. stack = stackstac.stack( items, bounds=bounds, snap_bounds=False, epsg=epsg, resolution=10, dtype="float32", rescale=False, fill_value=0, assets=BAND_GROUPS["rgb"] + BAND_GROUPS["nir"], resampling=Resampling.nearest, ) stack = stack.compute() stack.sel(band=["red", "green", "blue"]).plot.imshow( row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6 ) # ![Minicube visualization](https://github.com/Clay-foundation/model/assets/901647/c6e924e5-6ba1-4924-b99a-df8b90731a5f) # ## Write data to tif files # To use the mini datacube in the Clay dataloader, we need to write the # images to tif files on disk. These tif files are then used by the Clay # data loader for creating embeddings below. # In[5]: outdir = Path("data/minicubes") outdir.mkdir(exist_ok=True, parents=True) # Write tile to output dir for tile in stack: # Grid code like MGRS-29SNB mgrs = str(tile.coords["grid:code"].values).split("-")[1] date = str(tile.time.values)[:10] name = "{dir}/claytile_{mgrs}_{date}.tif".format( dir=outdir, mgrs=mgrs, date=date.replace("-", ""), ) tile.rio.to_raster(name, compress="deflate") with rasterio.open(name, "r+") as rst: rst.update_tags(date=date) # ## Create embeddings # Now switch gears and load the tiles to create embeddings and analyze them. # # The model checkpoint can be loaded directly from huggingface, and the data # directory points to the directory we created in the steps above. # # Note that the normalization parameters for the data module need to be # adapted based on the band groups that were selected as partial input. The # full set of normalization parameters can be found [here](https://github.com/Clay-foundation/model/blob/main/src/datamodule.py#L108). # ### Load the model and set up the data module # In[6]: DATA_DIR = "data/minicubes" CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt" # Load model rgb_model = CLAYModule.load_from_checkpoint( CKPT_PATH, mask_ratio=0.0, band_groups={"rgb": (2, 1, 0), "nir": (3,)}, bands=4, strict=False, # ignore the extra parameters in the checkpoint ) # Set the model to evaluation mode rgb_model.eval() # Load the datamodule, with the reduced set of class ClayDataModuleRGB(ClayDataModule): MEAN = [ 1369.03, # red 1597.68, # green 1741.10, # blue 2858.43, # nir ] STD = [ 2026.96, # red 2011.88, # green 2146.35, # blue 2016.38, # nir ] data_dir = Path(DATA_DIR) dm = ClayDataModuleRGB(data_dir=str(data_dir.absolute()), batch_size=20) dm.setup(stage="predict") trn_dl = iter(dm.predict_dataloader()) # ### Create the embeddings for the images over the forest fire # This will loop through the images returned by the data loader # and evaluate the model for each one of the images. The raw # embeddings are reduced to mean values to simplify the data. # In[7]: embeddings = [] for batch in trn_dl: with torch.inference_mode(): # Move data from to the device of model batch["pixels"] = batch["pixels"].to(rgb_model.device) # Pass just the specific band through the model batch["timestep"] = batch["timestep"].to(rgb_model.device) batch["latlon"] = batch["latlon"].to(rgb_model.device) # Pass pixels, latlon, timestep through the encoder to create encoded patches ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = rgb_model.model.encoder(batch) embeddings.append(unmasked_patches.detach().cpu().numpy()) embeddings = numpy.vstack(embeddings) embeddings_mean = embeddings[:, :-2, :].mean(axis=1) print(f"Average embeddings have shape {embeddings_mean.shape}") # ## Analyze embeddings # Now we can make a simple analysis of the embeddings. We reduce all the # embeddings to a single number using Principle Component Analysis. Then # we can plot the principal components. The effect of the fire on the # embeddings is clearly visible. We use the following color code in the graph: # # | Color | Interpretation | # |---|---| # | Green | Cloudy Images | # | Blue | Before the fire | # | Red | After the fire | # In[8]: pca = decomposition.PCA(n_components=1) pca_result = pca.fit_transform(embeddings_mean) plt.xticks(rotation=-30) # All points plt.scatter(stack.time, pca_result, color="blue") # Cloudy images plt.scatter(stack.time[0], pca_result[0], color="green") plt.scatter(stack.time[2], pca_result[2], color="green") # After fire plt.scatter(stack.time[-5:], pca_result[-5:], color="red") # In the plot above, each image embedding is one point. One can clearly # distinguish the two cloudy images and the values after the fire are # consistently low.