Using Xarray and Dask to process 1 TB of climate data
The slides I'm using are available from the CMS wiki - http://climate-cms.wikis.unsw.edu.au/Training
I'm using 'big data' to describe datasets that are bigger than your computer's memory
At NCI there are several different computer types
Raijin:
normal: 32 / 64 / 128 GB per 16 CPU cores
normalbw: 128 / 256 GB per 28 CPU cores
normalsl: 192 GB per 32 CPU cores
VDI: 32 GB per 8 CPU cores
Work on one block of data at a time, doing as much as you can
Test with small subsets of the data
If you create intermediate files delete them when they're no longer needed
Consider how big your final output is going to be
CMS maintain replicas of many useful datasets at NCI so you don't have to download them yourself - http://climate-cms.wikis.unsw.edu.au/Category:Dataset
Heatwave detection requires a threshold value for each grid cell at each day of the year. This threshold is
For thirty years of 0.25 degree hourly ERA-5 data this is going to involve processing about a terrabyte of data
Hopefully you can take some of the techniques used here for your own analyses of large datasets
What data do we need for each step?
Daily maximum temperature
15 day rolling mean
90th percentile at each day of the year
The data comes in one file per calendar month
Operations 1 and 2 have similar data needs - they could easily be combined into a single operation acting on a couple files at a time. They reduce the data size to 1/24th of the input size
Operation 3 is more complex - you could gather all of the January files together and process them, then all of the February files together etc. - but be careful of leap years. If we've got 30 years of input this reduces the data size to 1/30th of the input size
The final output should be only a couple gigabytes, and will be much easier to use when detecting events
A perfectly reasonable way to do this analysis is to manually run through the months as described above.
for month in range(1,2):
this_month = []
for year in range(1980,1982):
# Open files for this month plus the next and previous (for the rolling mean)
this_year_month = load_data(year, month)
# Process steps 1 & 2 for this month
daily_max = mx2t.resample(time='D').max('time').chunk({'time':-1})
rolling_average = (daily_max
.rolling(time=15, center=True)
.mean()
.sel(time=slice(this_date,
this_date + pandas.tseries.offsets.MonthEnd())))
print(rolling_average)
this_month.append(rolling_average)
# Process step 3 and save this month to file
month_threshold = (xarray.concat(this_month, 'time')
.groupby('time.dayofyear')
.reduce(numpy.percentile, q=90, dim='time'))
month_threshold.to_netcdf('era5-monthly-threshold-%02d.nc'%month)
This example however is going to use the Dask library to do the analysis - it allows for parallel processing and lets us control memory more.
I'm using the conda/analysis3-unstable
environment for this demo, you can load it on Raijin/VDI with
module use /g/data3/hh5/public/modules
module load conda/analysis3-unstable
To start off with let's load some libraries
%matplotlib inline
import xarray
import dask
import bottleneck
import numpy
import dask.diagnostics
dask.diagnostics.ProgressBar().register()
Our local archives have hourly fields at 0.25 x 0.25 degree resolution. For the single-level MX2T variable that's 408 GB of data, split across 467 files.
ds = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/MX2T_era5_global_*.nc')
ds
<xarray.Dataset> Dimensions: (latitude: 721, longitude: 1440, time: 341057) Coordinates: * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75 * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 * time (time) datetime64[ns] 1979-01-01T07:00:00 ... 2019-02-28T23:00:00 Data variables: mx2t (time, latitude, longitude) float32 dask.array<shape=(341057, 721, 1440), chunksize=(737, 721, 1440)> Attributes: Conventions: CF-1.6 history: 2019-03-11 22:11:16 GMT by grib_to_netcdf-2.10.0: /opt/ecmw...
open_mfdataset()
creates a virtual dataset out of multiple files by concatenating the contents together
This data goes from 1979/01/01 0700Z to 2019/02/28 2300Z. To make analysis simpler let's trim it so it starts at Jan 1 0000Z and ends at Dec 31 2300Z using a selector
ds = ds.sel(time = slice('19800101T0000Z', '20091231T2300'))
Since the data in the files is compressed, the in-memory size of this data is much bigger than the size on disk - a bit over one TB
mx2t = ds.mx2t
print(mx2t.nbytes / 1024 ** 3, 'GB')
1017.1860980987549 GB
Dask lets us work with data much larger than our computer's memory, because it only loads data when it's really needed - if we make a plot of a certain region it will only load that region's data
http://docs.dask.org/en/latest/array.html
https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.nbytes.html
Now to start our analysis - we have hourly max temperature, but want daily.
We could use resample, but this takes a bit of time for Dask to process
%%time
daily_max = mx2t.resample(time='D').max('time')
print(daily_max)
print()
<xarray.DataArray 'mx2t' (time: 10958, latitude: 721, longitude: 1440)> dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(1, 721, 1440)> Coordinates: * time (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31 * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75 * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 CPU times: user 19.1 s, sys: 25 ms, total: 19.2 s Wall time: 19.1 s
In general Dask operations should return pretty much immediately
You can see that chunksize=(1, 721, 1440)
- dask has created a new chunk for each individual day. With 10958 days to keep track of dask can get bogged down
A different way to do the daily mean is by manipulating the shape of the array - we know since this is well-structured published data that there are always 24 values in each day's data
%%time
daily_time = mx2t.time.data.reshape((-1, 24))[:,0]
print(daily_time[0:4])
print()
daily_max_data = mx2t.data.reshape((-1, 24, mx2t.shape[1], mx2t.shape[2])).max(axis=1)
print(daily_max_data)
print()
['1980-01-01T00:00:00.000000000' '1980-01-02T00:00:00.000000000' '1980-01-03T00:00:00.000000000' '1980-01-04T00:00:00.000000000'] dask.array<mean_agg-aggregate, shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)> CPU times: user 11 ms, sys: 1e+03 µs, total: 12 ms Wall time: 11.1 ms
We've lost the xarray metadata, but you can see that the chunks are much bigger - we have a chunk for each month rather than each day
We can add metadata back by copying it from the original field
daily_max = xarray.DataArray(daily_max_data,
name = 'daily_max_t',
dims = mx2t.dims,
coords = {
'time': ('time', daily_time),
'latitude': mx2t.latitude,
'longitude': mx2t.longitude,
})
daily_max
<xarray.DataArray 'daily_max_t' (time: 10958, latitude: 721, longitude: 1440)> dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)> Coordinates: * time (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31 * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
Our next step is to do a rolling mean of the daily maximum
rolling = daily_max.rolling(time=15, center=True).mean()
rolling
<xarray.DataArray 'getitem-cee79fa3c2352a993a66c0dc8ded60c9' (time: 10958, latitude: 721, longitude: 1440)> dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(10958, 721, 1440)> Coordinates: * time (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31 * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
This time we've the opposite problem - rather than having one chunk per day, we have one chunk for the entire 1 TB dataset! This is going to cause trouble when analysing the data, since it's going to get loaded as one giant chunk
http://xarray.pydata.org/en/stable/generated/xarray.DataArray.rolling.html#xarray.DataArray.rolling
This is due to a bug in Xarray. It's possible, although clunky, to do the rolling mean directly in dask and keep the original chunking
rolling_data = dask.array.overlap.map_overlap(
daily_max.data,
func=bottleneck.move_mean,
window=15,
axis=0,
depth=(14,0,0),
boundary='reflect',
trim=True,
dtype=daily_max.data.dtype)
rolling_data
dask.array<_trim, shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)>
Note the rolling mean is not centred doing the calculation this way - it's the mean over this day and 14 previous days, rather than the mean of 7 days before, this day, 7 days after.
http://docs.dask.org/en/latest/array-api.html#dask.array.overlap.map_overlap
https://kwgoodman.github.io/bottleneck-doc/reference.html#bottleneck.move_mean
When we convert back to a DataArray we need to keep this time offset in mind
rolling_time = daily_max.time.data - numpy.timedelta64(7, 'D')
rolling = xarray.DataArray(rolling_data,
name = 'daily_max_t',
dims = daily_max.dims,
coords = {
'time': ('time', rolling_time),
'latitude': daily_max.latitude,
'longitude': daily_max.longitude,
})
rolling
<xarray.DataArray 'daily_max_t' (time: 10958, latitude: 721, longitude: 1440)> dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)> Coordinates: * time (time) datetime64[ns] 1979-12-25 1979-12-26 ... 2009-12-24 * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
At this point all of the time domain processing is done - we've calculated the 15 day rolling mean of the daily maximum temperature. It's a good idea to plot some data to make sure everything's looking reasonable
rolling.sel(time='19960728').plot()
[########################################] | 100% Completed | 32.6s [########################################] | 100% Completed | 29.2s
<matplotlib.collections.QuadMesh at 0x7fe920883898>
Now it's time to get the 90th percentile for each day in the year.
The simplest way to do this is
rolling.groupby('time.dayofyear').reduce(numpy.percentile, dim='time', q=90)
however there are problems with our giant dataset:
.reduce()
is not Dask-aware - it loads all the data. This can be fixed by adding allow_lazy=True
to the argumentsnumpy.percentile
isn't Dask aware either, so we need to add Dask support to the existing functiondask.map_blocks
allows us to make a Dask-aware version of numpy.percentile
- it tells Dask to run a function on each chunk. Since percentile
is a reduction operation we need to join up all of the chunks along the time axis first and add the drop_axis
argument
def dask_percentile(array, axis, q):
array = array.rechunk({axis: -1})
return array.map_blocks(
numpy.percentile,
axis=axis,
q=q,
dtype=array.dtype,
drop_axis=axis)
doy_p90 = (rolling
.groupby('time.dayofyear')
.reduce(dask_percentile, dim='time', q=90,
allow_lazy=True))
doy_p90
<xarray.DataArray 'daily_max_t' (dayofyear: 366, latitude: 721, longitude: 1440)> dask.array<shape=(366, 721, 1440), dtype=float32, chunksize=(1, 721, 1440)> Coordinates: * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0 * longitude (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75 * dayofyear (dayofyear) int64 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366
Now we're ready to try it out!
doy_p90.sel(dayofyear=39).plot()
[############## ] | 35% Completed | 5min 16.6s
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-13-d2d1ceb6f742> in <module> ----> 1 doy_p90.sel(dayofyear=39).plot() /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in __call__(self, **kwargs) 421 422 def __call__(self, **kwargs): --> 423 return plot(self._da, **kwargs) 424 425 @functools.wraps(hist) /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs) 170 kwargs['ax'] = ax 171 --> 172 return plotfunc(darray, **kwargs) 173 174 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs) 620 621 # Pass the data as a masked ndarray too --> 622 zval = darray.to_masked_array(copy=False) 623 624 # Replace pd.Intervals if contained in xval or yval. /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy) 1715 Masked where invalid values (nan or inf) occur. 1716 """ -> 1717 isnull = pd.isnull(self.values) 1718 return np.ma.MaskedArray(data=self.values, mask=isnull, copy=copy) 1719 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/dataarray.py in values(self) 408 def values(self): 409 """The array's data as a numpy.ndarray""" --> 410 return self.variable.values 411 412 @values.setter /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in values(self) 390 def values(self): 391 """The variable's data as a numpy.ndarray""" --> 392 return _as_array_or_item(self._data) 393 394 @values.setter /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in _as_array_or_item(data) 211 TODO: remove this (replace with np.asarray) once these issues are fixed 212 """ --> 213 data = np.asarray(data) 214 if data.ndim == 0: 215 if data.dtype.kind == 'M': /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 536 537 """ --> 538 return array(a, dtype, copy=False, order=order) 539 540 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/array/core.py in __array__(self, dtype, **kwargs) 996 997 def __array__(self, dtype=None, **kwargs): --> 998 x = self.compute() 999 if dtype and x.dtype != dtype: 1000 x = x.astype(dtype) /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/base.py in compute(self, **kwargs) 154 dask.base.compute 155 """ --> 156 (result,) = compute(self, traverse=False, **kwargs) 157 return result 158 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/base.py in compute(*args, **kwargs) 396 keys = [x.__dask_keys__() for x in collections] 397 postcomputes = [x.__dask_postcompute__() for x in collections] --> 398 results = schedule(dsk, keys, **kwargs) 399 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)]) 400 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs) 74 results = get_async(pool.apply_async, len(pool._pool), dsk, result, 75 cache=cache, get_id=_thread_get_id, ---> 76 pack_exception=pack_exception, **kwargs) 77 78 # Cleanup pools associated to dead threads /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs) 460 _execute_task(task, data) # Re-execute locally 461 else: --> 462 raise_exception(exc, tb) 463 res, worker_id = loads(res_info) 464 state['cache'][key] = res /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/compatibility.py in reraise(exc, tb) 110 if exc.__traceback__ is not tb: 111 raise exc.with_traceback(tb) --> 112 raise exc 113 114 import pickle as cPickle /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception) 228 try: 229 task, data = loads(task_info) --> 230 result = _execute_task(task, data) 231 id = get_id() 232 result = dumps((result, id)) /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/core.py in _execute_task(arg, cache, dsk) 117 func, args = arg[0], arg[1:] 118 args2 = [_execute_task(a, cache) for a in args] --> 119 return func(*args2) 120 elif not ishashable(arg): 121 return arg /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/array/core.py in getter(a, b, asarray, lock) 80 c = a[b] 81 if asarray: ---> 82 c = np.asarray(c) 83 finally: 84 if lock: /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 536 537 """ --> 538 return array(a, dtype, copy=False, order=order) 539 540 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype) 602 603 def __array__(self, dtype=None): --> 604 return np.asarray(self.array, dtype=dtype) 605 606 def __getitem__(self, key): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 536 537 """ --> 538 return array(a, dtype, copy=False, order=order) 539 540 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype) 508 def __array__(self, dtype=None): 509 array = as_indexable(self.array) --> 510 return np.asarray(array[self.key], dtype=None) 511 512 def transpose(self, order): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 536 537 """ --> 538 return array(a, dtype, copy=False, order=order) 539 540 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype) 66 67 def __array__(self, dtype=None): ---> 68 return self.func(self.array) 69 70 def __repr__(self): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _scale_offset_decoding(data, scale_factor, add_offset, dtype) 182 183 def _scale_offset_decoding(data, scale_factor, add_offset, dtype): --> 184 data = np.array(data, dtype=dtype, copy=True) 185 if scale_factor is not None: 186 data *= scale_factor /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype) 66 67 def __array__(self, dtype=None): ---> 68 return self.func(self.array) 69 70 def __repr__(self): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _apply_mask(data, encoded_fill_values, decoded_fill_value, dtype) 133 ) -> np.ndarray: 134 """Mask all matching values in a NumPy arrays.""" --> 135 data = np.asarray(data, dtype=dtype) 136 condition = False 137 for fv in encoded_fill_values: /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 536 537 """ --> 538 return array(a, dtype, copy=False, order=order) 539 540 /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype) 508 def __array__(self, dtype=None): 509 array = as_indexable(self.array) --> 510 return np.asarray(array[self.key], dtype=None) 511 512 def transpose(self, order): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in __getitem__(self, key) 62 return indexing.explicit_indexing_adapter( 63 key, self.shape, indexing.IndexingSupport.OUTER, ---> 64 self._getitem) 65 66 def _getitem(self, key): /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method) 776 """ 777 raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) --> 778 result = raw_indexing_method(raw_key.tuple) 779 if numpy_indices.tuple: 780 # index the loaded np.ndarray /g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in _getitem(self, key) 73 with self.datastore.lock: 74 original_array = self.get_array(needs_lock=False) ---> 75 array = getitem(original_array, key) 76 except IndexError: 77 # Catch IndexError in netCDF4 and return a more informative netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable.__getitem__() netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable._get() netCDF4/_netCDF4.pyx in netCDF4._netCDF4._ensure_nc_success() RuntimeError: NetCDF: HDF error
It's working, but it takes much longer than we'd like for test purposes, and on VDI you're likely to run into HDF5 errors. Let's gather our progress up to now in a function so we can try some work-arounds
def rolling_maximum(dataset):
daily_time = dataset.time.data.reshape((-1, 24))[:,0]
daily_max_data = dataset.data.reshape((-1, 24,
dataset.shape[1],
dataset.shape[2])
).mean(axis=1)
rolling_time = daily_time - numpy.timedelta64(7, 'D')
rolling_data = dask.array.overlap.map_overlap(
daily_max_data,
func=bottleneck.move_mean,
window=15,
axis=0,
depth=(14,0,0),
boundary='reflect',
trim=True,
dtype=daily_max_data.dtype)
rolling = xarray.DataArray(rolling_data,
dims = dataset.dims,
coords = {
'time': ('time', rolling_time),
'latitude': dataset.latitude,
'longitude': dataset.longitude,
})
return rolling
Dask isn't the only thing that splits data into chunks. NetCDF4 files do it too (it helps with compression)
You can look at the .encoding
attribute of a Xarray variable to see information about the file storage
ds.mx2t.encoding
{'zlib': True, 'shuffle': True, 'complevel': 5, 'fletcher32': False, 'contiguous': False, 'chunksizes': (93, 91, 180), 'source': '/g/data/ub4/era5/netcdf/surface/MX2T/1979/MX2T_era5_global_19790101_19790131.nc', 'original_shape': (737, 721, 1440), 'dtype': dtype('int16'), 'missing_value': -32767, '_FillValue': -32767, 'scale_factor': 0.0016965627572058163, 'add_offset': 265.9024415135433}
The size of chunks can have a huge effect on performance and memory use. You don't want them too big, as then you'll fill up your memory, and you don't want too many of them or Dask gets bogged down trying to keep track of everything.
Chunk sizes are set by the chunks
argument to open_dataset
. A good starting point is to use the chunk sizes from the file. If you're not doing much time processing it's good to try a small value for the time dimension too.
ds_c = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/'
'MX2T_era5_global_*.nc',
chunks={'latitude': 91, 'longitude': 180})
mx2t_c = ds_c.mx2t
Another thing that can help for testing very large datasets is to just run the analysis on a small section of the full dataset - say we feed our analysis just a few years of the whole range, only over Australia
mx2t_aus = mx2t_c.sel(time = slice('19800101T0000Z', '19841231T2300'),
latitude=slice(0,-60),
longitude=slice(100,160))
rolling_c = rolling_maximum(mx2t_aus)
doy_p90_c = (rolling_c
.groupby('time.dayofyear')
.reduce(dask_percentile, dim='time', q=90,
allow_lazy=True))
doy_p90_c.sel(dayofyear=39).plot()
[########################################] | 100% Completed | 1min 1.0s [########################################] | 100% Completed | 29.5s
<matplotlib.collections.QuadMesh at 0x7fe920b04828>
A minute and a half is reasonable, though you may want to reduce the domain even further if you need to debug something
With test jobs working interactively we can move on to doing the full analysis. This will take a bit of time and memory, so rather than the shared VDI nodes we'll make a script to run on Raijin.
When running on a compute node we want a bit more control over how Dask runs - for instance number of processors and memory limit
import dask.distributed
if __name__ == '__main__':
client = dask.distributed.Client(
n_workers=8, threads_per_worker=1,
memory_limit='4gb', local_dir=tempfile.mkdtemp())
Note that if you're using a client you must use a if __name__ == '__main__'
guard around your script
if __name__ == '__main__':
client = dask.distributed.Client(
n_workers=8, threads_per_worker=1,
memory_limit='4gb', local_dir=tempfile.mkdtemp())
ds = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/'
'MX2T_era5_global_*.nc',
chunks={'latitude': 91})
# Trim to full days, with a bit of a buffer for the rolling mean
ds = ds.sel(time=slice('19791201','20100201'))
# Perform the rolling mean of daily max and trim the output to the
# target period
rolled = rolling_maximum(ds.mx2t).sel(time=slice('19800101','20100101'))
# Perform the percentile
doy_p90 = (rolled.groupby('time.dayofyear')
.reduce(dask_percentile, dim='time', q=90,
allow_lazy=True))
# Save to a file
doy_p90 = doy_p90.to_dataset(name='mx2t_doy_p90')
saver = doy_p90.to_netcdf('mx2t_doy_p90.nc', compute=False)
# Add a progress bar
future = client.persist(saver)
dask.distributed.progress(future)
future.compute()
Slides & code are available from http://climate-cms.wikis.unsw.edu.au/Training
Contact CMS at cws_help@nci.org.au or https://arccss.slack.com