# Regridding to Healpix Grid using XESMF

In [None]:
import astropy.coordinates
import cartopy.crs as ccrs
import cdshealpix.nested
import cf_xarray  # noqa: F401
import dask
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xdggs
import xesmf

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)

In [None]:
from distributed import Client

client = Client()
client

## rectilinear grid: the `air_temperature` example dataset

In [None]:
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 20}).isel(
    time=slice(None, 400)
)
ds

In [None]:
upscaled = ds.interp(lon=np.linspace(200, 330, 1060), lat=np.linspace(15, 75, 500))
upscaled

In [None]:
level = 7
lon = astropy.coordinates.Longitude(
    [200, 225, 250, 275, 300, 330, 330, 300, 275, 250, 225, 200], unit="degree"
)
lat = astropy.coordinates.Latitude(
    [15, 15, 15, 15, 15, 15, 75, 75, 75, 75, 75, 75], unit="degree"
)
cell_ids, _, _ = cdshealpix.nested.polygon_search(lon, lat, depth=level, flat=True)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(
    upscaled, target_grid, method="bilinear", locstream_out=True
)
regridder

In [None]:
regridded = regridder.regrid_dataset(
    upscaled, skipna=True, keep_attrs=True
).dggs.decode()
regridded

In [None]:
computed = regridded.compute()
computed

In [None]:
computed["air"].dggs.explore(alpha=0.8)

## curvilinear grid: the `rasm` dataset

In [None]:
ds = xr.tutorial.open_dataset("rasm", chunks={"time": 8})
ds

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.NorthPolarStereo()})
ds["Tair"].isel(time=1).plot.pcolormesh(
    x="xc", y="yc", ax=ax, transform=ccrs.PlateCarree()
)

In [None]:
level = 8
lon = astropy.coordinates.Longitude(0, unit="degree")
lat = astropy.coordinates.Latitude(90, unit="degree")
cell_ids, _, _ = cdshealpix.nested.cone_search(
    lon, lat, depth=level, flat=True, radius=(90 - 16.5) << astropy.units.degree
)

target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode({"grid_name": "healpix", "level": level, "indexing_scheme": "nested"})
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder

In [None]:
regridded = (
    regridder.regrid_dataset(ds, keep_attrs=True, skipna=True).dggs.decode().compute()
)
regridded

In [None]:
regridded["Tair"].dggs.explore(alpha=0.8)

## curvilinear grid: the `ROMS_example` dataset

In [None]:
ds = xr.tutorial.open_dataset(
    "ROMS_example", chunks={"time": 1, "eta_rho": -1, "xi_rho": -1}
)
ds

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.Miller()}, figsize=(12, 12))
ds["salt"].isel(ocean_time=0, s_rho=1).drop(["hc", "Vtransform"]).plot.pcolormesh(
    x="lon_rho", y="lat_rho", ax=ax, transform=ccrs.PlateCarree()
)

In [None]:
min_lon, max_lon = map(float, dask.compute(ds["lon_rho"].min(), ds["lon_rho"].max()))
min_lat, max_lat = map(float, dask.compute(ds["lat_rho"].min(), ds["lat_rho"].max()))

In [None]:
grid_info = xdggs.HealpixInfo.from_dict({"level": 12, "indexing_scheme": "nested"})

In [None]:
lon = astropy.coordinates.Longitude([min_lon, max_lon, max_lon, min_lon], unit="degree")
lat = astropy.coordinates.Latitude([min_lat, min_lat, max_lat, max_lat], unit="degree")
cell_ids, _, _ = cdshealpix.nested.polygon_search(
    lon, lat, depth=grid_info.level, flat=True
)

In [None]:
target_grid = (
    xr.Dataset(coords={"cell_ids": ("cells", cell_ids)})
    .dggs.decode(grid_info)
    .dggs.assign_latlon_coords()
)
target_grid

In [None]:
%%time
regridder = xesmf.Regridder(ds, target_grid, method="bilinear", locstream_out=True)
regridder

In [None]:
regridded = regridder.regrid_dataset(
    ds, keep_attrs=True, skipna=True, na_thres=0.5
).dggs.decode()
regridded

In [None]:
computed = regridded.compute().where(lambda ds: ds.notnull(), drop=True)
computed

In [None]:
computed["salt"].dggs.explore(alpha=0.8)