# Regridding to H3 using XESMF

In [None]:
import cartopy.crs as ccrs
import cf_xarray  # noqa: F401
import dask
import h3ronpy
import matplotlib.pyplot as plt
import numpy as np
import shapely
import xarray as xr
import xdggs  # noqa: F401
import xesmf

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)

In [None]:
from distributed import Client

client = Client()
client

## rectilinear grid: the `air_temperature` example dataset

In [None]:
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20}).isel(
    time=slice(None, 400)
)
ds

In [None]:
upscaled = ds.interp(
    lon=np.linspace(200, 330, 1060), lat=np.linspace(15, 75, 500)
).assign_coords(lon=lambda ds: (ds["lon"] + 180) % 360 - 180)
upscaled

In [None]:
level = 4
geom = shapely.box(
    float(upscaled["lon"].min()),
    float(upscaled["lat"].min()),
    float(upscaled["lon"].max()),
    float(upscaled["lat"].max()),
)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(
    upscaled, target_grid, method="bilinear", locstream_out=True
)
regridder

In [None]:
regridded = regridder.regrid_dataset(
    upscaled, skipna=True, keep_attrs=True
).dggs.decode()
regridded

In [None]:
computed = regridded.compute()
computed

In [None]:
computed["air"].dggs.explore(alpha=0.8)

## curvilinear grid: the `rasm` dataset

In [None]:
ds = xr.tutorial.open_dataset("rasm", chunks={"time": 8})
ds

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.NorthPolarStereo()})
ds["Tair"].isel(time=1).plot.pcolormesh(
    x="xc", y="yc", ax=ax, transform=ccrs.PlateCarree()
)

In [None]:
level = 4
geom = shapely.box(0, 16.5, 360, 90)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder

In [None]:
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True).dggs.decode().compute()
)
regridded

In [None]:
regridded["Tair"].dggs.explore(alpha=0.8)

## curvilinear grid: the `ROMS_example` dataset

In [None]:
ds = xr.tutorial.open_dataset("ROMS_example", chunks={"time": 1})
ds

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.Miller()}, figsize=(12, 12))
ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
    x="lon_rho", y="lat_rho", ax=ax, transform=ccrs.PlateCarree()
)

In [None]:
min_lon, max_lon = map(float, dask.compute(ds["lon_rho"].min(), ds["lon_rho"].max()))
min_lat, max_lat = map(float, dask.compute(ds["lat_rho"].min(), ds["lat_rho"].max()))

In [None]:
level = 6
geom = shapely.box(min_lon, min_lat, max_lon, max_lat)
cell_ids = np.asarray(
    h3ronpy.vector.geometry_to_cells(
        geom, resolution=level, containment_mode=h3ronpy.ContainmentMode.Covers
    )
)
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "h3", "level": level})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder

In [None]:
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True, na_thres=0.5)
    .dggs.decode()
    .compute()
    .where(lambda ds: ds.notnull(), drop=True)
)
regridded

In [None]:
regridded["salt"].dggs.explore(alpha=0.8)