#!/usr/bin/env python # coding: utf-8 # # CLAY v0 - Interpolation between images # In[ ]: import sys sys.path.append("../") # In[ ]: import os from pathlib import Path import imageio import matplotlib.pyplot as plt import numpy as np import torch from einops import rearrange from PIL import Image from src.datamodule import ClayDataModule, ClayDataset from src.model_clay import CLAYModule # In[ ]: # data directory for all chips DATA_DIR = "../data/02" # path of best model checkpoint for Clay v0 CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt" # ## Load Model & DataModule # In[ ]: # Load the model & set in eval mode model = CLAYModule.load_from_checkpoint( CKPT_PATH, mask_ratio=0.0, shuffle=False ) # No masking or shuffling of patches model.eval(); # In[ ]: data_dir = Path(DATA_DIR) # Load the Clay DataModule ds = ClayDataset(chips_path=list(data_dir.glob("**/*.tif"))) dm = ClayDataModule(data_dir=str(data_dir), batch_size=2) dm.setup(stage="fit") # Load the train DataLoader trn_dl = iter(dm.train_dataloader()) # In[ ]: # Load the first batch of chips batch = next(trn_dl) batch.keys() # In[ ]: batch["pixels"].shape, batch["latlon"].shape, batch["timestep"].shape # In[ ]: def show(sample, idx=None, save=False): Path("animate").mkdir(exist_ok=True) sample = rearrange(sample, "c h w -> h w c") denorm_sample = sample * torch.as_tensor(dm.STD) + torch.as_tensor(dm.MEAN) rgb = denorm_sample[..., [2, 1, 0]] plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min())) plt.axis("off") if save: plt.savefig(f"animate/chip_{idx}.png") # In[ ]: sample1, sample2 = batch["pixels"] # In[ ]: show(sample1) # In[ ]: show(sample2) # Each batch has chips of shape `13 x 512 x 512`, normalized `lat` & `lon` coords & normalized timestep information as `year`, `month` & `day`. # In[ ]: # Save a copy of batch to visualize later _batch = batch["pixels"].detach().clone().cpu().numpy() # ## Pass data through the CLAY model # In[ ]: # Pass the pixels through the encoder & decoder of CLAY with torch.no_grad(): # Move data from to the device of model batch["pixels"] = batch["pixels"].to(model.device) batch["timestep"] = batch["timestep"].to(model.device) batch["latlon"] = batch["latlon"].to(model.device) # Pass pixels, latlon, timestep through the encoder to create encoded patches ( unmasked_patches, unmasked_indices, masked_indices, masked_matrix, ) = model.model.encoder(batch) # ### Create an image based on interpolation of the embedding values between 2 images # *Images are saved inside `./animate`* # In[ ]: for idx, alpha in enumerate(np.linspace(0, 1, 20)): patch_break = 128 l1, l2 = unmasked_patches l3 = alpha * l1 + (1 - alpha) * l2 l4 = torch.vstack((l1[:patch_break, :], l2[patch_break:, :])) # Pass the unmasked_patches through the decoder to reconstruct the pixel space with torch.no_grad(): pixels = model.model.decoder( rearrange(l3, "gl d -> 1 gl d"), unmasked_indices[[0]], masked_indices[[0]] ) image = rearrange(pixels, "b c (h w) (p1 p2) -> b c (h p1) (w p2)", h=16, p1=32) _image = image[0].detach().cpu() show(_image, idx, save=True) # In[ ]: fig, axs = plt.subplots(2, 10, figsize=(20, 4)) for ax, idx in zip(axs.flatten(), range(20)): ax.imshow(Image.open(f"./animate/chip_{idx}.png")) ax.set_title(f"Seq {idx}") ax.set_axis_off() plt.tight_layout() # #### Create a GIF of the interpolation of images # In[ ]: img_paths = [f"./animate/chip_{idx}.png" for idx in range(20)] with imageio.get_writer("animate/sample.gif", mode="I", duration=100) as writer: for img_path in img_paths: img = imageio.imread(img_path) writer.append_data(img) # Delete the images for img_path in img_paths: os.remove(img_path) # In[ ]: from IPython.display import Image, display display(Image(filename="./animate/sample.gif")) # In[ ]: