Regridding model data with xESMF

Import python packages

# supress warnings
import warnings
warnings.filterwarnings('ignore') # don't output warnings

import os
# import packages
import xarray as xr
xr.set_options(display_style='html')
import intake
import cftime
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import numpy as np
import xesmf as xe
from cmcrameri import cm

%matplotlib inline

Open CMIP6 online catalog

cat_url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
col = intake.open_esm_datastore(cat_url)
col

pangeo-cmip6 catalog with 7619 dataset(s) from 517624 asset(s):

unique
activity_id 18
institution_id 36
source_id 88
experiment_id 170
member_id 657
table_id 37
variable_id 709
grid_label 10
zstore 517624
dcpp_init_year 60
version 710

Get data in xarray

Search od550aer variable dataset

cat = col.search(experiment_id=['historical'], variable_id='od550aer', member_id=['r1i1p1f1'], grid_label='gn')
cat.df
activity_id institution_id source_id experiment_id member_id table_id variable_id grid_label zstore dcpp_init_year version
0 CMIP NCAR CESM2-WACCM historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/NCAR/CESM2-WACCM/histori... NaN 20190227
1 CMIP NCAR CESM2 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/NCAR/CESM2/historical/r1... NaN 20190308
2 CMIP CCCma CanESM5 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/CCCma/CanESM5/historical... NaN 20190429
3 CMIP HAMMOZ-Consortium MPI-ESM-1-2-HAM historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/HAMMOZ-Consortium/MPI-ES... NaN 20190627
4 CMIP MPI-M MPI-ESM1-2-HR historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/hist... NaN 20190710
5 CMIP MPI-M MPI-ESM1-2-LR historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/hist... NaN 20190710
6 CMIP NCC NorESM2-LM historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/NCC/NorESM2-LM/historica... NaN 20190815
7 CMIP BCC BCC-ESM1 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/BCC/BCC-ESM1/historical/... NaN 20190918
8 CMIP NCAR CESM2-WACCM-FV2 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/NCAR/CESM2-WACCM-FV2/his... NaN 20191120
9 CMIP NCAR CESM2-FV2 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/NCAR/CESM2-FV2/historica... NaN 20191120
10 CMIP MRI MRI-ESM2-0 historical r1i1p1f1 AERmon od550aer gn gs://cmip6/CMIP6/CMIP/MRI/MRI-ESM2-0/historica... NaN 20200207
cat.df['source_id'].unique()
array(['CESM2-WACCM', 'CESM2', 'CanESM5', 'MPI-ESM-1-2-HAM',
       'MPI-ESM1-2-HR', 'MPI-ESM1-2-LR', 'NorESM2-LM', 'BCC-ESM1',
       'CESM2-WACCM-FV2', 'CESM2-FV2', 'MRI-ESM2-0'], dtype=object)

Create dictionary from the list of datasets we found

  • This step may take several minutes so be patient!

dset_dict = cat.to_dataset_dict(zarr_kwargs={'use_cftime':True})
--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'
100.00% [11/11 01:42<00:00]
list(dset_dict.keys())
['CMIP.MPI-M.MPI-ESM1-2-LR.historical.AERmon.gn',
 'CMIP.BCC.BCC-ESM1.historical.AERmon.gn',
 'CMIP.NCAR.CESM2-WACCM.historical.AERmon.gn',
 'CMIP.MPI-M.MPI-ESM1-2-HR.historical.AERmon.gn',
 'CMIP.MRI.MRI-ESM2-0.historical.AERmon.gn',
 'CMIP.CCCma.CanESM5.historical.AERmon.gn',
 'CMIP.NCAR.CESM2.historical.AERmon.gn',
 'CMIP.NCC.NorESM2-LM.historical.AERmon.gn',
 'CMIP.HAMMOZ-Consortium.MPI-ESM-1-2-HAM.historical.AERmon.gn',
 'CMIP.NCAR.CESM2-WACCM-FV2.historical.AERmon.gn',
 'CMIP.NCAR.CESM2-FV2.historical.AERmon.gn']

Select model and visualize a single date

  • Use data as xarray to make a simple plot

ds = dset_dict['CMIP.NCC.NorESM2-LM.historical.AERmon.gn']
ds
<xarray.Dataset>
Dimensions:     (bnds: 2, lat: 96, lon: 144, member_id: 1, time: 1980)
Coordinates:
  * lat         (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 86.21 88.11 90.0
    lat_bnds    (lat, bnds) float64 dask.array<chunksize=(96, 2), meta=np.ndarray>
  * lon         (lon) float64 0.0 2.5 5.0 7.5 10.0 ... 350.0 352.5 355.0 357.5
    lon_bnds    (lon, bnds) float64 dask.array<chunksize=(144, 2), meta=np.ndarray>
  * time        (time) object 1850-01-16 12:00:00 ... 2014-12-16 12:00:00
    time_bnds   (time, bnds) object dask.array<chunksize=(1980, 2), meta=np.ndarray>
    wavelength  float64 ...
  * member_id   (member_id) <U8 'r1i1p1f1'
Dimensions without coordinates: bnds
Data variables:
    od550aer    (member_id, time, lat, lon) float32 dask.array<chunksize=(1, 990, 96, 144), meta=np.ndarray>
Attributes: (12/52)
    Conventions:               CF-1.7 CMIP-6.2
    activity_id:               CMIP
    branch_method:             Hybrid-restart from year 1600-01-01 of piControl
    branch_time:               0.0
    branch_time_in_child:      0.0
    branch_time_in_parent:     430335.0
    ...                        ...
    title:                     NorESM2-LM output prepared for CMIP6
    tracking_id:               hdl:21.14100/efd7a56e-94a8-47f5-b3d8-06ae02268...
    variable_id:               od550aer
    variant_label:             r1i1p1f1
    intake_esm_varname:        ['od550aer']
    intake_esm_dataset_key:    CMIP.NCC.NorESM2-LM.historical.AERmon.gn

Plot on NorthPolarStereo and set the latitude limit

def polarCentral_set_latlim(lat_lims, ax):
    ax.set_extent([-180, 180, lat_lims[0], lat_lims[1]], ccrs.PlateCarree())
    # Compute a circle in axes coordinates, which we can use as a boundary
    # for the map. We can pan/zoom as much as we like - the boundary will be
    # permanently circular.
    theta = np.linspace(0, 2*np.pi, 100)
    center, radius = [0.5, 0.5], 0.5
    verts = np.vstack([np.sin(theta), np.cos(theta)]).T
    circle = mpath.Path(verts * radius + center)

    ax.set_boundary(circle, transform=ax.transAxes)
fig = plt.figure(1, figsize=[10,10])

# Fix extent
minval = 0
maxval = 0.3

ax = plt.subplot(1, 1, 1, projection=ccrs.NorthPolarStereo())
ax.coastlines()
ax.gridlines()
polarCentral_set_latlim([50,90], ax)
ds['od550aer'].sel(time=cftime.DatetimeNoLeap(1985, 1, 16, 12, 0, 0, 0)).plot(ax=ax, vmin=minval, vmax=maxval, transform=ccrs.PlateCarree(), cmap=cm.oslo_r)
<matplotlib.collections.QuadMesh at 0x7fbfbc1aec70>
../../_images/xesmf_regridding_15_1.png

Get attributes (unique identifier)

ds.attrs['tracking_id']
'hdl:21.14100/efd7a56e-94a8-47f5-b3d8-06ae02268192\nhdl:21.14100/0c3683e7-1c3f-45d6-bbc7-414f68e7a801\nhdl:21.14100/a83a3a96-0d16-4f3e-aa88-a68f00e1ce2e\nhdl:21.14100/7a629db0-dda1-445d-a496-2e77c9c7c20a\nhdl:21.14100/388888f8-7ee4-467f-aca9-8a6e45361f55\nhdl:21.14100/7aa3797a-b0e7-427d-a209-efed00cd1724\nhdl:21.14100/33d3ae45-cb42-47e5-8c08-76df86d298a8\nhdl:21.14100/10f9c9e3-3d54-494a-b153-ea63c5b584c7\nhdl:21.14100/2a3a1f67-8890-4e89-b6de-84b2b13cf70e\nhdl:21.14100/447d4151-8161-461a-a66e-f21144baabf6\nhdl:21.14100/c1bd13af-f8b2-4b18-bd2c-34b13a4921dc\nhdl:21.14100/eee79e49-dd70-4d8f-b195-c885aca26e3a\nhdl:21.14100/7cd2e526-a94c-4004-bfe8-32832a9df6d7\nhdl:21.14100/b2421ba2-c8c9-44c0-84af-f9b09cb87759\nhdl:21.14100/f87b3cf5-e68d-40f2-a289-667b4cf7d15f\nhdl:21.14100/a6542f1a-0567-4e7a-b063-f5609f017d69\nhdl:21.14100/01e4e8d6-b84e-4d25-9f50-4a9ea5588f07'

Regrid CMIP6 data to common NorESM2-LM grid

  • Select a time range

  • we use squeeze to remove dimension with one element only e.g. member_id=’r1i1p1f1’

starty = 1985; endy = 1986
year_range = range(starty, endy+1)
# Read in the output grid from NorESM
ds_out = ds.sel(time = ds.time.dt.year.isin(year_range)).squeeze()
ds_out
<xarray.Dataset>
Dimensions:     (bnds: 2, lat: 96, lon: 144, time: 24)
Coordinates:
  * lat         (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 86.21 88.11 90.0
    lat_bnds    (lat, bnds) float64 dask.array<chunksize=(96, 2), meta=np.ndarray>
  * lon         (lon) float64 0.0 2.5 5.0 7.5 10.0 ... 350.0 352.5 355.0 357.5
    lon_bnds    (lon, bnds) float64 dask.array<chunksize=(144, 2), meta=np.ndarray>
  * time        (time) object 1985-01-16 12:00:00 ... 1986-12-16 12:00:00
    time_bnds   (time, bnds) object dask.array<chunksize=(24, 2), meta=np.ndarray>
    wavelength  float64 ...
    member_id   <U8 'r1i1p1f1'
Dimensions without coordinates: bnds
Data variables:
    od550aer    (time, lat, lon) float32 dask.array<chunksize=(24, 96, 144), meta=np.ndarray>
Attributes: (12/52)
    Conventions:               CF-1.7 CMIP-6.2
    activity_id:               CMIP
    branch_method:             Hybrid-restart from year 1600-01-01 of piControl
    branch_time:               0.0
    branch_time_in_child:      0.0
    branch_time_in_parent:     430335.0
    ...                        ...
    title:                     NorESM2-LM output prepared for CMIP6
    tracking_id:               hdl:21.14100/efd7a56e-94a8-47f5-b3d8-06ae02268...
    variable_id:               od550aer
    variant_label:             r1i1p1f1
    intake_esm_varname:        ['od550aer']
    intake_esm_dataset_key:    CMIP.NCC.NorESM2-LM.historical.AERmon.gn
# create dictionary for reggridded data
ds_regrid_dict = dict()
for key in dset_dict.keys():
    print(key)
    ds_in = dset_dict[keys]
    ds_in = ds_in.sel(time = ds_in.time.dt.year.isin(year_range)).squeeze()
    regridder = xe.Regridder(ds_in, ds_out, 'bilinear')
    # Apply regridder to data
    # the entire dataset can be processed at once
    ds_in_regrid = regridder(ds_in, keep_attrs=True)
    # Save to netcdf file
    model = key.split('.')[2]
    filename = 'od550aer_AERmon.nc'
    savepath = 'CMIP6_hist/{}'.format(model)
    nc_out = os.path.join(savepath, filename)
    os.makedirs(savepath, exist_ok=True) 
    ds_in_regrid.to_netcdf(nc_out)
    # create dataset with all models
    ds_regrid_dict[model] = ds_in_regrid
    print('file written: {}'.format(nc_out))
CMIP.MPI-M.MPI-ESM1-2-LR.historical.AERmon.gn
file written: CMIP6_hist/MPI-ESM1-2-LR/od550aer_AERmon.nc
CMIP.BCC.BCC-ESM1.historical.AERmon.gn
file written: CMIP6_hist/BCC-ESM1/od550aer_AERmon.nc
CMIP.NCAR.CESM2-WACCM.historical.AERmon.gn
file written: CMIP6_hist/CESM2-WACCM/od550aer_AERmon.nc
CMIP.MPI-M.MPI-ESM1-2-HR.historical.AERmon.gn
file written: CMIP6_hist/MPI-ESM1-2-HR/od550aer_AERmon.nc
CMIP.MRI.MRI-ESM2-0.historical.AERmon.gn
file written: CMIP6_hist/MRI-ESM2-0/od550aer_AERmon.nc
CMIP.CCCma.CanESM5.historical.AERmon.gn
file written: CMIP6_hist/CanESM5/od550aer_AERmon.nc
CMIP.NCAR.CESM2.historical.AERmon.gn
file written: CMIP6_hist/CESM2/od550aer_AERmon.nc
CMIP.NCC.NorESM2-LM.historical.AERmon.gn
file written: CMIP6_hist/NorESM2-LM/od550aer_AERmon.nc
CMIP.HAMMOZ-Consortium.MPI-ESM-1-2-HAM.historical.AERmon.gn
file written: CMIP6_hist/MPI-ESM-1-2-HAM/od550aer_AERmon.nc
CMIP.NCAR.CESM2-WACCM-FV2.historical.AERmon.gn
file written: CMIP6_hist/CESM2-WACCM-FV2/od550aer_AERmon.nc
CMIP.NCAR.CESM2-FV2.historical.AERmon.gn
file written: CMIP6_hist/CESM2-FV2/od550aer_AERmon.nc

Concatenate all models

_ds = list(ds_regrid_dict.values())
_coord = list(ds_regrid_dict.keys())
ds_out_regrid = xr.concat(objs=_ds, dim=_coord, coords="all").rename({'concat_dim':'model'})
ds_out_regrid
<xarray.Dataset>
Dimensions:    (lat: 96, lon: 144, model: 11, nbnd: 2, time: 24)
Coordinates:
  * time       (time) object 1985-01-15 12:00:00 ... 1986-12-15 12:00:00
    time_bnds  (model, time, nbnd) object dask.array<chunksize=(1, 24, 2), meta=np.ndarray>
    member_id  (model) <U8 'r1i1p1f1' 'r1i1p1f1' ... 'r1i1p1f1' 'r1i1p1f1'
  * lon        (lon) float64 0.0 2.5 5.0 7.5 10.0 ... 350.0 352.5 355.0 357.5
  * lat        (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 86.21 88.11 90.0
  * model      (model) <U15 'MPI-ESM1-2-LR' 'BCC-ESM1' ... 'CESM2-FV2'
Dimensions without coordinates: nbnd
Data variables:
    od550aer   (model, time, lat, lon) float64 dask.array<chunksize=(1, 24, 96, 144), meta=np.ndarray>
Attributes: (12/49)
    Conventions:             CF-1.7 CMIP-6.2
    activity_id:             CMIP
    branch_method:           standard
    branch_time_in_child:    674885.0
    branch_time_in_parent:   10950.0
    case_id:                 1559
    ...                      ...
    variable_id:             od550aer
    variant_info:            CMIP6 CESM2-FV2 historical experiment (1850-2014...
    variant_label:           r1i1p1f1
    intake_esm_varname:      ['od550aer']
    intake_esm_dataset_key:  CMIP.NCAR.CESM2-FV2.historical.AERmon.gn
    regrid_method:           bilinear

Compute seasonal mean of all regridded models

ds_seas = ds_out_regrid.mean('model', keep_attrs=True, skipna = True).groupby('time.season').mean('time', keep_attrs=True, skipna = True)
ds_seas
<xarray.Dataset>
Dimensions:   (lat: 96, lon: 144, season: 4)
Coordinates:
  * lon       (lon) float64 0.0 2.5 5.0 7.5 10.0 ... 350.0 352.5 355.0 357.5
  * lat       (lat) float64 -90.0 -88.11 -86.21 -84.32 ... 86.21 88.11 90.0
  * season    (season) object 'DJF' 'JJA' 'MAM' 'SON'
Data variables:
    od550aer  (season, lat, lon) float64 dask.array<chunksize=(1, 96, 144), meta=np.ndarray>
Attributes: (12/49)
    Conventions:             CF-1.7 CMIP-6.2
    activity_id:             CMIP
    branch_method:           standard
    branch_time_in_child:    674885.0
    branch_time_in_parent:   10950.0
    case_id:                 1559
    ...                      ...
    variable_id:             od550aer
    variant_info:            CMIP6 CESM2-FV2 historical experiment (1850-2014...
    variant_label:           r1i1p1f1
    intake_esm_varname:      ['od550aer']
    intake_esm_dataset_key:  CMIP.NCAR.CESM2-FV2.historical.AERmon.gn
    regrid_method:           bilinear
ds_seas['od550aer'].min().compute(), ds_seas['od550aer'].max().compute()
(<xarray.DataArray 'od550aer' ()>
 array(0.0031875),
 <xarray.DataArray 'od550aer' ()>
 array(2.65120286))

Save seasonal mean in a new netCDF file and in the current Galaxy history

ds_seas.to_netcdf('CMIP6_hist/od550aer_seasonal.nc')
!put -p CMIP6_hist/od550aer_seasonal.nc -t netcdf

Visualize final results (seasonal mean for all models)

import matplotlib
proj_plot = ccrs.Mercator()

p = ds_seas['od550aer'].plot(x='lon', y='lat', transform=ccrs.PlateCarree(),
                             aspect=ds_seas.dims["lon"] / ds_seas.dims["lat"],  # for a sensible figsize
                             subplot_kws={"projection": proj_plot},
                             col='season', col_wrap=2, robust=True, cmap='PiYG')
# We have to set the map's options on all four axes
for ax,i in zip(p.axes.flat,  ds_seas.season.values):
    ax.coastlines()
    ax.set_title('Season '+i, fontsize=18)
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(18.5, 10.5)
fig.savefig('od550aer_seasonal_mean.png', dpi=100)
../../_images/xesmf_regridding_31_0.png