import sys
sys.path.append("../")
import warnings
from pathlib import Path
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from src.datamodule import ClayDataModule, ClayDataset
from src.model_clay import CLAYModule
warnings.filterwarnings("ignore")
L.seed_everything(42)
# data directory for all chips
DATA_DIR = "../data/02"
# path of best model checkpoint for Clay v0
CKPT_PATH = "../checkpoints/v0/mae_epoch-24_val-loss-0.46.ckpt"
# Load the model & set in eval mode
model = CLAYModule.load_from_checkpoint(CKPT_PATH, mask_ratio=0.7)
model.eval();
data_dir = Path(DATA_DIR)
# Load the Clay DataModule
ds = ClayDataset(chips_path=list(data_dir.glob("**/*.tif")))
dm = ClayDataModule(data_dir=str(data_dir), batch_size=100)
dm.setup(stage="fit")
# Load the train DataLoader
trn_dl = iter(dm.train_dataloader())
# Load the first batch of chips
batch = next(trn_dl)
batch.keys()
# Save a copy of batch to visualize later
_batch = batch["pixels"].detach().clone().cpu().numpy()
# Pass the pixels through the encoder & decoder of CLAY
with torch.no_grad():
# Move data from to the device of model
batch["pixels"] = batch["pixels"].to(model.device)
batch["timestep"] = batch["timestep"].to(model.device)
batch["latlon"] = batch["latlon"].to(model.device)
# Pass pixels, latlon, timestep through the encoder to create encoded patches
(
unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = model.model.encoder(batch)
# Pass the unmasked_patches through the decoder to reconstruct the pixel space
pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)
In CLAY, the encoder receives unmasked patches, latitude-longitude data, and timestep information. Notably, the last 2 embeddings from the encoder specifically represent the latitude-longitude and timestep embeddings.
latlon_embeddings = unmasked_patches[:, -2, :].detach().cpu().numpy()
time_embeddings = unmasked_patches[:, -1, :].detach().cpu().numpy()
# Get normalized latlon that were input to the model
latlon = batch["latlon"].detach().cpu().numpy()
We will just focus on location embeddings in this notebook
latlon.shape, latlon_embeddings.shape
Latitude & Longitude map to 768 dimentional vector
pca = PCA(n_components=2)
latlon_embeddings = pca.fit_transform(latlon_embeddings)
latlon_embeddings.shape
Latlon Cluster
kmeans = KMeans(n_clusters=5)
kmeans.fit_transform(latlon)
latlon = np.column_stack((latlon, kmeans.labels_))
Latlon Embeddings Cluster
kmeans = KMeans(n_clusters=5)
kmeans.fit_transform(latlon_embeddings)
latlon_embeddings = np.column_stack((latlon_embeddings, kmeans.labels_))
latlon.shape, latlon_embeddings.shape
We are a third dimension to latlon & latlon embeddings with cluster labels
plt.figure(figsize=(15, 15), dpi=80)
plt.scatter(latlon[:, 0], latlon[:, 1], c=latlon[:, 2], label="Actual", alpha=0.3)
for i in range(100):
txt = f"{latlon[:,0][i]:.2f},{latlon[:, 1][i]:.2f}"
plt.annotate(txt, (latlon[:, 0][i] + 1e-5, latlon[:, 1][i] + 1e-5))
As we see in the scatter plot above, there is nothing unique about latlon that go into the model, they are cluster based on their change in longitude values above
plt.figure(figsize=(15, 15), dpi=80)
plt.scatter(
latlon_embeddings[:, 0],
latlon_embeddings[:, 1],
c=latlon_embeddings[:, 2],
label="Predicted",
alpha=0.3,
)
for i in range(100):
txt = i
plt.annotate(txt, (latlon_embeddings[:, 0][i], latlon_embeddings[:, 1][i]))
def show_cluster(ids):
fig, axes = plt.subplots(1, len(ids), figsize=(10, 5))
for i, ax in zip(ids, axes.flatten()):
img_path = batch["source_url"][i]
img = rio.open(img_path).read([3, 2, 1]).transpose(1, 2, 0)
img = (img - img.min()) / (img.max() - img.min())
ax.imshow(img)
ax.set_axis_off()
show_cluster((87, 37, 40))
show_cluster((23, 11, 41))
show_cluster((68, 71, 7))
We can see location embedding capturing semantic information as well