Regridding using the grid-weights library

Regridding using the grid-weights library#

import astropy.coordinates
import cdshealpix.nested
import cf_xarray  # noqa: F401
import geoarrow.rust.core as geoarrow
import grid_weights.api as grid_weights
import numpy as np
import xarray as xr
import xdggs  # noqa: F401
from grid_indexing import infer_cell_geometries

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)
<xarray.core.options.set_options at 0x7f982f795f40>
from distributed import Client

client = Client()
client

Client

Client-de0671eb-689c-11f0-90e0-6045bd3360ad

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

rectilinear grid: the air_temperature example dataset#

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20}).isel(
    time=slice(None, 400)
)
ds
<xarray.Dataset> Size: 4MB
Dimensions:  (lat: 25, time: 400, lon: 53)
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
Data variables:
    air      (time, lat, lon) float64 4MB dask.array<chunksize=(20, 25, 53), meta=np.ndarray>
Attributes: (5)
upscaled = (
    ds.interp(lon=np.linspace(200, 330, 530), lat=np.linspace(15, 76, 250))
    .assign_coords(lon=lambda ds: (ds["lon"] + 180) % 360 - 180)
    .chunk({"lon": 265, "lat": 125})
)
upscaled
<xarray.Dataset> Size: 424MB
Dimensions:  (time: 400, lat: 250, lon: 530)
Coordinates:
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * lat      (lat) float64 2kB 15.0 15.24 15.49 15.73 ... 75.27 75.51 75.76 76.0
  * lon      (lon) float64 4kB -160.0 -159.8 -159.5 ... -30.49 -30.25 -30.0
Data variables:
    air      (time, lat, lon) float64 424MB dask.array<chunksize=(20, 125, 265), meta=np.ndarray>
Attributes: (5)
level = 7
lon = astropy.coordinates.Longitude(
    [200, 225, 250, 275, 300, 330, 330, 300, 275, 250, 225, 200], unit="degree"
)
lat = astropy.coordinates.Latitude(
    [15, 15, 15, 15, 15, 15, 75, 75, 75, 75, 75, 75], unit="degree"
)
cell_ids, _, _ = cdshealpix.nested.polygon_search(lon, lat, depth=level, flat=True)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 610kB
Dimensions:    (cells: 25400)
Coordinates:
  * cell_ids   (cells) uint64 203kB 33629 33630 33631 ... 131069 131070 131071
    latitude   (cells) float64 203kB 15.09 15.09 15.4 ... 41.01 41.01 41.41
    longitude  (cells) float64 203kB 229.6 228.9 229.2 ... 270.4 269.6 270.0
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
source_geoms_ = geoarrow.to_shapely(infer_cell_geometries(upscaled)).reshape((530, 250))
source_geoms = xr.DataArray(
    source_geoms_, dims=["lon", "lat"], coords=upscaled[["lon", "lat"]].coords
).chunk({"lon": 265, "lat": 125})
source_geoms
<xarray.DataArray (lon: 530, lat: 250)> Size: 1MB
dask.array<chunksize=(265, 125), meta=np.ndarray>
Coordinates:
  * lon      (lon) float64 4kB -160.0 -159.8 -159.5 ... -30.49 -30.25 -30.0
  * lat      (lat) float64 2kB 15.0 15.24 15.49 15.73 ... 75.27 75.51 75.76 76.0
target_geoms = target_grid.dggs.cell_boundaries().chunk({"cells": 5100}).dggs.decode()
target_geoms
<xarray.DataArray (cells: 25400)> Size: 203kB
dask.array<chunksize=(5100,), meta=np.ndarray>
Coordinates:
  * cell_ids  (cells) uint64 203kB dask.array<chunksize=(5100,), meta=np.ndarray>
Dimensions without coordinates: cells
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
%%time
algorithms = grid_weights.Algorithms.by_variable(upscaled, default="conservative")
indexed_cells = grid_weights.create_index(source_geoms).query(
    target_geoms, methods=algorithms.unique()
)
weights = grid_weights.weights(source_geoms, target_geoms, indexed_cells)
weights
/home/runner/work/regridding-ecosystem-overview/regridding-ecosystem-overview/.pixi/envs/default/lib/python3.12/site-packages/distributed/client.py:3383: UserWarning: Sending large graph of size 12.89 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
  warnings.warn(
CPU times: user 3.33 s, sys: 135 ms, total: 3.46 s
Wall time: 5.7 s
<xarray.Dataset> Size: 27GB
Dimensions:          (target_cells: 25400, source_lon: 530, source_lat: 250)
Coordinates:
  * source_lon       (source_lon) float64 4kB -160.0 -159.8 ... -30.25 -30.0
  * source_lat       (source_lat) float64 2kB 15.0 15.24 15.49 ... 75.76 76.0
  * target_cell_ids  (target_cells) uint64 203kB dask.array<chunksize=(5100,), meta=np.ndarray>
Dimensions without coordinates: target_cells
Data variables:
    conservative     (target_cells, source_lon, source_lat) float64 27GB dask.array<chunksize=(5100, 265, 125), meta=np.ndarray>
Indexes:
    target_cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
Attributes: (1)
%%time
regridded = algorithms.regrid(upscaled, weights)
regridded
CPU times: user 31.7 ms, sys: 977 μs, total: 32.7 ms
Wall time: 31.9 ms
<xarray.Dataset> Size: 81MB
Dimensions:   (time: 400, cells: 25400)
Coordinates:
  * cell_ids  (cells) uint64 203kB dask.array<chunksize=(5100,), meta=np.ndarray>
  * time      (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
Dimensions without coordinates: cells
Data variables:
    air       (time, cells) float64 81MB dask.array<chunksize=(20, 5100), meta=np.ndarray>
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
Attributes: (5)
%%time
computed = regridded.compute()
computed
/home/runner/work/regridding-ecosystem-overview/regridding-ecosystem-overview/.pixi/envs/default/lib/python3.12/site-packages/distributed/client.py:3383: UserWarning: Sending large graph of size 16.22 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
  warnings.warn(
CPU times: user 6.72 s, sys: 320 ms, total: 7.04 s
Wall time: 22.9 s
<xarray.Dataset> Size: 81MB
Dimensions:   (time: 400, cells: 25400)
Coordinates:
  * cell_ids  (cells) uint64 203kB 33629 33630 33631 ... 131069 131070 131071
  * time      (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
Dimensions without coordinates: cells
Data variables:
    air       (time, cells) float64 81MB 294.9 295.0 294.7 ... 288.9 288.4 286.9
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
Attributes: (5)
computed["air"].dggs.explore(alpha=0.8)