Regridding to Healpix Grid using XESMF#

import astropy.coordinates
import cartopy.crs as ccrs
import cdshealpix.nested
import cf_xarray  # noqa: F401
import dask
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xdggs
import xesmf

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)
<xarray.core.options.set_options at 0x7f2c77106b70>
from distributed import Client

client = Client()
client

Client

Client-65712eda-689d-11f0-92a6-6045bd3360ad

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

rectilinear grid: the air_temperature example dataset#

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20}).isel(
    time=slice(None, 400)
)
ds
<xarray.Dataset> Size: 4MB
Dimensions:  (lat: 25, time: 400, lon: 53)
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
Data variables:
    air      (time, lat, lon) float64 4MB dask.array<chunksize=(20, 25, 53), meta=np.ndarray>
Attributes: (5)
upscaled = ds.interp(lon=np.linspace(200, 330, 1060), lat=np.linspace(15, 75, 500))
upscaled
<xarray.Dataset> Size: 2GB
Dimensions:  (time: 400, lat: 500, lon: 1060)
Coordinates:
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * lon      (lon) float64 8kB 200.0 200.1 200.2 200.4 ... 329.8 329.9 330.0
  * lat      (lat) float64 4kB 15.0 15.12 15.24 15.36 ... 74.64 74.76 74.88 75.0
Data variables:
    air      (time, lat, lon) float64 2GB dask.array<chunksize=(20, 500, 1060), meta=np.ndarray>
Attributes: (5)
level = 7
lon = astropy.coordinates.Longitude(
    [200, 225, 250, 275, 300, 330, 330, 300, 275, 250, 225, 200], unit="degree"
)
lat = astropy.coordinates.Latitude(
    [15, 15, 15, 15, 15, 15, 75, 75, 75, 75, 75, 75], unit="degree"
)
cell_ids, _, _ = cdshealpix.nested.polygon_search(lon, lat, depth=level, flat=True)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 610kB
Dimensions:    (cells: 25400)
Coordinates:
  * cell_ids   (cells) uint64 203kB 33629 33630 33631 ... 131069 131070 131071
    latitude   (cells) float64 203kB 15.09 15.09 15.4 ... 41.01 41.01 41.41
    longitude  (cells) float64 203kB 229.6 228.9 229.2 ... 270.4 269.6 270.0
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
%%time
regridder = xesmf.Regridder(
    upscaled, target_grid, method="bilinear", locstream_out=True
)
regridder
CPU times: user 4.07 s, sys: 163 ms, total: 4.23 s
Wall time: 4.11 s
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_500x1060_1x25400.nc 
Reuse pre-computed weights? False 
Input grid shape:           (500, 1060) 
Output grid shape:          (1, 25400) 
Periodic in longitude?      False
regridded = regridder.regrid_dataset(
    upscaled, skipna=True, keep_attrs=True
).dggs.decode()
regridded
<xarray.Dataset> Size: 82MB
Dimensions:    (time: 400, cells: 25400)
Coordinates:
  * time       (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * cell_ids   (cells) uint64 203kB 33629 33630 33631 ... 131069 131070 131071
    latitude   (cells) float64 203kB 15.09 15.09 15.4 ... 41.01 41.01 41.41
    longitude  (cells) float64 203kB 229.6 228.9 229.2 ... 270.4 269.6 270.0
Dimensions without coordinates: cells
Data variables:
    air        (time, cells) float64 81MB dask.array<chunksize=(20, 25400), meta=np.ndarray>
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
Attributes: (6)
computed = regridded.compute()
computed
<xarray.Dataset> Size: 82MB
Dimensions:    (time: 400, cells: 25400)
Coordinates:
  * time       (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * cell_ids   (cells) uint64 203kB 33629 33630 33631 ... 131069 131070 131071
    latitude   (cells) float64 203kB 15.09 15.09 15.4 ... 41.01 41.01 41.41
    longitude  (cells) float64 203kB 229.6 228.9 229.2 ... 270.4 269.6 270.0
Dimensions without coordinates: cells
Data variables:
    air        (time, cells) float64 81MB 295.0 295.0 294.7 ... 288.4 286.9
Indexes:
    cell_ids  HealpixIndex(level=7, indexing_scheme=nested)
Attributes: (6)
computed["air"].dggs.explore(alpha=0.8)

curvilinear grid: the rasm dataset#

ds = xr.tutorial.open_dataset("rasm", chunks={"time": 8})
ds
<xarray.Dataset> Size: 17MB
Dimensions:  (time: 36, y: 205, x: 275)
Coordinates:
  * time     (time) object 288B 1980-09-16 12:00:00 ... 1983-08-17 00:00:00
    xc       (y, x) float64 451kB dask.array<chunksize=(205, 275), meta=np.ndarray>
    yc       (y, x) float64 451kB dask.array<chunksize=(205, 275), meta=np.ndarray>
Dimensions without coordinates: y, x
Data variables:
    Tair     (time, y, x) float64 16MB dask.array<chunksize=(8, 205, 275), meta=np.ndarray>
Attributes: (11)
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.NorthPolarStereo()})
ds["Tair"].isel(time=1).plot.pcolormesh(
    x="xc", y="yc", ax=ax, transform=ccrs.PlateCarree()
)
<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f2c38971430>
../_images/15082f71550319d15c1955dd1449d176ef34a440f5df3d9ce4a16d675c7f9d9d.png
level = 8
lon = astropy.coordinates.Longitude(0, unit="degree")
lat = astropy.coordinates.Latitude(90, unit="degree")
cell_ids, _, _ = cdshealpix.nested.cone_search(
    lon, lat, depth=level, flat=True, radius=(90 - 16.5) << astropy.units.degree
)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 7MB
Dimensions:    (cells: 283136)
Coordinates:
  * cell_ids   (cells) uint64 2MB 3573 3574 3575 3577 ... 524285 524286 524287
    latitude   (cells) float64 2MB 16.33 16.33 16.49 16.33 ... 41.41 41.41 41.61
    longitude  (cells) float64 2MB 48.34 47.99 48.16 47.64 ... 270.2 269.8 270.0
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  HealpixIndex(level=8, indexing_scheme=nested)
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder
CPU times: user 2.39 s, sys: 45.2 ms, total: 2.44 s
Wall time: 2.39 s
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_205x275_1x283136.nc 
Reuse pre-computed weights? False 
Input grid shape:           (205, 275) 
Output grid shape:          (1, 283136) 
Periodic in longitude?      False
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True).dggs.decode().compute()
)
regridded
/home/runner/work/regridding-ecosystem-overview/regridding-ecosystem-overview/.pixi/envs/default/lib/python3.12/site-packages/distributed/client.py:3383: UserWarning: Sending large graph of size 15.31 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
  warnings.warn(
<xarray.Dataset> Size: 88MB
Dimensions:    (time: 36, cells: 283136)
Coordinates:
  * time       (time) object 288B 1980-09-16 12:00:00 ... 1983-08-17 00:00:00
  * cell_ids   (cells) uint64 2MB 3573 3574 3575 3577 ... 524285 524286 524287
    latitude   (cells) float64 2MB 16.33 16.33 16.49 16.33 ... 41.41 41.41 41.61
    longitude  (cells) float64 2MB 48.34 47.99 48.16 47.64 ... 270.2 269.8 270.0
Dimensions without coordinates: cells
Data variables:
    Tair       (time, cells) float64 82MB nan nan nan nan ... nan nan nan nan
Indexes:
    cell_ids  HealpixIndex(level=8, indexing_scheme=nested)
Attributes: (12)
regridded["Tair"].dggs.explore(alpha=0.8)

curvilinear grid: the ROMS_example dataset#

ds = xr.tutorial.open_dataset(
    "ROMS_example", chunks={"time": 1, "eta_rho": -1, "xi_rho": -1}
)
ds
<xarray.Dataset> Size: 19MB
Dimensions:     (ocean_time: 2, s_rho: 30, eta_rho: 191, xi_rho: 371)
Coordinates:
    Cs_r        (s_rho) float64 240B dask.array<chunksize=(30,), meta=np.ndarray>
    lon_rho     (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    hc          float64 8B ...
    h           (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    lat_rho     (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    Vtransform  int32 4B ...
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
Dimensions without coordinates: eta_rho, xi_rho
Data variables:
    salt        (ocean_time, s_rho, eta_rho, xi_rho) float32 17MB dask.array<chunksize=(1, 15, 191, 371), meta=np.ndarray>
    zeta        (ocean_time, eta_rho, xi_rho) float32 567kB dask.array<chunksize=(1, 191, 371), meta=np.ndarray>
Attributes: (34)
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.Miller()}, figsize=(12, 12))
ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
    x="lon_rho", y="lat_rho", ax=ax, transform=ccrs.PlateCarree()
)
/tmp/ipykernel_4774/3005707708.py:2: DeprecationWarning: dropping variables using `drop` is deprecated; use drop_vars.
  ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f2c38fe7e30>
../_images/02fe61b3c3e6d701270dedfefccb4521611d20bf04885272503cd734b44570be.png
min_lon, max_lon = map(float, dask.compute(ds["lon_rho"].min(), ds["lon_rho"].max()))
min_lat, max_lat = map(float, dask.compute(ds["lat_rho"].min(), ds["lat_rho"].max()))
grid_info = xdggs.HealpixInfo.from_dict({"level": 12, "indexing_scheme": "nested"})
lon = astropy.coordinates.Longitude([min_lon, max_lon, max_lon, min_lon], unit="degree")
lat = astropy.coordinates.Latitude([min_lat, min_lat, max_lat, max_lat], unit="degree")
cell_ids, _, _ = cdshealpix.nested.polygon_search(
    lon, lat, depth=grid_info.level, flat=True
)
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode(grid_info)
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 2MB
Dimensions:    (cells: 96332)
Coordinates:
  * cell_ids   (cells) uint64 771kB 133389141 133389142 ... 133964297 133964298
    latitude   (cells) float64 771kB 27.45 27.45 27.46 ... 30.9 30.91 30.91
    longitude  (cells) float64 771kB 271.9 271.9 271.9 ... 269.1 269.1 269.1
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  HealpixIndex(level=12, indexing_scheme=nested)
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder
CPU times: user 1.25 s, sys: 18.4 ms, total: 1.27 s
Wall time: 1.25 s
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_191x371_1x96332.nc 
Reuse pre-computed weights? False 
Input grid shape:           (191, 371) 
Output grid shape:          (1, 96332) 
Periodic in longitude?      False
regridded = regridder.regrid_dataset(
    ds, keep_attrs=True, skipna=True, na_thres=0.5
).dggs.decode()
regridded
<xarray.Dataset> Size: 50MB
Dimensions:     (ocean_time: 2, s_rho: 30, cells: 96332)
Coordinates:
    Cs_r        (s_rho) float64 240B dask.array<chunksize=(30,), meta=np.ndarray>
    hc          float64 8B ...
    Vtransform  int32 4B ...
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
  * cell_ids    (cells) uint64 771kB 133389141 133389142 ... 133964297 133964298
    latitude    (cells) float64 771kB 27.45 27.45 27.46 ... 30.9 30.91 30.91
    longitude   (cells) float64 771kB 271.9 271.9 271.9 ... 269.1 269.1 269.1
Dimensions without coordinates: cells
Data variables:
    salt        (ocean_time, s_rho, cells) float64 46MB dask.array<chunksize=(1, 15, 96332), meta=np.ndarray>
    zeta        (ocean_time, cells) float64 2MB dask.array<chunksize=(1, 96332), meta=np.ndarray>
Indexes:
    cell_ids  HealpixIndex(level=12, indexing_scheme=nested)
Attributes: (35)
computed = regridded.compute().where(lambda ds: ds.notnull(), drop=True)
computed
<xarray.Dataset> Size: 24MB
Dimensions:     (ocean_time: 2, s_rho: 30, cells: 45625)
Coordinates:
    Cs_r        (s_rho) float64 240B -0.933 -0.8092 ... -0.0005206 -5.758e-05
    hc          float64 8B 20.0
    Vtransform  int32 4B 2
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
  * cell_ids    (cells) uint64 365kB 133406431 133406443 ... 133798305 133798306
    latitude    (cells) float64 365kB 27.73 27.71 27.71 ... 29.97 29.98 29.98
    longitude   (cells) float64 365kB 269.2 269.1 269.1 ... 266.2 266.2 266.2
Dimensions without coordinates: cells
Data variables:
    salt        (ocean_time, s_rho, cells) float64 22MB 34.94 34.93 ... 12.12
    zeta        (ocean_time, cells) float64 730kB -0.1275 -0.1262 ... -0.01867
Indexes:
    cell_ids  HealpixIndex(level=12, indexing_scheme=nested)
Attributes: (35)
computed["salt"].dggs.explore(alpha=0.8)