#!/usr/bin/env python # coding: utf-8 # # Calculating Climatologies and Anomalies with Xarray and Dask: # # ## A Workaround for a Longstanding Problem # # Climatologies are anomalies are a core operation in climate science. Many workflows start with the following operations: # - Group spatiotemporal data by month or dayofyear (determined by the resolution of the dataset) # - Take a mean of each group to determine the "climatology" # - Broadcast the climatology back to the original dataset and subtract it, producing the "anomaly" # # Xarray makes this easy. We often write code like # # gb = ds.groupby('time.month') # clim = gb.mean(dim='time') # anom = gb - clim # # Unfortunately there are problems related to how dask deals with this operation # # - https://github.com/pydata/xarray/issues/1832 # - https://github.com/dask/dask/issues/874 # - https://github.com/pangeo-data/pangeo/issues/271 # # There have been many attempted fixes over the years (see linked PRs above). But none of them has been totally successful. # # Here we desribe a new approach. # In[15]: import xarray as xr from dask.distributed import Client import gcsfs get_ipython().run_line_magic('matplotlib', 'inline') xr.__version__ # ### The Dataset: MERRA2 Daily Surface Temprature # In[2]: gcs = gcsfs.GCSFileSystem(token = 'anon') to_map = gcs.get_mapper("ivanovich_merra2/t2maxdaily.zarr/") ds = xr.open_zarr(to_map) ds # In[3]: ds.t2mmax.data # ### Default, No Rechunking # In[4]: gb = ds.groupby('T.dayofyear') clim = gb.mean(dim='T') anom = gb - clim anom_std = anom.std(dim='T') anom_std.t2mmax.data # In[5]: from dask.distributed import Client client = Client("tcp://10.32.5.32:43525") client # We see we have balooned up to almost 100,000 tasks # In[6]: get_ipython().run_line_magic('time', 'anom_std.load()') # Two minutes is a really long time to process 12 GB of data. And the dask cluster almost choked in the process. # # The parallelism became too fine-grained, resulting in too much communication overhead. # # ### With Rechunking # # Since the operation is embarassingly parallel in the space dimension, but the data are chunked in the time dimension, one idea is that rechunking could help. # In[6]: ds_rechunk = ds.chunk({'T': -1, 'Y': 3}) ds_rechunk.t2mmax.data # In[8]: gb = ds_rechunk.groupby('T.dayofyear') clim = gb.mean(dim='T') anom = gb - clim anom_std = anom.std(dim='T') anom_std.t2mmax.data # This created **4.5 million tasks**! Clearly not the solution we were hoping for. I'm not even going to try to compute it. For whatever reason, the way these operation (mostly indexing and broadcasting) are interpreted by dask array does not allow them to leverage the parallelism we know is possible. # # ### The Workaround: `Xarray.map_blocks` # # Since the computation is embarassingly parallel in the space dimension, I could use `dask.array.map_blocks` to operate on each chunk in isolation. The problem is, I don't know how to write the groupby and broadcasting logic in pure numpy. I need xarray and its indexes. # # The solution is xarray's new `map_blocks` function. # In[9]: def calculate_anomaly(ds): # needed to workaround xarray's check with zero dimensions # https://github.com/pydata/xarray/issues/3575 if len(ds['T']) == 0: return ds gb = ds.groupby("T.dayofyear") clim = gb.mean(dim='T') return gb - clim # In[10]: t2mmax_anom = xr.map_blocks(calculate_anomaly, ds_rechunk.t2mmax) t2mmax_anom.data # That seems great! Only 300 chunks! Let's see how it performs. # In[13]: get_ipython().run_line_magic('time', "t2mmax_std = t2mmax_anom.std(dim='T').load()") # This was about twice as fast. Moreover, it feels like a more scalable approach. # In[15]: t2mmax_std.plot() # ## Compute the climatology and anomalies as 2D maps # # The advantage of the `map_blocks` approach is that it doesn't create too many chuncks. That way we can lazily build more operations on top of the anomaly dataset. # # Below we count the number of "hot events" (anomaly > 1 degree for two consecutive days) per year. # In[12]: rolling = t2mmax_anom.rolling(T = 2, center = True) rolling_hot = rolling.max() rolling_hot # In[19]: yearly_events = (rolling_hot > 1).astype('int').resample(T='YS').sum() yearly_events.data # In[20]: yearly_events.load() # In[26]: # skip 2019 yearly_events_mean = yearly_events[:-1].mean(dim='T') yearly_events_anom = yearly_events[:-1] - yearly_events_mean # In[27]: yearly_events_mean.plot() # In[30]: yearly_events_anom[0].plot() # In[31]: yearly_events_anom[-1].plot() # In[ ]: