# Copyright 2020 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pandas as pd
import xarray
import numpy_groupies
def _binned_agg(
array: np.ndarray,
indices: np.ndarray,
num_bins: int,
*,
func,
fill_value,
dtype,
) -> np.ndarray:
"""NumPy helper function for aggregating over bins."""
mask = np.logical_not(np.isnan(indices))
int_indices = indices[mask].astype(int)
shape = array.shape[:-indices.ndim] + (num_bins,)
result = numpy_groupies.aggregate(
int_indices, array[..., mask],
func=func,
size=num_bins,
fill_value=fill_value,
dtype=dtype,
axis=-1,
)
return result
def groupby_bins_agg(
array: xarray.DataArray,
group: xarray.DataArray,
bins,
func='sum',
fill_value=0,
dtype=None,
**cut_kwargs,
) -> xarray.DataArray:
"""Faster equivalent of Xarray's groupby_bins(...).sum()."""
# TODO: implement this upstream in xarray:
# https://github.com/pydata/xarray/issues/4473
binned = pd.cut(np.ravel(group), bins, **cut_kwargs)
new_dim_name = group.name + "_bins"
indices = group.copy(data=binned.codes.reshape(group.shape))
result = xarray.apply_ufunc(
_binned_agg, array, indices,
input_core_dims=[indices.dims, indices.dims],
output_core_dims=[[new_dim_name]],
output_dtypes=[array.dtype],
dask_gufunc_kwargs=dict(
output_sizes={new_dim_name: binned.categories.size},
),
kwargs={
'num_bins': binned.categories.size,
'func': func,
'fill_value': fill_value,
'dtype': dtype,
},
dask='parallelized',
)
result.coords[new_dim_name] = binned.categories
return result
def make_test_data(t, x, y, seed=0):
signal = xarray.DataArray(
np.random.RandomState(seed).rand(t, x, y),
dims=['time', 'y', 'x'],
coords={
'time': np.arange(t),
'y': np.arange(x),
'x': np.arange(y),
},
name='signal')
distance = ((signal.x ** 2 + signal.y ** 2) ** 0.5).rename('distance')
return signal, distance
signal, distance = make_test_data(t=2, x=50, y=50)
bins = 10
actual = groupby_bins_agg(signal, distance, bins, func='mean')
expected = signal.groupby_bins(distance, bins=10).mean()
xarray.testing.assert_allclose(actual, expected)
actual
<xarray.DataArray (time: 2, distance_bins: 10)> array([[0.51498271, 0.46370372, 0.48996133, 0.51069211, 0.5302821 , 0.50000696, 0.48753868, 0.52151072, 0.48915714, 0.51292164], [0.53229943, 0.52167522, 0.45915308, 0.5293949 , 0.47742068, 0.48900111, 0.48465034, 0.47729889, 0.5059115 , 0.50237199]]) Coordinates: * time (time) int64 0 1 * distance_bins (distance_bins) object (-0.0693, 6.93] ... (62.367, 69.296]
array([[0.51498271, 0.46370372, 0.48996133, 0.51069211, 0.5302821 , 0.50000696, 0.48753868, 0.52151072, 0.48915714, 0.51292164], [0.53229943, 0.52167522, 0.45915308, 0.5293949 , 0.47742068, 0.48900111, 0.48465034, 0.47729889, 0.5059115 , 0.50237199]])
array([0, 1])
array([Interval(-0.0693, 6.93, closed='right'), Interval(6.93, 13.859, closed='right'), Interval(13.859, 20.789, closed='right'), Interval(20.789, 27.719, closed='right'), Interval(27.719, 34.648, closed='right'), Interval(34.648, 41.578, closed='right'), Interval(41.578, 48.508, closed='right'), Interval(48.508, 55.437, closed='right'), Interval(55.437, 62.367, closed='right'), Interval(62.367, 69.296, closed='right')], dtype=object)
signal, distance = make_test_data(t=20, x=1000, y=1000)
bins = 50
signal.nbytes / 1e6
160.0
%time _ = signal.groupby_bins(distance, bins).mean()
CPU times: user 8.52 s, sys: 674 ms, total: 9.19 s Wall time: 10.3 s
%time _ = groupby_bins_agg(signal, distance, bins, func='mean')
CPU times: user 909 ms, sys: 290 ms, total: 1.2 s Wall time: 1.3 s
import dask
dask_signal = signal.chunk({'time': 1})
dask.config.set(num_workers=4)
dask_signal
<xarray.DataArray 'signal' (time: 20, y: 1000, x: 1000)> dask.array<xarray-<this-array>, shape=(20, 1000, 1000), dtype=float64, chunksize=(1, 1000, 1000), chunktype=numpy.ndarray> Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 * y (y) int64 0 1 2 3 4 5 6 7 8 ... 991 992 993 994 995 996 997 998 999 * x (x) int64 0 1 2 3 4 5 6 7 8 ... 991 992 993 994 995 996 997 998 999
|
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
array([ 0, 1, 2, ..., 997, 998, 999])
array([ 0, 1, 2, ..., 997, 998, 999])
%time result = dask_signal.groupby_bins(distance, bins).mean()
%time result.compute()
result
CPU times: user 8.13 s, sys: 365 ms, total: 8.49 s Wall time: 8.87 s CPU times: user 1.12 s, sys: 332 ms, total: 1.45 s Wall time: 967 ms
<xarray.DataArray 'signal' (time: 20, distance_bins: 50)> dask.array<transpose, shape=(20, 50), dtype=float64, chunksize=(1, 1), chunktype=numpy.ndarray> Coordinates: * distance_bins (distance_bins) object (-1.413, 28.256] ... (1384.543, 141... * time (time) int64 0 1 2 3 4 5 6 7 8 ... 11 12 13 14 15 16 17 18 19
|
array([Interval(-1.413, 28.256, closed='right'), Interval(28.256, 56.512, closed='right'), Interval(56.512, 84.768, closed='right'), Interval(84.768, 113.024, closed='right'), Interval(113.024, 141.28, closed='right'), Interval(141.28, 169.536, closed='right'), Interval(169.536, 197.792, closed='right'), Interval(197.792, 226.048, closed='right'), Interval(226.048, 254.304, closed='right'), Interval(254.304, 282.56, closed='right'), Interval(282.56, 310.816, closed='right'), Interval(310.816, 339.072, closed='right'), Interval(339.072, 367.328, closed='right'), Interval(367.328, 395.584, closed='right'), Interval(395.584, 423.84, closed='right'), Interval(423.84, 452.096, closed='right'), Interval(452.096, 480.352, closed='right'), Interval(480.352, 508.608, closed='right'), Interval(508.608, 536.864, closed='right'), Interval(536.864, 565.12, closed='right'), Interval(565.12, 593.376, closed='right'), Interval(593.376, 621.632, closed='right'), Interval(621.632, 649.888, closed='right'), Interval(649.888, 678.144, closed='right'), Interval(678.144, 706.4, closed='right'), Interval(706.4, 734.656, closed='right'), Interval(734.656, 762.912, closed='right'), Interval(762.912, 791.168, closed='right'), Interval(791.168, 819.424, closed='right'), Interval(819.424, 847.68, closed='right'), Interval(847.68, 875.936, closed='right'), Interval(875.936, 904.192, closed='right'), Interval(904.192, 932.448, closed='right'), Interval(932.448, 960.704, closed='right'), Interval(960.704, 988.96, closed='right'), Interval(988.96, 1017.216, closed='right'), Interval(1017.216, 1045.472, closed='right'), Interval(1045.472, 1073.728, closed='right'), Interval(1073.728, 1101.983, closed='right'), Interval(1101.983, 1130.239, closed='right'), Interval(1130.239, 1158.495, closed='right'), Interval(1158.495, 1186.751, closed='right'), Interval(1186.751, 1215.007, closed='right'), Interval(1215.007, 1243.263, closed='right'), Interval(1243.263, 1271.519, closed='right'), Interval(1271.519, 1299.775, closed='right'), Interval(1299.775, 1328.031, closed='right'), Interval(1328.031, 1356.287, closed='right'), Interval(1356.287, 1384.543, closed='right'), Interval(1384.543, 1412.799, closed='right')], dtype=object)
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
%time result = groupby_bins_agg(dask_signal, distance, bins, func='mean')
%time result.compute()
result
CPU times: user 54.8 ms, sys: 7.46 ms, total: 62.2 ms Wall time: 61.3 ms CPU times: user 884 ms, sys: 191 ms, total: 1.08 s Wall time: 484 ms
<xarray.DataArray (time: 20, distance_bins: 50)> dask.array<transpose, shape=(20, 50), dtype=float64, chunksize=(1, 50), chunktype=numpy.ndarray> Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 ... 11 12 13 14 15 16 17 18 19 * distance_bins (distance_bins) object (-1.413, 28.256] ... (1384.543, 141...
|
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
array([Interval(-1.413, 28.256, closed='right'), Interval(28.256, 56.512, closed='right'), Interval(56.512, 84.768, closed='right'), Interval(84.768, 113.024, closed='right'), Interval(113.024, 141.28, closed='right'), Interval(141.28, 169.536, closed='right'), Interval(169.536, 197.792, closed='right'), Interval(197.792, 226.048, closed='right'), Interval(226.048, 254.304, closed='right'), Interval(254.304, 282.56, closed='right'), Interval(282.56, 310.816, closed='right'), Interval(310.816, 339.072, closed='right'), Interval(339.072, 367.328, closed='right'), Interval(367.328, 395.584, closed='right'), Interval(395.584, 423.84, closed='right'), Interval(423.84, 452.096, closed='right'), Interval(452.096, 480.352, closed='right'), Interval(480.352, 508.608, closed='right'), Interval(508.608, 536.864, closed='right'), Interval(536.864, 565.12, closed='right'), Interval(565.12, 593.376, closed='right'), Interval(593.376, 621.632, closed='right'), Interval(621.632, 649.888, closed='right'), Interval(649.888, 678.144, closed='right'), Interval(678.144, 706.4, closed='right'), Interval(706.4, 734.656, closed='right'), Interval(734.656, 762.912, closed='right'), Interval(762.912, 791.168, closed='right'), Interval(791.168, 819.424, closed='right'), Interval(819.424, 847.68, closed='right'), Interval(847.68, 875.936, closed='right'), Interval(875.936, 904.192, closed='right'), Interval(904.192, 932.448, closed='right'), Interval(932.448, 960.704, closed='right'), Interval(960.704, 988.96, closed='right'), Interval(988.96, 1017.216, closed='right'), Interval(1017.216, 1045.472, closed='right'), Interval(1045.472, 1073.728, closed='right'), Interval(1073.728, 1101.983, closed='right'), Interval(1101.983, 1130.239, closed='right'), Interval(1130.239, 1158.495, closed='right'), Interval(1158.495, 1186.751, closed='right'), Interval(1186.751, 1215.007, closed='right'), Interval(1215.007, 1243.263, closed='right'), Interval(1243.263, 1271.519, closed='right'), Interval(1271.519, 1299.775, closed='right'), Interval(1299.775, 1328.031, closed='right'), Interval(1328.031, 1356.287, closed='right'), Interval(1356.287, 1384.543, closed='right'), Interval(1384.543, 1412.799, closed='right')], dtype=object)