#!/usr/bin/env python # coding: utf-8 # # MODIS Snow Cover Data Processing with Dask and Azure Blob Storage # ## Setup and Imports # In[ ]: # !pip install easysnowdata # !pip install coiled # !pip install geodatasets # !pip install adlfs # !pip install bottleneck #!pip install -U shapely # In[ ]: import numpy as np import pandas as pd import xarray as xr import zarr #from azure.storage.blob import BlobServiceClient import easysnowdata import modis_masking import coiled import tqdm import dask from dask.distributed import Client, wait import numcodecs import concurrent.futures import os import logging import traceback import geopandas as gpd import geodatasets import matplotlib.pyplot as plt import odc.stac import getpass import adlfs import json import pathlib odc.stac.configure_rio(cloud_defaults=True) # ## Configuration # In[ ]: WY_start = 2015 WY_end = 2024 # get token from https://github.com/egagli/azure_authentication/raw/main/sas_token.txt sas_token = pathlib.Path('sas_token.txt').read_text() store = adlfs.AzureBlobFileSystem(account_name="snowmelt", credential=sas_token).get_mapper("snowmelt/snow_mask_v2/global_modis_snow_mask.zarr") # ## Prepare and visualize MODIS Grid # In[ ]: modis_grid = gpd.read_file('zip+http://book.ecosens.org/wp-content/uploads/2016/06/modis_grid.zip!modis_sinusoidal_grid_world.shp') land = gpd.read_file(geodatasets.get_url('naturalearth land')) land_modis_crs = gpd.GeoSeries(land.union_all(), crs='EPSG:4326').to_crs(modis_grid.crs) modis_grid_land_idx = modis_grid.intersects(land_modis_crs.union_all()) modis_grid_land_idx[600] = False modis_grid_land = modis_grid[modis_grid_land_idx] modis_grid_not_land = modis_grid[~modis_grid_land_idx] modis_grid_land_list = list(modis_grid_land.iterrows()) tile_processing_list = [f'h{tile["h"]}_v{tile["v"]}' for _, tile in modis_grid_land_list] # In[ ]: f, ax = plt.subplots(figsize=(15,15)) land_modis_crs.plot(ax=ax, color='green') modis_grid_not_land.geometry.boundary.plot(ax=ax, color='gray', linewidth=0.5) modis_grid_land.geometry.boundary.plot(ax=ax, color='blue', linewidth=2) h_values = sorted(modis_grid['h'].unique()) v_values = sorted(modis_grid['v'].unique(), reverse=True) h_coords = [modis_grid[modis_grid['h'] == h].geometry.centroid.x.mean() for h in h_values] v_coords = [modis_grid[modis_grid['v'] == v].geometry.centroid.y.mean() for v in v_values] ax.set_xticks(h_coords) ax.set_xticklabels([f'h{h}' for h in h_values]) ax.set_yticks(v_coords) ax.set_yticklabels([f'v{v}' for v in v_values]) ax.tick_params(axis='both', which='both', length=0) ax.set_title('MODIS grid system\nland tiles in blue') ax.set_xlabel('Horizontal tile number') ax.set_ylabel('Vertical tile number') ax.set_xlim(modis_grid.total_bounds[0], modis_grid.total_bounds[2]) ax.set_ylim(modis_grid.total_bounds[1], modis_grid.total_bounds[3]) ax.set_title('MODIS grid system\nland tiles in blue') f.tight_layout() # ## Define Processing Functions # # In[ ]: def create_azure_zarr_store(store): water_years = np.arange(WY_start, WY_end + 1) num_years = len(water_years) modis_snow_entire_extent_footprint = modis_masking.get_modis_MOD10A2_full_grid() y = modis_snow_entire_extent_footprint.y.values x = modis_snow_entire_extent_footprint.x.values shape = (num_years, 18 * 2400, 36 * 2400) chunks = (1, 2400, 2400) fill_value = np.iinfo(np.int16).min ds = xr.Dataset( { 'SAD_DOWY': (('water_year', 'y', 'x'), dask.array.full(shape, fill_value=fill_value, chunks=chunks, dtype=np.int16)), 'SDD_DOWY': (('water_year', 'y', 'x'), dask.array.full(shape, fill_value=fill_value, chunks=chunks, dtype=np.int16)), 'max_consec_snow_days': (('water_year', 'y', 'x'), dask.array.full(shape, fill_value=fill_value, chunks=chunks, dtype=np.int16)), }, coords={ 'water_year': water_years, 'y': y, 'x': x, } ) ds.water_year.attrs['description'] = ("Water year. In northern hemisphere, water year starts on October 1st " "and ends on September 30th. For the southern hemisphere, water year " "starts on April 1st and ends on March 31st. e.g. in NH WY 2015 is " "[2014-10-01,2015-09-30] and in SH WY 2015 is [2015-04-01,2016-03-31].") ds.attrs['processed_tiles'] = [] encoding = {var: {'chunks': chunks, 'compressor': zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)} for var in ds.data_vars} ds.rio.write_crs(modis_snow_entire_extent_footprint.rio.crs, inplace=True) # https://github.com/pydata/xarray/issues/6288#issuecomment-1230970216 for var in ds.data_vars: ds[str(var)].attrs['grid_mapping'] = 'spatial_ref' ds.to_zarr(store, mode='w', encoding=encoding, compute=False, consolidated=True) return ds # In[ ]: def process_tile(tile, store): h, v = (int(part[1:]) for part in tile.split('_')) #odc.stac.configure_rio(cloud_defaults=True) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logging.getLogger('azure').setLevel(logging.WARNING) logger.info(f"Starting process for tile {tile}") try: #logger.info(f"Zarr store opened successfully") hemisphere = "northern" if v < 9 else "southern" if hemisphere == "southern": WY_end = 2023 else: WY_end = 2024 #logger.info(f"Fetching MODIS data for tile {tile}") if hemisphere == "northern": modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent( vertical_tile=v, horizontal_tile=h, start_date=f"{WY_start-1}-10-01", end_date=f"{WY_end}-09-30", #chunks={"time": -1, "y": 600, "x": 600}, chunks={"time": 1, "y": 2400, "x": 2400}, ).chunk({"time": -1, "y": 600, "x": 600}) else: modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent( vertical_tile=v, horizontal_tile=h, start_date=f"{WY_start}-04-01", end_date=f"{WY_end+1}-03-31", #chunks={"time": -1, "y": 600, "x": 600}, chunks={"time": 1, "y": 2400, "x": 2400}, ).chunk({"time": -1, "y": 600, "x": 600}) #logger.info(f"Processing MODIS data for tile {tile}") modis_snow_da.coords["water_year"] = ( "time", pd.to_datetime(modis_snow_da.time).map( lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere) ), ) modis_snow_da.coords["DOWY"] = ( "time", pd.to_datetime(modis_snow_da.time).map( lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere) ), ) modis_snow_da = modis_snow_da[ (modis_snow_da.water_year >= WY_start) & (modis_snow_da.water_year <= WY_end) ] #logger.info(f"Applying binarize_with_cloud_filling for tile {tile}") effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da) #logger.info(f"Calculating seasonal snow presence for tile {tile}") seasonal_snow_presence = effective_snow_da.groupby("water_year").apply( modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY ) #logger.info(f"Writing results to zarr store for tile {tile}") num_years = len(seasonal_snow_presence.water_year) water_year_slice = slice(0, num_years) y_slice = slice(v * 2400, (v + 1) * 2400) x_slice = slice(h * 2400, (h + 1) * 2400) existing_ds = xr.open_zarr(store, consolidated=True) y_coords = existing_ds.y[y_slice].values x_coords = existing_ds.x[x_slice].values if np.allclose(y_coords, seasonal_snow_presence.y.values, atol=0.1) or np.allclose(x_coords, seasonal_snow_presence.x.values, atol=0.1): seasonal_snow_presence = seasonal_snow_presence.assign_coords(y=y_coords, x=x_coords) else: logger.error(f"y or x coordinates do not match for tile {tile}") raise ValueError(f"y or x coordinates do not match for tile {tile}") seasonal_snow_presence.drop_vars('spatial_ref').chunk({'water_year':1,'y':2400,'x':2400}).to_zarr(store, region="auto", mode="r+", consolidated=True) #logger.info(f"Tile {tile} processed and written successfully") #existing_ds.attrs['processed_tiles'].append(tile) #logger.info(f"Tile {tile} processed and written, added to processed_tiles list") return True except Exception as e: logger.error(f"(PT) Error processing tile {tile}: {str(e)}") logger.error(f"(PT) Traceback: {traceback.format_exc()}") return False # In[ ]: # In[ ]: start_fresh = False if start_fresh: zarr_store_ds = create_azure_zarr_store(store) zarr_store_ds # ## Set Up Dask Cluster with Coiled # In[ ]: # cluster = coiled.Cluster(idle_timeout="15 minutes", # n_workers=[20,100], # worker_memory="8 GiB", # #worker_options={"nthreads": 8}, # spot_policy="spot", # environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"}, # workspace="azure", # ) # #cluster.adapt(minimum=10, maximum=100) # client = cluster.get_client() # In[ ]: from dask.distributed import Client client = Client() client.cluster.scale(100) client # In[ ]: odc.stac.configure_rio(cloud_defaults=True, client=client) # ## Process MODIS Tiles # In[ ]: failed_tiles = [] processed_tiles_list_initial = zarr.open(store).attrs['processed_tiles'] for tile in tqdm.tqdm(tile_processing_list): if tile in processed_tiles_list_initial: print(f"Tile {tile} already processed, skipping") continue result = process_tile(tile, store) if result == True: with zarr.open(store) as zarr_store: processed_tile_list = zarr_store.attrs['processed_tiles'] processed_tile_list.append(tile) zarr_store.attrs['processed_tiles'] = processed_tile_list print(f"Tile {tile} SUCCESS, added to processed_list attribute") client.restart(wait_for_workers=True) else: print(f"Tile {tile} FAIL, adding to failed list") failed_tiles.append(tile) if failed_tiles: print("Run this cell again. The following tiles could not be processed:") for tile in failed_tiles: print(tile) else: print("Now consolidating metadata...") zarr.consolidate_metadata(store) print("All tiles processed successfully!!!") # In[ ]: # In[ ]: # In[ ]: client.cluster.scale(1) # In[ ]: # In[ ]: # In[ ]: # In[ ]: # ## other approaches (code graveyard) # ### serverless approach (this got close, couldn't push it across finish line though) # In[ ]: # https://docs.coiled.io/user_guide/functions.html # inspired by: https://github.com/earth-mover/serverless-datacube-demo/blob/main/src/lib.py # maybe another option: https://xarray.dev/blog/cubed-xarray # @coiled.function( # n_workers=50, # cpu=4, # #threads_per_worker=8, # memory="16GiB", # spot_policy="spot", # region="westeurope", # environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"}, # keepalive="5m", # workspace="azure" # ) # # def process_chunks(tile_list, store): # # odc.stac.configure_rio(cloud_defaults=True) # # results = [] # # for _, tile in tile_list: # # h = tile['h'] # # v = tile['v'] # # result = process_and_write_tile(h, v, store, serverless=False) # # results.append(result) # # return results # def process_chunk(tile, store): # odc.stac.configure_rio(cloud_defaults=True) # #with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"): # process = process_and_write_tile(tile, store, serverless=False) # return process # def spawn_coiled_jobs( # modis_grid_land_list, store): # h_list = [tile['h'] for _, tile in modis_grid_land_list] # v_list = [tile['v'] for _, tile in modis_grid_land_list] # results = list( # tqdm.tqdm( # process_chunk.map( # h_list, # v_list, # store=store, # retries=5 # ), # total=len(h_list), # desc="Jobs Completed", # ) # ) # return results # # def spawn_coiled_jobs(modis_grid_land_list, store, batch_size=10): # # batches = [modis_grid_land_list[i:i+batch_size] for i in range(0, len(modis_grid_land_list), batch_size)] # # results = list( # # tqdm.tqdm( # # process_chunks.map( # # batches, # # store=store, # # retries=5 # # ), # # total=len(batches), # # desc="Batch completed", # # ) # # ) # # return [item for sublist in results for item in sublist] # #results = spawn_coiled_jobs(modis_grid_land_list, store) # #results # In[ ]: #futures = [] # # for _, tile in tqdm.tqdm(modis_grid_land_list): # # h = tile['h'] # # v = tile['v'] # # try: # # process_and_write_tile(h, v, store) # # print(f"Tile h{h}_v{v} processed and written") # # except Exception as e: # # print(f"Error processing tile h{h}_v{v}: {str(e)}") # # print(f"Traceback: {traceback.format_exc()}") # # # maybe append to a list of all tiles that need to be rerun # # #future = client.submit(process_and_write_tile, h, v, store) # # #futures.append(future) # # #results = wait(futures) # # # # for future in futures: # try: # result = future.result() # print(result) # except Exception as e: # print(f"Task failed: {str(e)}") # print(f"Traceback: {future.traceback()}") # # # # client.close() # cluster.close() # # #seasonal_snow_presence.drop_vars('spatial_ref').chunk({'water_year':1,'y':2400,'x':2400}).to_zarr(store, region={'water_year':water_year_slice,'y':y_slice,'x':x_slice}, mode="r+") # if serverless: # print(f'running serverless mode, using threadpoolexecutor...') # with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"): # for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']: # data = seasonal_snow_presence[var].values # root[var][:,y_start:y_end,x_start:x_end] = data # else: # for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']: # data = seasonal_snow_presence[var].values # root[var][:,y_start:y_end,x_start:x_end] = data # root[:, time_slice, y_slice, x_slice] = data #root[var][time_slice, y_slice, x_slice] = data # if data.shape[0] == 9 and data.shape[1] == 2400 and data.shape[2] == 2400: # print(f'transpose necessary h{h}_v{v}') # data = np.transpose(data, (1, 2, 0)) # store.flush() # with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: # Adjust number as needed # futures = [executor.submit(process_and_write_tile, h, v, azure_zarr_path) # for h, v in modis_grid_land_list] # def process_batch(batch): # results = [] # for h, v in batch: # results.append(process_and_write_tile(h, v, azure_zarr_path)) # return results # batch_size = 10 # Adjust based on your workload # batches = [modis_grid_land_list[i:i+batch_size] for i in range(0, len(modis_grid_land_list), batch_size)] # futures = client.map(process_batch, batches) # def create_azure_zarr_store(connection_string, container_name, zarr_store_path): # blob_service_client = BlobServiceClient.from_connection_string(connection_string) # container_client = blob_service_client.get_container_client(container_name) # class AzureBlobStore(zarr.ABSStore): # def __init__(self, container_client, prefix): # self.container_client = container_client # self.prefix = prefix # self.client = container_client # Add this line # def __getitem__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # return blob_client.download_blob().readall() # def __setitem__(self, key, value): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # blob_client.upload_blob(value, overwrite=True) # def __contains__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # return blob_client.exists() # def __delitem__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # blob_client.delete_blob() # def rmdir(self, path): # dir_path = self.prefix # if path: # dir_path += "/" + path # dir_path += "/" # blobs_to_delete = self.container_client.list_blobs( # name_starts_with=dir_path # ) # for blob in blobs_to_delete: # self.container_client.delete_blob(blob) # store = AzureBlobStore(container_client, zarr_store_path) # root = zarr.open(store, mode="w") # # root.create_dataset('SAD_DOWY', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2') # # root.create_dataset('SDD_DOWY', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2') # # root.create_dataset('max_consec_snow_days', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2') # water_years = list(range(WY_start, WY_end + 1)) # num_years = len(water_years) # compressor = numcodecs.Blosc( # cname="zstd", clevel=3, shuffle=numcodecs.Blosc.SHUFFLE # ) # # Create datasets # for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']: # dataset = root.create_dataset( # var, # shape=(num_years, 18 * 2400, 36 * 2400), # chunks=(1, 2400, 2400), # dtype="i2", # compressor=compressor, # ) # # Add dimension names as attributes # #root.create_dataset("water_year", data=water_years, shape=(num_years,), dtype="i2") # # root["time"].attrs[ # # "description" # # ] = "Water year. In northern hemisphere, water year starts on October 1st and ends on September 30th. For the southern hemisphere, water year starts on April 1st and ends on March 31st. For example, in the northern hemisphere water year 2015 starts on October 1st, 2014 and ends on September 30th, 2015, and in the southern hemisphere water year 2015 starts on April 1st, 2015 and ends on March 31st, 2016." # return f"azure://{container_name}/{zarr_store_path}" # from azure.core.exceptions import ResourceNotFoundError # class AzureBlobStore(zarr.ABSStore): # def __init__(self, container_client, prefix): # self.container_client = container_client # self.prefix = prefix # self.client = container_client # Add this line # def __getitem__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # return blob_client.download_blob().readall() # def __setitem__(self, key, value): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # blob_client.upload_blob(value, overwrite=True) # def __contains__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # return blob_client.exists() # def __delitem__(self, key): # blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}") # blob_client.delete_blob() # def rmdir(self, path): # dir_path = self.prefix # if path: # dir_path += "/" + path # dir_path += "/" # blobs_to_delete = self.container_client.list_blobs( # name_starts_with=dir_path # ) # for blob in blobs_to_delete: # self.container_client.delete_blob(blob) #blob_service_client = BlobServiceClient.from_connection_string(connection_string) #container_client = blob_service_client.get_container_client(container_name) #store = AzureBlobStore(container_client, zarr_store_path) #root = zarr.open(store, mode="w") #y = np.arange(0, 18 * 2400) #x = np.arange(0, 36 * 2400) #connection_string = os.environ["azure-storage-connection-string"] #parts = azure_zarr_path.split("/") #container_name = parts[2] #zarr_store_path = "/".join(parts[3:]) # blob_service_client = BlobServiceClient.from_connection_string( # connection_string # ) #container_client = blob_service_client.get_container_client(container_name) #store = AzureBlobStore(connection_string,container_client, zarr_store_path) #root = zarr.open(store, mode="a") # x_start, x_end = h * 2400, (h + 1) * 2400 # y_start, y_end = v * 2400, (v + 1) * 2400 # with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"): # data = seasonal_snow_presence[['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']].to_array().values #'water_year':water_years,time_slice = slice(0, data.shape[0]) #seasonal_snow_presence.drop_vars('spatial_ref').chunk({'water_year':num_years,'y':2400,'x':2400}).to_zarr(store, region={'water_year':water_year_slice,'y':y_slice,'x':x_slice}, mode="r+") # def check_environment(): # import sys # import os # result = { # "sys.path": sys.path, # "current_dir": os.getcwd(), # "list_dir": os.listdir(), # "env_vars": dict(os.environ), # } # try: # import easysnowdata # result["easysnowdata_version"] = easysnowdata.__version__ # except ImportError as e: # result["easysnowdata_error"] = str(e) # try: # import modis_masking # result["modis_masking_file"] = modis_masking.__file__ # except ImportError as e: # result["modis_masking_error"] = str(e) # return result # # Run this on all workers # environment_info = client.run(check_environment) # # Print the results # for worker, info in environment_info.items(): # print(f"Worker {worker}:") # for key, value in info.items(): # print(f" {key}: {value}") # print() # Set the Azure Blob Storage path for the zarr store #container_name = "snowmelt" #zarr_store_path = "modis_mask/global_modis_snow_mask.zarr" #azure_zarr_path = f"azure://{container_name}/{zarr_store_path}" # # Load progress # progress = load_progress() # processed_tiles = set(progress['processed']) # failed_tiles = set(progress['failed']) # # Load processed tiles from zarr # zarr_store = zarr.open(store, mode='r') # zarr_processed_tiles = set(zarr_store.attrs['processed_tiles']) # failed_tiles = [] # def process_tile(tile, store): # result = process_and_write_tile(tile, store) # client.restart() # Restart workers to clear memory # return result # # First pass: process all tiles # for tile in tqdm.tqdm(tile_processing_list): # try: # result = process_tile(tile, store) # print(f"Tile {tile} processed and written") # except Exception as e: # print(f"Error processing tile {tile}: {str(e)}") # print(f"Traceback: {traceback.format_exc()}") # failed_tiles.append(tile) # # Second pass: retry failed tiles # max_retries = 3 # retry_count = 0 # while failed_tiles and retry_count < max_retries: # retry_count += 1 # print(f"Retry attempt {retry_count} for failed tiles") # still_failed = [] # for tile in tqdm.tqdm(failed_tiles): # try: # result = process_tile(tile, store) # print(f"Tile {tile} processed and written on retry") # except Exception as e: # print(f"Error processing tile {tile} on retry: {str(e)}") # print(f"Traceback: {traceback.format_exc()}") # still_failed.append(tile) # failed_tiles = still_failed # if failed_tiles: # print("The following tiles could not be processed after all retries:") # for tile in failed_tiles: # print(f"{tile}") # client.close() # cluster.close() # In[ ]: # # fixed # In[ ]: