#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import xarray as xr import rasterio.features import stackstac import pystac_client import planetary_computer import xrspatial.multispectral as ms # In[3]: # from dask_gateway import GatewayCluster # # just for speed. This runs fine on a single machine. # cluster = GatewayCluster() # cluster.scale(24) # client = cluster.get_client() # cluster # ## Problem setup # # Our task is to create a cloud-free composite of some Sentinel-2 imagery by taking a median over time. The "easy" way of doing things is to use `stackstac` to create the DataArray and call `.median(dim="time")`. # In[92]: area_of_interest = { "type": "Polygon", "coordinates": [ [ [-122.27508544921875, 47.54687159892238], [-121.96128845214844, 47.54687159892238], [-121.96128845214844, 47.745787772920934], [-122.27508544921875, 47.745787772920934], [-122.27508544921875, 47.54687159892238], ] ], } bbox = rasterio.features.bounds(area_of_interest) stac = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1") search = stac.search( bbox=bbox, # datetime="2018-01-01/2020-12-31", datetime="2020-06-01/2020-12-31", collections=["sentinel-2-l2a"], limit=500, # fetch items in batches of 500 query={"eo:cloud_cover": {"lt": 25}}, ) items = list(search.get_items()) print(len(items)) signed_items = [planetary_computer.sign(item).to_dict() for item in items] bad_data = ( stackstac.stack( signed_items, assets=["B04", "B03", "B02"], # red, green, blue chunksize=2048, ) .where(lambda x: x > 0, other=np.nan) # sentinel-2 uses 0 as nodata # .assign_coords(band=lambda x: x.common_name.rename("band")) # use common names ) bad_data # Unfortunately, this `median` computation requires an expensive shuffle. Currently we have chunk sizes of `(time=1, band=1, y=2048, x=2048)`. To compute the median, all of the values in `time` need to be in the same chunk. Dask will automatically rechunk to something like `(time=36, band=1, y=256, x=256)`. This means we won't blow up memory on any single task, but it requires shuffling a bunch of data, which can be dangerous to your cluster's health. # In[93]: bad_result = bad_data.median(dim="time") bad_result # As a secondary issue, this task graph is also relatively complex. The `bad_result` DataArray has ~25,000 tasks, which reflects the intermediate `rechunk` operations. Can we do better? # # ## Rechunking "pushdown" / fusion # # What if we loaded the data how we want it, rather than rechunking after the fact? GDAL / COGs support windowed reads, it's technically feasible to "push" the rechunking operation into the I/O operation. # # In the "old" workflow, each task roughly looks like # # 1. Load data from a single file # 2. Rechunk that data it into smaller pieces: # - Some small pieces were sent to other workers; Some small pieces were kept # - Some small pieces were kept locally, to build a contiguous "timeseries" for the small window by # requesting data from other workers and concatenating it. # # In this proposed workflow, each data reading task will instead look like: # # 1. Load data from *many* files; Reading a small window from each file in the timeseries and concatenating the results. # # And that's it. No rechunking, no communication and (hopefully) no memory issues. # # Unforunately, the user-level *code* for this is much more complicated than a simple `stackstac.median`! I'm hopeful that Dask could become aware of this type of optimization in the future, to push this kind of rechunk down to the I/O layer if the I/O layer supports it (COG does; I think there's proposed work in Zarr to support something like this). Then we might be able to have our cake (store a single copy of the data) and eat it too (fast-ish access for both spatial and timeseries analysis). # # The next section implements this stragegy. If you're just intersted in the result, I recommend skipping to [comparison](#Comparison) below. # # First we need the slices. We know we want chunks of size something like `(len(items), 1, ..., ...)`. We can let Dask figure that out for us. # In[94]: import dask.array chunks = dask.array.empty(bad_data.shape, chunks=(-1, 1, "auto", "auto"))[0, 0].chunks slices = dask.array.core.slices_from_chunks(chunks) slices[:2] # We have a few ways to handle reading the actual data. We could use `stackstac`, `rioxarray`, rasterio directly, etc. We'll use `stackstac` since it will handle the time dimension and other metadata for us. It supports cropping with a `bounds` argument, so we need to find the bounding box (in native CRS) for each window. # In[95]: bounds = [] for x_slice, y_slice in slices: x = bad_data.x[x_slice] y = bad_data.y[y_slice] # TODO: our meta is off by 1; maybe backwards? x_range = x.min().item(), x.max().item() + 0.1 y_range = y.min().item(), y.max().item() + 0.1 bounds.append((x_range[0], y_range[0], x_range[1] + 1, y_range[1] + 1)) # Now we need a DataArray per window. This is a bit awkward, since the *data* will be coming from a lazy / delayed Dask Array, but the metadata needs to be computed eagerly. This could be improved, but seems to work. # In[132]: @dask.delayed def read_window(stac_items, asset, bounds): return stackstac.stack(stac_items.to_dict()["features"], assets=[asset], bounds=bounds).compute(scheduler="single-threaded").where(lambda x: x > 0).data def read_metadata(stac_items, asset, bounds): return stackstac.stack(stac_items.to_dict()["features"], assets=[asset], bounds=bounds).where(lambda x: x > 0) # In[133]: import pystac get_ipython().run_line_magic('time', 'ic = pystac.ItemCollection(signed_items, clone_items=True)') # Next, let's get the List of DataArrays. We'll have one DataArray per "timeseries" / window. # # This section needs to be optimzied / parallelized. We're nominally just operating on metadata here, but it's still taking too long for interactive analysis. # In[161]: metas[0].time # In[134]: get_ipython().run_cell_magic('time', '', 'import dask.array as da\nimport xarray as xr\n\n# might want to parallize this\ntimeseries = []\nassets = ["B04", "B03", "B02"]\n\nfor bound in bounds:\n metas = [read_metadata(ic, asset, bound) for asset in assets]\n data = [read_window(ic, asset, bound) for asset in assets]\n xarrays = []\n for m, d in zip(metas, data):\n xarrays.append(xr.DataArray(\n da.from_delayed(d, shape=m.shape, dtype=m.dtype),\n coords=m.coords,\n dims=m.dims,\n attrs=m.attrs,\n ))\n ts = xr.concat(xarrays, dim="band")\n timeseries.append(ts)\n') # Now that we have a `List[DataArray]`, we need to assemble them into a single large DataArray. To do that, we'll make a series of "strips" where each strip shares the same x coordinates. We'll then concatenate those strips together to get our final DataArray. # In[218]: stac_items = ic asset = "B03" m = stackstac.stack(stac_items.to_dict()["features"], assets=[asset], bounds=bounds[0]).where(lambda x: x > 0) m # In[220]: bad_data[:, 0, :732, :732] # In[135]: import itertools key = lambda arr: arr.x[0].item() strips = [ xr.concat(list(v), dim="y") for _, v in itertools.groupby(timeseries, key=key) ] good_data = xr.concat(strips, dim="x") good_data # In[207]: x = read_window(ic, "B02", bounds[0]).compute() # In[136]: a = bad_data.data.rechunk((-1, 1, 732, 732)).blocks[0, 0, 0, 0] # In[137]: b = good_data.data.blocks[0, 0, 0, 0] # In[139]: get_ipython().run_line_magic('time', 'aa = a.compute()') get_ipython().run_line_magic('time', 'bb = b.compute()') # And now we can compute the median: # In[41]: good_result = good_data.median(dim="time") good_result # ## Comparison # A few things to point out: # # 1. The number of tasks is much smaller: about 5,000 compared to 25,000 from before. # 2. The computation represented by this task graph is *much* simpler. Initially, each output chunk of the `median` would require *many* input chunks, as data is gathered from all many arrays that were read individually. # In[13]: client.wait_for_workers(24) # In[14]: bad_result.data.blocks[0, 0, 0].visualize(optimize_graph=True) # By contrast, each output chunk of the `good_result` median is extremely simple: # In[15]: good_result.data.blocks[0, 0, 0].visualize(optimize_graph=True) # Simpler tasks mean less (or no) communication, which tends to be the bane of Dask clusters everywhere. # # That said, it's not a panacea. This kind of "read small bits from many files" is slower than reading a large byte stream and then using Dask to move data around. I haven't profiled it yet, but it's like some combination of 1.) per file overhead (high latency filesystem, parsing metada, etc.) 2.) HTTPs vs. TCP. So if your computation *does* work # In[16]: from distributed import performance_report # In[ ]: with performance_report("good-median.html"): a = good_result.compute() # In[ ]: with performance_report("bad-median.html"): b = bad_result.compute() # In[ ]: cluster.close() # In[ ]: assert a.equals(b) # ...