Big Data Analysis with Xarray and Dask

Scott Wales, CLEX CMS

Using Xarray and Dask to process 1 TB of climate data

The slides I'm using are available from the CMS wiki - http://climate-cms.wikis.unsw.edu.au/Training

What is 'big' data?

I'm using 'big data' to describe datasets that are bigger than your computer's memory

At NCI there are several different computer types

  • Raijin:

    • normal: 32 / 64 / 128 GB per 16 CPU cores
    • normalbw: 128 / 256 GB per 28 CPU cores
    • normalsl: 192 GB per 32 CPU cores
  • VDI: 32 GB per 8 CPU cores

Managing Big Datasets

  • Work on one block of data at a time, doing as much as you can

  • Test with small subsets of the data

  • If you create intermediate files delete them when they're no longer needed

  • Consider how big your final output is going to be

Heatwave Detection Thresholds

CMS maintain replicas of many useful datasets at NCI so you don't have to download them yourself - http://climate-cms.wikis.unsw.edu.au/Category:Dataset

Heatwave detection requires a threshold value for each grid cell at each day of the year. This threshold is

  • the 90th percentile at each day of the year
  • of a 15-day rolling mean
  • of the daily maximum temperature

For thirty years of 0.25 degree hourly ERA-5 data this is going to involve processing about a terrabyte of data

Hopefully you can take some of the techniques used here for your own analyses of large datasets

What data do we need for each step?

  1. Daily maximum temperature

    • High locality in time (need 24 consecutive values)
    • Spatially independent
  2. 15 day rolling mean

    • High locality in time (need 15 consecutive values)
    • Spatially independent
  3. 90th percentile at each day of the year

    • Low locality in time (need every ~365th value)
    • Spatially independent

The data comes in one file per calendar month

Operations 1 and 2 have similar data needs - they could easily be combined into a single operation acting on a couple files at a time. They reduce the data size to 1/24th of the input size

Operation 3 is more complex - you could gather all of the January files together and process them, then all of the February files together etc. - but be careful of leap years. If we've got 30 years of input this reduces the data size to 1/30th of the input size

The final output should be only a couple gigabytes, and will be much easier to use when detecting events

A perfectly reasonable way to do this analysis is to manually run through the months as described above.

for month in range(1,2):
    this_month = []

    for year in range(1980,1982):
        # Open files for this month plus the next and previous (for the rolling mean)
        this_year_month = load_data(year, month)

        # Process steps 1 & 2 for this month
        daily_max = mx2t.resample(time='D').max('time').chunk({'time':-1})
        rolling_average = (daily_max
                           .rolling(time=15, center=True)
                           .mean()
                           .sel(time=slice(this_date,
                                           this_date + pandas.tseries.offsets.MonthEnd())))
        print(rolling_average)
        this_month.append(rolling_average)

    # Process step 3 and save this month to file
    month_threshold = (xarray.concat(this_month, 'time')
                       .groupby('time.dayofyear')
                       .reduce(numpy.percentile, q=90, dim='time'))
    month_threshold.to_netcdf('era5-monthly-threshold-%02d.nc'%month)
In [21]:
import pandas
import numpy
import xarray
import glob

output = []

for month in range(1,2):
    this_month = []
    
    for year in range(1980,1982):
        # Open files for this month plus the next and previous (for the rolling mean)
        this_date = pandas.Timestamp(year=year, month=month, day=1)
        prev_date = this_date - pandas.tseries.offsets.MonthBegin()
        next_date = this_date + pandas.tseries.offsets.MonthBegin()
        month_data = []
        for date in [prev_date, this_date, next_date]:
            print(date)
            edate = date + pandas.tseries.offsets.MonthEnd()
            path = ('/g/data/ub4/era5/netcdf/surface/MX2T/%04d/MX2T_era5_global_%04d%02d01_%04d%02d%02d.nc'%
                            (date.year, date.year, date.month, edate.year, edate.month, edate.day))
            month_data.append(xarray.open_dataset(path).mx2t)
        mx2t = xarray.concat(month_data, 'time')
        print(mx2t.nbytes/1024**3)
        
        # Process steps 1 & 2 for this month
        daily_max = mx2t.resample(time='D').max('time').chunk({'time':-1})
        rolling_average = (daily_max
                           .rolling(time=15, center=True)
                           .mean()
                           .sel(time=slice(this_date, this_date + pandas.tseries.offsets.MonthEnd())))
        
        this_month.append(rolling_average)
    
    # Process step 3 and save this month to file
    month_threshold = (xarray.concat(this_month, 'time')
                       .groupby('time.dayofyear')
                       .reduce(numpy.percentile, q=90, dim='time'))
    month_threshold.to_netcdf('era5-monthly-threshold-%02d.nc'%month)
1979-12-01 00:00:00
1980-01-01 00:00:00
1980-02-01 00:00:00
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-b44f1e296cae> in <module>
     21                             (date.year, date.year, date.month, edate.year, edate.month, edate.day))
     22             month_data.append(xarray.open_dataset(path).mx2t)
---> 23         mx2t = xarray.concat(month_data, 'time')
     24         print(mx2t.nbytes/1024**3)
     25 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/combine.py in concat(objs, dim, data_vars, coords, compat, positions, indexers, mode, concat_over)
    118         raise TypeError('can only concatenate xarray Dataset and DataArray '
    119                         'objects, got %s' % type(first_obj))
--> 120     return f(objs, dim, data_vars, coords, compat, positions)
    121 
    122 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/combine.py in _dataarray_concat(arrays, dim, data_vars, coords, compat, positions)
    337 
    338     ds = _dataset_concat(datasets, dim, data_vars, coords, compat,
--> 339                          positions)
    340     result = arrays[0]._from_temp_dataset(ds, name)
    341 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/combine.py in _dataset_concat(datasets, dim, data_vars, coords, compat, positions)
    303         if k in concat_over:
    304             vars = ensure_common_dims([ds.variables[k] for ds in datasets])
--> 305             combined = concat_vars(vars, dim, positions)
    306             insert_result_variable(k, combined)
    307 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in concat(variables, dim, positions, shortcut)
   2087         return IndexVariable.concat(variables, dim, positions, shortcut)
   2088     else:
-> 2089         return Variable.concat(variables, dim, positions, shortcut)
   2090 
   2091 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in concat(cls, variables, dim, positions, shortcut)
   1427         first_var = variables[0]
   1428 
-> 1429         arrays = [v.data for v in variables]
   1430 
   1431         if dim in first_var.dims:

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in <listcomp>(.0)
   1427         first_var = variables[0]
   1428 
-> 1429         arrays = [v.data for v in variables]
   1430 
   1431         if dim in first_var.dims:

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in data(self)
    295             return self._data
    296         else:
--> 297             return self.values
    298 
    299     @data.setter

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in values(self)
    390     def values(self):
    391         """The variable's data as a numpy.ndarray"""
--> 392         return _as_array_or_item(self._data)
    393 
    394     @values.setter

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in _as_array_or_item(data)
    211     TODO: remove this (replace with np.asarray) once these issues are fixed
    212     """
--> 213     data = np.asarray(data)
    214     if data.ndim == 0:
    215         if data.dtype.kind == 'M':

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    624 
    625     def __array__(self, dtype=None):
--> 626         self._ensure_cached()
    627         return np.asarray(self.array, dtype=dtype)
    628 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in _ensure_cached(self)
    621     def _ensure_cached(self):
    622         if not isinstance(self.array, NumpyIndexingAdapter):
--> 623             self.array = NumpyIndexingAdapter(np.asarray(self.array))
    624 
    625     def __array__(self, dtype=None):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    602 
    603     def __array__(self, dtype=None):
--> 604         return np.asarray(self.array, dtype=dtype)
    605 
    606     def __getitem__(self, key):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    508     def __array__(self, dtype=None):
    509         array = as_indexable(self.array)
--> 510         return np.asarray(array[self.key], dtype=None)
    511 
    512     def transpose(self, order):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype)
     66 
     67     def __array__(self, dtype=None):
---> 68         return self.func(self.array)
     69 
     70     def __repr__(self):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _scale_offset_decoding(data, scale_factor, add_offset, dtype)
    182 
    183 def _scale_offset_decoding(data, scale_factor, add_offset, dtype):
--> 184     data = np.array(data, dtype=dtype, copy=True)
    185     if scale_factor is not None:
    186         data *= scale_factor

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype)
     66 
     67     def __array__(self, dtype=None):
---> 68         return self.func(self.array)
     69 
     70     def __repr__(self):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _apply_mask(data, encoded_fill_values, decoded_fill_value, dtype)
    133 ) -> np.ndarray:
    134     """Mask all matching values in a NumPy arrays."""
--> 135     data = np.asarray(data, dtype=dtype)
    136     condition = False
    137     for fv in encoded_fill_values:

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    508     def __array__(self, dtype=None):
    509         array = as_indexable(self.array)
--> 510         return np.asarray(array[self.key], dtype=None)
    511 
    512     def transpose(self, order):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in __getitem__(self, key)
     62         return indexing.explicit_indexing_adapter(
     63             key, self.shape, indexing.IndexingSupport.OUTER,
---> 64             self._getitem)
     65 
     66     def _getitem(self, key):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method)
    776     """
    777     raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
--> 778     result = raw_indexing_method(raw_key.tuple)
    779     if numpy_indices.tuple:
    780         # index the loaded np.ndarray

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in _getitem(self, key)
     73             with self.datastore.lock:
     74                 original_array = self.get_array(needs_lock=False)
---> 75                 array = getitem(original_array, key)
     76         except IndexError:
     77             # Catch IndexError in netCDF4 and return a more informative

netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable.__getitem__()

netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable._get()

netCDF4/_netCDF4.pyx in netCDF4._netCDF4._ensure_nc_success()

RuntimeError: NetCDF: HDF error

This example however is going to use the Dask library to do the analysis - it allows for parallel processing and lets us control memory more.

I'm using the conda/analysis3-unstable environment for this demo, you can load it on Raijin/VDI with

module use /g/data3/hh5/public/modules
module load conda/analysis3-unstable

To start off with let's load some libraries

In [1]:
%matplotlib inline
import xarray
import dask
import bottleneck
import numpy

import dask.diagnostics
dask.diagnostics.ProgressBar().register()

Our local archives have hourly fields at 0.25 x 0.25 degree resolution. For the single-level MX2T variable that's 408 GB of data, split across 467 files.

In [2]:
ds = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/MX2T_era5_global_*.nc')
ds
Out[2]:
<xarray.Dataset>
Dimensions:    (latitude: 721, longitude: 1440, time: 341057)
Coordinates:
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * time       (time) datetime64[ns] 1979-01-01T07:00:00 ... 2019-02-28T23:00:00
Data variables:
    mx2t       (time, latitude, longitude) float32 dask.array<shape=(341057, 721, 1440), chunksize=(737, 721, 1440)>
Attributes:
    Conventions:  CF-1.6
    history:      2019-03-11 22:11:16 GMT by grib_to_netcdf-2.10.0: /opt/ecmw...

open_mfdataset() creates a virtual dataset out of multiple files by concatenating the contents together

This data goes from 1979/01/01 0700Z to 2019/02/28 2300Z. To make analysis simpler let's trim it so it starts at Jan 1 0000Z and ends at Dec 31 2300Z using a selector

In [3]:
ds = ds.sel(time = slice('19800101T0000Z', '20091231T2300'))

Since the data in the files is compressed, the in-memory size of this data is much bigger than the size on disk - a bit over one TB

In [4]:
mx2t = ds.mx2t

print(mx2t.nbytes / 1024 ** 3, 'GB')
1017.1860980987549 GB

Dask lets us work with data much larger than our computer's memory, because it only loads data when it's really needed - if we make a plot of a certain region it will only load that region's data

http://docs.dask.org/en/latest/array.html
https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.nbytes.html

Daily max temperature

Now to start our analysis - we have hourly max temperature, but want daily.

We could use resample, but this takes a bit of time for Dask to process

In [9]:
%%time
daily_max = mx2t.resample(time='D').max('time')
print(daily_max)
print()
<xarray.DataArray 'mx2t' (time: 10958, latitude: 721, longitude: 1440)>
dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(1, 721, 1440)>
Coordinates:
  * time       (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0

CPU times: user 19.1 s, sys: 25 ms, total: 19.2 s
Wall time: 19.1 s

In general Dask operations should return pretty much immediately

You can see that chunksize=(1, 721, 1440) - dask has created a new chunk for each individual day. With 10958 days to keep track of dask can get bogged down

http://xarray.pydata.org/en/stable/generated/xarray.DataArray.resample.html#xarray.DataArray.resample

A different way to do the daily mean is by manipulating the shape of the array - we know since this is well-structured published data that there are always 24 values in each day's data

In [8]:
%%time

daily_time = mx2t.time.data.reshape((-1, 24))[:,0]
print(daily_time[0:4])
print()

daily_max_data = mx2t.data.reshape((-1, 24, mx2t.shape[1], mx2t.shape[2])).max(axis=1)
print(daily_max_data)
print()
['1980-01-01T00:00:00.000000000' '1980-01-02T00:00:00.000000000'
 '1980-01-03T00:00:00.000000000' '1980-01-04T00:00:00.000000000']

dask.array<mean_agg-aggregate, shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)>

CPU times: user 11 ms, sys: 1e+03 ┬Ás, total: 12 ms
Wall time: 11.1 ms

We've lost the xarray metadata, but you can see that the chunks are much bigger - we have a chunk for each month rather than each day

We can add metadata back by copying it from the original field

In [7]:
daily_max = xarray.DataArray(daily_max_data,
                             name = 'daily_max_t',
                             dims = mx2t.dims,
                             coords = {
                                 'time': ('time', daily_time),
                                 'latitude': mx2t.latitude,
                                 'longitude': mx2t.longitude,
                             })
daily_max
Out[7]:
<xarray.DataArray 'daily_max_t' (time: 10958, latitude: 721, longitude: 1440)>
dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)>
Coordinates:
  * time       (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75

15-day rolling mean

Our next step is to do a rolling mean of the daily maximum

In [8]:
rolling = daily_max.rolling(time=15, center=True).mean()
rolling
Out[8]:
<xarray.DataArray 'getitem-cee79fa3c2352a993a66c0dc8ded60c9' (time: 10958, latitude: 721, longitude: 1440)>
dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(10958, 721, 1440)>
Coordinates:
  * time       (time) datetime64[ns] 1980-01-01 1980-01-02 ... 2009-12-31
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75

This time we've the opposite problem - rather than having one chunk per day, we have one chunk for the entire 1 TB dataset! This is going to cause trouble when analysing the data, since it's going to get loaded as one giant chunk

http://xarray.pydata.org/en/stable/generated/xarray.DataArray.rolling.html#xarray.DataArray.rolling

This is due to a bug in Xarray. It's possible, although clunky, to do the rolling mean directly in dask and keep the original chunking

In [9]:
rolling_data = dask.array.overlap.map_overlap(
        daily_max.data,
        func=bottleneck.move_mean,
        window=15,
        axis=0,
        depth=(14,0,0),
        boundary='reflect',
        trim=True,
        dtype=daily_max.data.dtype)

rolling_data
Out[9]:
dask.array<_trim, shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)>

Note the rolling mean is not centred doing the calculation this way - it's the mean over this day and 14 previous days, rather than the mean of 7 days before, this day, 7 days after.

http://docs.dask.org/en/latest/array-api.html#dask.array.overlap.map_overlap
https://kwgoodman.github.io/bottleneck-doc/reference.html#bottleneck.move_mean

When we convert back to a DataArray we need to keep this time offset in mind

In [10]:
rolling_time = daily_max.time.data - numpy.timedelta64(7, 'D')

rolling = xarray.DataArray(rolling_data,
                             name = 'daily_max_t',
                             dims = daily_max.dims,
                             coords = {
                                 'time': ('time', rolling_time),
                                 'latitude': daily_max.latitude,
                                 'longitude': daily_max.longitude,
                             })

rolling
Out[10]:
<xarray.DataArray 'daily_max_t' (time: 10958, latitude: 721, longitude: 1440)>
dask.array<shape=(10958, 721, 1440), dtype=float32, chunksize=(31, 721, 1440)>
Coordinates:
  * time       (time) datetime64[ns] 1979-12-25 1979-12-26 ... 2009-12-24
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75

At this point all of the time domain processing is done - we've calculated the 15 day rolling mean of the daily maximum temperature. It's a good idea to plot some data to make sure everything's looking reasonable

In [11]:
rolling.sel(time='19960728').plot()
[########################################] | 100% Completed | 32.6s
[########################################] | 100% Completed | 29.2s
Out[11]:
<matplotlib.collections.QuadMesh at 0x7fe920883898>

90th Percentile by Day of Year

Now it's time to get the 90th percentile for each day in the year.

The simplest way to do this is

rolling.groupby('time.dayofyear').reduce(numpy.percentile, dim='time', q=90)

however there are problems with our giant dataset:

  • By default .reduce() is not Dask-aware - it loads all the data. This can be fixed by adding allow_lazy=True to the arguments
  • numpy.percentile isn't Dask aware either, so we need to add Dask support to the existing function

dask.map_blocks allows us to make a Dask-aware version of numpy.percentile - it tells Dask to run a function on each chunk. Since percentile is a reduction operation we need to join up all of the chunks along the time axis first and add the drop_axis argument

In [12]:
def dask_percentile(array, axis, q):
    array = array.rechunk({axis: -1})
    return array.map_blocks(
        numpy.percentile,
        axis=axis,
        q=q,
        dtype=array.dtype,
        drop_axis=axis)

doy_p90 = (rolling
           .groupby('time.dayofyear')
           .reduce(dask_percentile, dim='time', q=90,
                   allow_lazy=True))
doy_p90
Out[12]:
<xarray.DataArray 'daily_max_t' (dayofyear: 366, latitude: 721, longitude: 1440)>
dask.array<shape=(366, 721, 1440), dtype=float32, chunksize=(1, 721, 1440)>
Coordinates:
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 -180.0 -179.75 -179.5 ... 179.25 179.5 179.75
  * dayofyear  (dayofyear) int64 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366

Now we're ready to try it out!

In [13]:
doy_p90.sel(dayofyear=39).plot()
[##############                          ] | 35% Completed |  5min 16.6s
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-13-d2d1ceb6f742> in <module>
----> 1 doy_p90.sel(dayofyear=39).plot()

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in __call__(self, **kwargs)
    421 
    422     def __call__(self, **kwargs):
--> 423         return plot(self._da, **kwargs)
    424 
    425     @functools.wraps(hist)

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs)
    170     kwargs['ax'] = ax
    171 
--> 172     return plotfunc(darray, **kwargs)
    173 
    174 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
    620 
    621         # Pass the data as a masked ndarray too
--> 622         zval = darray.to_masked_array(copy=False)
    623 
    624         # Replace pd.Intervals if contained in xval or yval.

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy)
   1715             Masked where invalid values (nan or inf) occur.
   1716         """
-> 1717         isnull = pd.isnull(self.values)
   1718         return np.ma.MaskedArray(data=self.values, mask=isnull, copy=copy)
   1719 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/dataarray.py in values(self)
    408     def values(self):
    409         """The array's data as a numpy.ndarray"""
--> 410         return self.variable.values
    411 
    412     @values.setter

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in values(self)
    390     def values(self):
    391         """The variable's data as a numpy.ndarray"""
--> 392         return _as_array_or_item(self._data)
    393 
    394     @values.setter

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/variable.py in _as_array_or_item(data)
    211     TODO: remove this (replace with np.asarray) once these issues are fixed
    212     """
--> 213     data = np.asarray(data)
    214     if data.ndim == 0:
    215         if data.dtype.kind == 'M':

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/array/core.py in __array__(self, dtype, **kwargs)
    996 
    997     def __array__(self, dtype=None, **kwargs):
--> 998         x = self.compute()
    999         if dtype and x.dtype != dtype:
   1000             x = x.astype(dtype)

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/base.py in compute(self, **kwargs)
    154         dask.base.compute
    155         """
--> 156         (result,) = compute(self, traverse=False, **kwargs)
    157         return result
    158 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/base.py in compute(*args, **kwargs)
    396     keys = [x.__dask_keys__() for x in collections]
    397     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 398     results = schedule(dsk, keys, **kwargs)
    399     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    400 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     74     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     75                         cache=cache, get_id=_thread_get_id,
---> 76                         pack_exception=pack_exception, **kwargs)
     77 
     78     # Cleanup pools associated to dead threads

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    460                         _execute_task(task, data)  # Re-execute locally
    461                     else:
--> 462                         raise_exception(exc, tb)
    463                 res, worker_id = loads(res_info)
    464                 state['cache'][key] = res

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/compatibility.py in reraise(exc, tb)
    110         if exc.__traceback__ is not tb:
    111             raise exc.with_traceback(tb)
--> 112         raise exc
    113 
    114     import pickle as cPickle

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    228     try:
    229         task, data = loads(task_info)
--> 230         result = _execute_task(task, data)
    231         id = get_id()
    232         result = dumps((result, id))

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/dask/array/core.py in getter(a, b, asarray, lock)
     80         c = a[b]
     81         if asarray:
---> 82             c = np.asarray(c)
     83     finally:
     84         if lock:

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    602 
    603     def __array__(self, dtype=None):
--> 604         return np.asarray(self.array, dtype=dtype)
    605 
    606     def __getitem__(self, key):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    508     def __array__(self, dtype=None):
    509         array = as_indexable(self.array)
--> 510         return np.asarray(array[self.key], dtype=None)
    511 
    512     def transpose(self, order):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype)
     66 
     67     def __array__(self, dtype=None):
---> 68         return self.func(self.array)
     69 
     70     def __repr__(self):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _scale_offset_decoding(data, scale_factor, add_offset, dtype)
    182 
    183 def _scale_offset_decoding(data, scale_factor, add_offset, dtype):
--> 184     data = np.array(data, dtype=dtype, copy=True)
    185     if scale_factor is not None:
    186         data *= scale_factor

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in __array__(self, dtype)
     66 
     67     def __array__(self, dtype=None):
---> 68         return self.func(self.array)
     69 
     70     def __repr__(self):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/coding/variables.py in _apply_mask(data, encoded_fill_values, decoded_fill_value, dtype)
    133 ) -> np.ndarray:
    134     """Mask all matching values in a NumPy arrays."""
--> 135     data = np.asarray(data, dtype=dtype)
    136     condition = False
    137     for fv in encoded_fill_values:

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    536 
    537     """
--> 538     return array(a, dtype, copy=False, order=order)
    539 
    540 

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    508     def __array__(self, dtype=None):
    509         array = as_indexable(self.array)
--> 510         return np.asarray(array[self.key], dtype=None)
    511 
    512     def transpose(self, order):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in __getitem__(self, key)
     62         return indexing.explicit_indexing_adapter(
     63             key, self.shape, indexing.IndexingSupport.OUTER,
---> 64             self._getitem)
     65 
     66     def _getitem(self, key):

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/core/indexing.py in explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method)
    776     """
    777     raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
--> 778     result = raw_indexing_method(raw_key.tuple)
    779     if numpy_indices.tuple:
    780         # index the loaded np.ndarray

/g/data3/hh5/public/apps/miniconda3/envs/analysis3-19.04/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in _getitem(self, key)
     73             with self.datastore.lock:
     74                 original_array = self.get_array(needs_lock=False)
---> 75                 array = getitem(original_array, key)
     76         except IndexError:
     77             # Catch IndexError in netCDF4 and return a more informative

netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable.__getitem__()

netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable._get()

netCDF4/_netCDF4.pyx in netCDF4._netCDF4._ensure_nc_success()

RuntimeError: NetCDF: HDF error

It's working, but it takes much longer than we'd like for test purposes, and on VDI you're likely to run into HDF5 errors. Let's gather our progress up to now in a function so we can try some work-arounds

In [14]:
def rolling_maximum(dataset):
    daily_time = dataset.time.data.reshape((-1, 24))[:,0]
    daily_max_data = dataset.data.reshape((-1, 24, 
                                           dataset.shape[1], 
                                           dataset.shape[2])
                                         ).mean(axis=1)
    
    rolling_time = daily_time - numpy.timedelta64(7, 'D')
    rolling_data = dask.array.overlap.map_overlap(
        daily_max_data,
        func=bottleneck.move_mean,
        window=15,
        axis=0,
        depth=(14,0,0),
        boundary='reflect',
        trim=True,
        dtype=daily_max_data.dtype)
    
    rolling = xarray.DataArray(rolling_data,
                             dims = dataset.dims,
                             coords = {
                                 'time': ('time', rolling_time),
                                 'latitude': dataset.latitude,
                                 'longitude': dataset.longitude,
                             })
    
    return rolling

Chunk sizes

Dask isn't the only thing that splits data into chunks. NetCDF4 files do it too (it helps with compression)

You can look at the .encoding attribute of a Xarray variable to see information about the file storage

In [17]:
ds.mx2t.encoding
Out[17]:
{'zlib': True,
 'shuffle': True,
 'complevel': 5,
 'fletcher32': False,
 'contiguous': False,
 'chunksizes': (93, 91, 180),
 'source': '/g/data/ub4/era5/netcdf/surface/MX2T/1979/MX2T_era5_global_19790101_19790131.nc',
 'original_shape': (737, 721, 1440),
 'dtype': dtype('int16'),
 'missing_value': -32767,
 '_FillValue': -32767,
 'scale_factor': 0.0016965627572058163,
 'add_offset': 265.9024415135433}

The size of chunks can have a huge effect on performance and memory use. You don't want them too big, as then you'll fill up your memory, and you don't want too many of them or Dask gets bogged down trying to keep track of everything.

Chunk sizes are set by the chunks argument to open_dataset. A good starting point is to use the chunk sizes from the file. If you're not doing much time processing it's good to try a small value for the time dimension too.

In [18]:
ds_c = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/'
                             'MX2T_era5_global_*.nc',
                             chunks={'latitude': 91, 'longitude': 180})
mx2t_c = ds_c.mx2t

Another thing that can help for testing very large datasets is to just run the analysis on a small section of the full dataset - say we feed our analysis just a few years of the whole range, only over Australia

In [19]:
mx2t_aus = mx2t_c.sel(time = slice('19800101T0000Z', '19841231T2300'),
                      latitude=slice(0,-60),
                      longitude=slice(100,160))

rolling_c = rolling_maximum(mx2t_aus)

doy_p90_c = (rolling_c
             .groupby('time.dayofyear')
             .reduce(dask_percentile, dim='time', q=90,
                     allow_lazy=True))
In [20]:
doy_p90_c.sel(dayofyear=39).plot()
[########################################] | 100% Completed |  1min  1.0s
[########################################] | 100% Completed | 29.5s
Out[20]:
<matplotlib.collections.QuadMesh at 0x7fe920b04828>

A minute and a half is reasonable, though you may want to reduce the domain even further if you need to debug something

Running on Raijin

With test jobs working interactively we can move on to doing the full analysis. This will take a bit of time and memory, so rather than the shared VDI nodes we'll make a script to run on Raijin.

When running on a compute node we want a bit more control over how Dask runs - for instance number of processors and memory limit

import dask.distributed

if __name__ == '__main__':
    client = dask.distributed.Client(
        n_workers=8, threads_per_worker=1,
        memory_limit='4gb', local_dir=tempfile.mkdtemp())

Note that if you're using a client you must use a if __name__ == '__main__' guard around your script

if __name__ == '__main__':
    client = dask.distributed.Client(
        n_workers=8, threads_per_worker=1,
        memory_limit='4gb', local_dir=tempfile.mkdtemp())

    ds = xarray.open_mfdataset('/g/data/ub4/era5/netcdf/surface/MX2T/*/'
                               'MX2T_era5_global_*.nc',
                               chunks={'latitude': 91})

    # Trim to full days, with a bit of a buffer for the rolling mean
    ds = ds.sel(time=slice('19791201','20100201'))

    # Perform the rolling mean of daily max and trim the output to the
    # target period
    rolled = rolling_maximum(ds.mx2t).sel(time=slice('19800101','20100101'))

    # Perform the percentile
    doy_p90 = (rolled.groupby('time.dayofyear')
                     .reduce(dask_percentile, dim='time', q=90,
                             allow_lazy=True))

    # Save to a file
    doy_p90 = doy_p90.to_dataset(name='mx2t_doy_p90')
    saver = doy_p90.to_netcdf('mx2t_doy_p90.nc', compute=False)

    # Add a progress bar
    future = client.persist(saver)
    dask.distributed.progress(future)
    future.compute()

The end

Slides & code are available from http://climate-cms.wikis.unsw.edu.au/Training

Contact CMS at [email protected] or https://arccss.slack.com