Regridding to H3 using XESMF#

import cartopy.crs as ccrs
import cf_xarray  # noqa: F401
import dask
import h3ronpy
import matplotlib.pyplot as plt
import numpy as np
import shapely
import xarray as xr
import xdggs  # noqa: F401
import xesmf

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)
<xarray.core.options.set_options at 0x7f7c0c7c0ec0>
from distributed import Client

client = Client()
client

Client

Client-543f8752-689d-11f0-9231-6045bd3360ad

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

rectilinear grid: the air_temperature example dataset#

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20}).isel(
    time=slice(None, 400)
)
ds
<xarray.Dataset> Size: 4MB
Dimensions:  (lat: 25, time: 400, lon: 53)
Coordinates:
  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
Data variables:
    air      (time, lat, lon) float64 4MB dask.array<chunksize=(20, 25, 53), meta=np.ndarray>
Attributes: (5)
upscaled = ds.interp(
    lon=np.linspace(200, 330, 1060), lat=np.linspace(15, 75, 500)
).assign_coords(lon=lambda ds: (ds["lon"] + 180) % 360 - 180)
upscaled
<xarray.Dataset> Size: 2GB
Dimensions:  (time: 400, lat: 500, lon: 1060)
Coordinates:
  * time     (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * lat      (lat) float64 4kB 15.0 15.12 15.24 15.36 ... 74.64 74.76 74.88 75.0
  * lon      (lon) float64 8kB -160.0 -159.9 -159.8 ... -30.25 -30.12 -30.0
Data variables:
    air      (time, lat, lon) float64 2GB dask.array<chunksize=(20, 500, 1060), meta=np.ndarray>
Attributes: (5)
level = 4
geom = shapely.box(
    float(upscaled["lon"].min()),
    float(upscaled["lat"].min()),
    float(upscaled["lon"].max()),
    float(upscaled["lat"].max()),
)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 901kB
Dimensions:    (cells: 37555)
Coordinates:
  * cell_ids   (cells) uint64 300kB 594512207790735359 ... 596425581361364991
    latitude   (cells) float64 300kB 75.06 75.21 74.94 ... 15.11 14.84 15.18
    longitude  (cells) float64 300kB -109.7 -112.3 -111.2 ... -113.6 -113.9
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  H3Index(level=4)
%%time
regridder = xesmf.Regridder(
    upscaled, target_grid, method="bilinear", locstream_out=True
)
regridder
CPU times: user 4.44 s, sys: 168 ms, total: 4.6 s
Wall time: 4.47 s
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_500x1060_1x37555.nc 
Reuse pre-computed weights? False 
Input grid shape:           (500, 1060) 
Output grid shape:          (1, 37555) 
Periodic in longitude?      False
regridded = regridder.regrid_dataset(
    upscaled, skipna=True, keep_attrs=True
).dggs.decode()
regridded
<xarray.Dataset> Size: 121MB
Dimensions:    (time: 400, cells: 37555)
Coordinates:
  * time       (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * cell_ids   (cells) uint64 300kB 594512207790735359 ... 596425581361364991
    latitude   (cells) float64 300kB 75.06 75.21 74.94 ... 15.11 14.84 15.18
    longitude  (cells) float64 300kB -109.7 -112.3 -111.2 ... -113.6 -113.9
Dimensions without coordinates: cells
Data variables:
    air        (time, cells) float64 120MB dask.array<chunksize=(20, 37555), meta=np.ndarray>
Indexes:
    cell_ids  H3Index(level=4)
Attributes: (6)
computed = regridded.compute()
computed
<xarray.Dataset> Size: 121MB
Dimensions:    (time: 400, cells: 37555)
Coordinates:
  * time       (time) datetime64[ns] 3kB 2013-01-01 ... 2013-04-10T18:00:00
  * cell_ids   (cells) uint64 300kB 594512207790735359 ... 596425581361364991
    latitude   (cells) float64 300kB 75.06 75.21 74.94 ... 15.11 14.84 15.18
    longitude  (cells) float64 300kB -109.7 -112.3 -111.2 ... -113.6 -113.9
Dimensions without coordinates: cells
Data variables:
    air        (time, cells) float64 120MB nan nan 245.6 ... 298.0 nan 298.0
Indexes:
    cell_ids  H3Index(level=4)
Attributes: (6)
computed["air"].dggs.explore(alpha=0.8)

curvilinear grid: the rasm dataset#

ds = xr.tutorial.open_dataset("rasm", chunks={"time": 8})
ds
<xarray.Dataset> Size: 17MB
Dimensions:  (time: 36, y: 205, x: 275)
Coordinates:
  * time     (time) object 288B 1980-09-16 12:00:00 ... 1983-08-17 00:00:00
    xc       (y, x) float64 451kB dask.array<chunksize=(205, 275), meta=np.ndarray>
    yc       (y, x) float64 451kB dask.array<chunksize=(205, 275), meta=np.ndarray>
Dimensions without coordinates: y, x
Data variables:
    Tair     (time, y, x) float64 16MB dask.array<chunksize=(8, 205, 275), meta=np.ndarray>
Attributes: (11)
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.NorthPolarStereo()})
ds["Tair"].isel(time=1).plot.pcolormesh(
    x="xc", y="yc", ax=ax, transform=ccrs.PlateCarree()
)
<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f7b8c110bc0>
../_images/15082f71550319d15c1955dd1449d176ef34a440f5df3d9ce4a16d675c7f9d9d.png
level = 4
geom = shapely.box(0, 16.5, 360, 90)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 2MB
Dimensions:    (cells: 103217)
Coordinates:
  * cell_ids   (cells) uint64 826kB 594475159402840063 ... 596505244414771199
    latitude   (cells) float64 826kB 79.24 79.63 79.12 ... 16.29 16.54 16.46
    longitude  (cells) float64 826kB 38.02 37.33 40.12 ... 139.4 140.1 140.6
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  H3Index(level=4)
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder
CPU times: user 1.23 s, sys: 13.5 ms, total: 1.24 s
Wall time: 1.23 s
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_205x275_1x103217.nc 
Reuse pre-computed weights? False 
Input grid shape:           (205, 275) 
Output grid shape:          (1, 103217) 
Periodic in longitude?      False
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True).dggs.decode().compute()
)
regridded
<xarray.Dataset> Size: 32MB
Dimensions:    (time: 36, cells: 103217)
Coordinates:
  * time       (time) object 288B 1980-09-16 12:00:00 ... 1983-08-17 00:00:00
  * cell_ids   (cells) uint64 826kB 594475159402840063 ... 596505244414771199
    latitude   (cells) float64 826kB 79.24 79.63 79.12 ... 16.29 16.54 16.46
    longitude  (cells) float64 826kB 38.02 37.33 40.12 ... 139.4 140.1 140.6
Dimensions without coordinates: cells
Data variables:
    Tair       (time, cells) float64 30MB nan -2.794 nan nan ... nan nan nan nan
Indexes:
    cell_ids  H3Index(level=4)
Attributes: (12)
regridded["Tair"].dggs.explore(alpha=0.8)

curvilinear grid: the ROMS_example dataset#

ds = xr.tutorial.open_dataset("ROMS_example", chunks={"time": 1})
ds
<xarray.Dataset> Size: 19MB
Dimensions:     (ocean_time: 2, s_rho: 30, eta_rho: 191, xi_rho: 371)
Coordinates:
    Cs_r        (s_rho) float64 240B dask.array<chunksize=(30,), meta=np.ndarray>
    lon_rho     (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    hc          float64 8B ...
    h           (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    lat_rho     (eta_rho, xi_rho) float64 567kB dask.array<chunksize=(191, 371), meta=np.ndarray>
    Vtransform  int32 4B ...
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
Dimensions without coordinates: eta_rho, xi_rho
Data variables:
    salt        (ocean_time, s_rho, eta_rho, xi_rho) float32 17MB dask.array<chunksize=(1, 15, 96, 186), meta=np.ndarray>
    zeta        (ocean_time, eta_rho, xi_rho) float32 567kB dask.array<chunksize=(1, 191, 371), meta=np.ndarray>
Attributes: (34)
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.Miller()}, figsize=(12, 12))
ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
    x="lon_rho", y="lat_rho", ax=ax, transform=ccrs.PlateCarree()
)
/tmp/ipykernel_4657/3005707708.py:2: DeprecationWarning: dropping variables using `drop` is deprecated; use drop_vars.
  ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
<cartopy.mpl.geocollection.GeoQuadMesh at 0x7f7b8c21af60>
../_images/02fe61b3c3e6d701270dedfefccb4521611d20bf04885272503cd734b44570be.png
min_lon, max_lon = map(float, dask.compute(ds["lon_rho"].min(), ds["lon_rho"].max()))
min_lat, max_lat = map(float, dask.compute(ds["lat_rho"].min(), ds["lat_rho"].max()))
level = 6
geom = shapely.box(min_lon, min_lat, max_lon, max_lat)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid
<xarray.Dataset> Size: 155kB
Dimensions:    (cells: 6471)
Coordinates:
  * cell_ids   (cells) uint64 52kB 604679243367972863 ... 604686774727344127
    latitude   (cells) float64 52kB 28.78 29.25 29.15 29.2 ... 27.51 27.57 27.51
    longitude  (cells) float64 52kB -87.72 -87.74 -87.74 ... -89.37 -89.28
Dimensions without coordinates: cells
Data variables:
    *empty*
Indexes:
    cell_ids  H3Index(level=6)
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder
CPU times: user 513 ms, sys: 8.77 ms, total: 522 ms
Wall time: 537 ms
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_191x371_1x6471.nc 
Reuse pre-computed weights? False 
Input grid shape:           (191, 371) 
Output grid shape:          (1, 6471) 
Periodic in longitude?      False
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True, na_thres=0.5)
    .dggs.decode()
    .compute()
    .where(lambda ds: ds.notnull(), drop=True)
)
regridded
/home/runner/work/regridding-ecosystem-overview/regridding-ecosystem-overview/.pixi/envs/default/lib/python3.12/site-packages/xarray/computation/apply_ufunc.py:450: PerformanceWarning: Regridding is increasing the number of chunks by a factor of 35.0, you might want to specify sizes in `output_chunks` in the regridder call. Default behaviour is to preserve the chunk sizes from the input (96, 186).
  result_vars[name] = func(*variable_args)
<xarray.Dataset> Size: 2MB
Dimensions:     (ocean_time: 2, s_rho: 30, cells: 2980)
Coordinates:
    Cs_r        (s_rho) float64 240B -0.933 -0.8092 ... -0.0005206 -5.758e-05
    hc          float64 8B 20.0
    Vtransform  int32 4B 2
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
  * cell_ids    (cells) uint64 24kB 604679246991851519 ... 604686709094875135
    latitude    (cells) float64 24kB 29.31 29.36 29.36 ... 27.74 27.64 27.69
    longitude   (cells) float64 24kB -87.83 -87.87 -87.8 ... -91.33 -91.36
Dimensions without coordinates: cells
Data variables:
    salt        (ocean_time, s_rho, cells) float64 1MB 35.39 35.72 ... 36.16
    zeta        (ocean_time, cells) float64 48kB -0.2887 -0.2773 ... -0.1436
Indexes:
    cell_ids  H3Index(level=6)
Attributes: (35)
regridded["salt"].dggs.explore(alpha=0.8)