Bias Correction of Climate Model Output#

Paper (Chapter 4) Notebook Repository Method Repository Ongoing Development alt text

Primary Contact: Dr. Jeremy Carter | Notebook Top-to-Bottom Runtime ~ 2 Minutes

Challenge:

Physically based climate models are highly-multidimensional and computationally demanding. Simulation runs are often tuned for specific variables and typically observe a non-insignificant bias in other variables. Depending on what variables your specific research question focuses on it’s often desirable to perform postprocessing bias correction. In-situ measurements of variables can be used to help apply a bias correction. Where the in-situ measurements are sparse it’s important to consider uncertainty, depending on factors such as the underlying spatial covariance between points.

Approach:

Gaussian processes are used to explicitly model spatial covariance between points and to estimate uncertainties when applying bias correction across the whole domain. A Bayesian hierarchical model is constructed to promote uncertainty propagation across the different components of the model.

Introduction#

This notebook demonstrates an approach to bias correction that utilises Gaussian Processes (GPs) and a Bayesian hierarchical framework. The method is applied to bias correcting temperature data from a climate model over Antarctica using in-situ observations from automatic weather stations. Since the weather stations are both spatially and temporally sparse it’s important to capture uncertainty in the correction. Further detail is available at: J.Carter Thesis (Chapter 4). The code is in ongoing development and is available at: Bias Correction Application. The datasets and modules needed to run this notebook are available via the jupterbook_render branch of the repository.

Importing Required Libraries and Loading Data#

The data for this tutorial is stored as NetCDF files. This is a common file format for climate model output and is handled well by the Xarray Python package, which can be thought of as a Pandas equivalent for efficient handling of multidimensional data. Xarray loads data as ‘datasets’ and ‘dataarrays’. The model in this tutorial is defined using the Python packages Numpyro and TinyGP, which are compatible. Numpyro provides an intuitive probabilistic programming language for Bayesian statistics, built ontop of JAX and with NumPy based syntax. TinyGP provides a lightweight and intuitive package for working with Gaussian Process objects.

Hide code cell content
# Importing libraries
import os
from urllib.request import urlretrieve
import pickle 
import timeit
from tqdm import tqdm
import numpy as np
import xarray as xr
import pandas as pd
from scipy.spatial import distance
from scipy.stats import norm
import arviz as az
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from tinygp import kernels, GaussianProcess
from tinygp.kernels.distance import L2Distance
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

import geopandas as gpd
import seaborn as sns
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

rng_key = jax.random.PRNGKey(1)
jax.config.update("jax_enable_x64", True)
c:\Users\jercar\AppData\Local\miniconda3\envs\jbook_BCA\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
# Loading Data
data_path = os.path.join(os.getcwd(),"data","")

ds_aws = xr.open_dataset(f'{data_path}ds_aws.nc') # Automatic Weather Station Data
ds_climate = xr.open_dataset(f'{data_path}ds_climate.nc') # Climate Model Data

Using Xarray datasets provides nice interactive tables of the multidimensional data:

# Displaying the AWS data
ds_aws 
Hide code cell output
<xarray.Dataset> Size: 925kB
Dimensions:         (station: 219, t: 504)
Coordinates:
  * station         (station) <U22 19kB 'AGO Site' 'AGO-4' ... 'aws16' 'aws17'
    glat            (station) float64 2kB ...
    glon            (station) float64 2kB ...
    grid_latitude   (station) float64 2kB ...
    grid_longitude  (station) float64 2kB ...
    year            (t) float64 4kB ...
    month           (t) float64 4kB ...
  * t               (t) float64 4kB 1.0 2.0 3.0 4.0 ... 501.0 502.0 503.0 504.0
Data variables:
    latitude        (station) float64 2kB ...
    elevation       (station) float64 2kB ...
    temperature     (station, t) float64 883kB ...
# Displaying the climate model data
ds_climate
Hide code cell output
<xarray.Dataset> Size: 45MB
Dimensions:         (time: 456, grid_longitude: 126, grid_latitude: 98)
Coordinates:
  * grid_longitude  (grid_longitude) float64 1kB 152.4 152.9 ... 207.0 207.4
  * grid_latitude   (grid_latitude) float64 784B -21.39 -20.95 ... 20.84 21.29
  * time            (time) datetime64[ns] 4kB 1981-01-31 ... 2018-12-31
    month           (time) int64 4kB ...
    year            (time) int64 4kB ...
    glon            (grid_longitude, grid_latitude) float64 99kB ...
    glat            (grid_longitude, grid_latitude) float64 99kB ...
    t               (time) int64 4kB ...
Data variables:
    temperature     (time, grid_longitude, grid_latitude) float64 45MB ...
    elevation       (grid_longitude, grid_latitude) float64 99kB ...
    latitude        (grid_longitude, grid_latitude) float64 99kB ...
# Computing basic summary statistics using the pandas port of xarray objects
print('Summary of Automatic Weather Station Data \n',
      ds_aws.to_dataframe().describe()[['elevation','latitude','temperature']])
print('\n Summary of Climate Model Data \n',
      ds_climate.to_dataframe().describe()[['elevation','latitude','temperature']])
Hide code cell output
Summary of Automatic Weather Station Data 
            elevation       latitude   temperature
count  110376.000000  110376.000000  18088.000000
mean     1251.009132     -76.472009    -25.852952
std      1131.183380       5.411963     14.071861
min         5.000000     -90.000000    -71.740000
25%        87.000000     -79.820000    -31.880000
50%      1122.000000     -76.320000    -24.135000
75%      2090.000000     -73.080000    -16.320000
max      4093.000000     -65.240000      1.750000

 Summary of Climate Model Data 
           elevation      latitude   temperature
count  2.610144e+06  2.610144e+06  2.610144e+06
mean   2.003590e+03 -7.661032e+01 -3.257423e+01
std    1.150357e+03  5.392458e+00  1.458319e+01
min   -3.087963e+00 -8.971554e+01 -7.326862e+01
25%    1.045115e+03 -8.060204e+01 -4.373946e+01
50%    2.192235e+03 -7.646343e+01 -3.125523e+01
75%    2.988649e+03 -7.229835e+01 -2.182017e+01
max    4.063502e+03 -6.397399e+01  1.437534e+00

Data Exploration#

Initial data exploration is essential for informing our model construction. It includes examining the spatial and temporal distributions of the data and relationships between variables and potential predictors. To start we’ll examine the spatial distribution of weather stations, plotting over the grid of the climate model.

Spatial Distribution of Weather Stations#

Hide code cell source
# Loading ice sheet shapefile 
icesheet_shapefile_path = os.path.join(data_path,"icesheet_shapefile","icesheet.shp")
gdf_icesheet = gpd.read_file(icesheet_shapefile_path)

# Defining rotated coordinate system (glon,glat) and converting ice sheet shapefile to rotated coordinates
rotated_coord_system = ccrs.RotatedGeodetic(
    13.079999923706055,
    0.5199999809265137,
    central_rotated_longitude=180.0,
    globe=None,
)
gdf_icesheet_rotatedcoords = gdf_icesheet.to_crs(rotated_coord_system)

############################################################################################################

# Defining background map function
def background_map_rotatedcoords(ax):
    gdf_icesheet_rotatedcoords.boundary.plot(
        ax=ax,
        color='k',
        linewidth=0.3,
        alpha=0.4)
    ax.set_axis_off()

# Defining marker size legend function
def markersize_legend(ax, bins, scale_multipler, legend_fontsize=10,loc=3,ncols=1,columnspacing=0.8,handletextpad=0.1,bbox=(0.,0.)):
    ax.add_artist(
        ax.legend(
            handles=[
                mlines.Line2D(
                    [],
                    [],
                    color="tab:blue",
                    markeredgecolor="k",
                    markeredgewidth=0.3,
                    lw=0,
                    marker="o",
                    markersize=np.sqrt(b*scale_multipler),
                    label=str(int(b)),
                )
                for i, b in enumerate(bins)
            ],
            loc=loc,
            fontsize = legend_fontsize,
            ncols=ncols,
            columnspacing=columnspacing,
            handletextpad=handletextpad,
            bbox_to_anchor=bbox,
            framealpha=0,
        )
    )

############################################################################################################

# Plotting the weather station locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)
background_map_rotatedcoords(ax)
ax.scatter(
    ds_aws.glon,
    ds_aws.glat,
    s=ds_aws.count('t')['temperature']/5,
    edgecolor='k',
    linewidths=0.5,
)
ax.annotate(
    'Number of records:',
    xy=(0.08, 0.1), xycoords='axes fraction',
    fontsize=10)
markersize_legend(ax, [1,10,30,50,100,200,300,400], scale_multipler=1/5, legend_fontsize=10,loc=3,ncols=9,columnspacing=0.3,handletextpad=-0.4,bbox=(0.08,0.05))

# highlighting individual station
station = 'Manuela'
ax.annotate(
    f'{station} Weather Station',
    xy=(ds_aws.sel(station = station).glon, ds_aws.sel(station = station).glat), xycoords='data',
    xytext=(-140,-15), textcoords='offset points',
    arrowprops=dict(arrowstyle="->"),
    fontsize=10)

# Plotting the climate model grid
ds_climate['temperature'].mean('time').notnull().plot.pcolormesh(
    x='glon',
    y='glat',
    ax=ax,
    alpha=0.05,
    add_colorbar=False,
    edgecolor='k',
    linewidth=0.3,
)

plt.tight_layout()
plt.show()
../../_images/42435228ec18160842bb802144e0181bf833b29ab12a2e5f0291062c12c0188e.png

It is clear that the spatial distribution of weather stations is not uniform over the domain. There are certain regions containing high-density clusters of stations. This will induce a bias in the bias-correction itself when conducted over the whole domain, with corrections skewed towards these regions. If these clusters occurred randomly then modelling the spatial correlation between sites would be adequate to account for the distribution. Instead it is likely these regions were chosen for specific features, such as having anomalously high temperatures, which is difficult to account for in the model. We don’t directly attempt to solve this problem but it is noted as a remaining limitation.

Time Series and Distribution of Single Weather Station (Manuela)#

Comparisons are made between the time series for the Manuela weather station and the nearest grid-cell of the climate model output. The time series represents the average temperature for each month and the values are aggregated from hourly measurements of the raw data (preprocessed in this notebook for simplicity).

Hide code cell source
# Computing Nearest Neighbours
ds_climate_stacked = ds_climate.stack(x=('grid_longitude', 'grid_latitude'))
ds_climate_stacked_landonly = ds_climate_stacked.dropna('x')
ox = np.dstack([ds_aws['glon'],ds_aws['glat']])[0]
cx = np.dstack([ds_climate_stacked_landonly['glon'],ds_climate_stacked_landonly['glat']])[0]
nn_indecies = []
for point in ox:
    nn_indecies.append(distance.cdist([point], cx).argmin())
ds_climate_nearest_stacked = ds_climate_stacked_landonly.isel(x=nn_indecies)
ds_climate_nearest_stacked = ds_climate_nearest_stacked.assign_coords(nearest_station=("x", ds_aws.station.data))
ds_climate_nearest_stacked = ds_climate_nearest_stacked.swap_dims({"x": "nearest_station"})

# Single Site Full Time Series
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)
station = 'Manuela'
ds_climate_nearest_stacked.sel(nearest_station = station)['temperature'].plot(x="t",
                                                                              ax=ax,
                                                                              hue='station',
                                                                              alpha=0.7,
                                                                              label='Climate Model Output (Nearest Grid-Cell)',
                                                                              marker='x',
                                                                              ms=1,
                                                                              color='tab:blue',
                                                                              linewidth=1.0)

ds_aws.sel(station = station)['temperature'].plot(ax=ax,
                                                hue='station',
                                                alpha=0.7,
                                                label=f'{station} Weather Station',
                                                marker='x',
                                                ms=1,
                                                color='tab:orange',
                                                linewidth=1.5)
xticks = np.arange(0,45*12,12*5)
xticklabels = np.arange(1980,2025,5)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_ylabel('Temperature')
ax.set_xlabel('Time')
ax.legend()
ax.set_title('')
plt.tight_layout()
plt.show()
../../_images/d6460c5862ca0cbea06c58bec7f3c2e35abe87f43d3f7d40a7ae961c93259526.png

The Manuela weather station has one of the highest numbers of temperature records, spanning from 1984-2021. The time series for the climate model spans 1981-2019. It’s clear that the variance in the time series are dominated by the seasonal cycle and that any bias in for example the mean will have some seasonal dependency. The PDFs for the 2 time series are plot below.

Hide code cell source
# Probability Density Function (all months)
station = 'Manuela'

fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)
ds_aws.sel(station=station).to_dataframe()[['temperature']].hist(bins=40,
                                         ax=ax,
                                         edgecolor='k',
                                         linewidth=0.2,
                                         grid=False,
                                         density=1,
                                         alpha=0.7,
                                         label = f'{station} Weather Station',
                                         )
ds_climate_nearest_stacked.sel(nearest_station=station).to_dataframe()[['temperature']].hist(bins=40,
                                         ax=ax,
                                         edgecolor='k',
                                         linewidth=0.2,
                                         grid=False,
                                         density=1,
                                         alpha=0.7,
                                        label = 'Climate Model Output (Nearest Grid-Cell)',
                                         )
ax.annotate('All Months',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
plt.legend()
plt.tight_layout()
plt.show()
../../_images/a7354339847e19a8ebb0f4d4c30288bcfbcc8c8663f24393a43b6e790d3216f1.png

In this tutorial we define bias with respect to differences between the parameters that describe the PDFs of the time series. The PDF above is multi-modal, reflecting the seasonality of the data, meaning we’d need to use quite a few parameters to adequately describe the distribution. Simply using the mean would have limited value as we can see that the winter peak is ~5° higher for the climate model output while the summer peaks are approximately equal. The common approach here is to simply split the time series up by the month, focusing on defining bias for each month separately. The PDFs for individual months are approximately Gaussian (see below) and so the bias can be defined in terms of differences in the mean and variance parameters between the datasets.

Hide code cell source
# Probability Density Function by Month
station = 'Manuela'

fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)

sns.histplot(data=ds_aws.sel(station=station).to_dataframe()[['month','temperature']],
            x='temperature',
            hue='month',
            bins=40,
            ax=ax,
            edgecolor='k',
            linewidth=0.2,
            kde=True,
            palette='Paired',
)

ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
plt.legend(['January','February','March','April','May','June','July','August','September','October','November','December'])
plt.tight_layout()
plt.show()
../../_images/8c876e81c11da9b14fbbb25ccc607e45bcbdcaca1fd27669183e5f373a5e927a.png

In this tutorial, we’ll focus on just applying bias correction to the monthly June time series, shown below for the Manuela station.

Hide code cell source
month = 6 
station = 'Manuela'

fig, axs = plt.subplots(1, 2, figsize=(10, 4), dpi=100)

ax = axs[0]
ds_climate_nearest_stacked.sel(nearest_station = station).where(ds_climate_nearest_stacked['month']==month,drop=True)['temperature'].plot(
                                                                                                                                x='t',
                                                                                                                                ax=ax,
                                                                                                                                hue='station',
                                                                                                                                alpha=0.7,
                                                                                                                                label='Climate Model Output (Nearest Grid-Cell)',
                                                                                                                                marker='x',
                                                                                                                                ms=1,
                                                                                                                                color='tab:orange',
                                                                                                                                linewidth=1.0)
ds_aws.sel(station = station).where(ds_aws['month']==month,drop=True)['temperature'].plot(ax=ax,
                                                                                                hue='station',
                                                                                                alpha=0.7,
                                                                                                label=f'{station} Weather Station',
                                                                                                marker='x',
                                                                                                ms=1,
                                                                                                color='tab:blue',
                                                                                                linewidth=1.5)

ax.annotate('June',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)#,rotation = 90)
ax.set_ylabel('Temperature')
ax.set_xlabel('Time')
ax.legend(fontsize=8)
ax.set_title('')

ax=axs[1]
ds_aws.sel(station=station).where(ds_aws['month']==month,drop=True).to_dataframe()[['temperature']].hist(bins=12,
                                         ax=ax,
                                         edgecolor='k',
                                         linewidth=0.2,
                                         grid=False,
                                         density=1,
                                         alpha=0.7,
                                         label = f'{station} Weather Station',
                                         )
ds_climate_nearest_stacked.sel(nearest_station=station).where(ds_climate_nearest_stacked['month']==month,drop=True).to_dataframe()[['temperature']].hist(bins=12,
                                         ax=ax,
                                         edgecolor='k',
                                         linewidth=0.2,
                                         grid=False,
                                         density=1,
                                         alpha=0.7,
                                        label = 'Climate Model Output (Nearest Grid-Cell)',
                                         )
ax.annotate('June',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
ax.legend(fontsize=8)

plt.tight_layout()
plt.show()
../../_images/dbb39e2182a307770133df4143951c02164cce8dfa337dfa291c9a768a60d40d.png

When examining the June time series, it’s clear that the climate model performs well at capturing the yearly variability and is highly correlated with the weather station output. However, there’s a significant bias in the mean of approximately 4°, indicating the utility of applying a bias correction.

Pairplots#

Since we’re interested in evaluating bias in the mean and variance of the June time series, it’s useful to think about predictors that have an influence on these metrics. The two most obvious predictors are elevation and latitude, both of which have well-understood physical justification for the impact on the mean temperature.

Hide code cell source
# Filtering to June only
ds_aws_filtered = ds_aws.sel(t = ds_aws.month == 6)
ds_climate_stacked_landonly_filtered = ds_climate_stacked_landonly.sel(time = ds_climate_stacked_landonly.month == 6)
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked.sel(time = ds_climate_nearest_stacked.month == 6)

# Filtering AWS by number of records and updating nearest grid-cells to match
ds_aws_filtered['records'] = ds_aws_filtered.count('t')['temperature']
stations_recordsfilter = ds_aws_filtered.where(ds_aws_filtered['records']>5,drop=True)['station'].data

ds_aws_filtered = ds_aws_filtered.sel(station=stations_recordsfilter)
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked_filtered.sel(nearest_station=stations_recordsfilter)

# Computing June Mean and Standard Deviation
ds_aws_filtered['june_mean_temperature'] = ds_aws_filtered.mean('t')['temperature']
ds_aws_filtered['june_std_temperature'] = ds_aws_filtered.std('t')['temperature']
ds_climate_stacked_landonly_filtered['june_mean_temperature'] = ds_climate_stacked_landonly_filtered.mean('time')['temperature']
ds_climate_stacked_landonly_filtered['june_std_temperature'] = ds_climate_stacked_landonly_filtered.std('time')['temperature']
ds_climate_nearest_stacked_filtered['june_mean_temperature'] = ds_climate_nearest_stacked_filtered.mean('time')['temperature']
ds_climate_nearest_stacked_filtered['june_std_temperature'] = ds_climate_nearest_stacked_filtered.std('time')['temperature']

# Transforming data for plotting with seaborn PairGrid 
vars = ['elevation','latitude','june_mean_temperature','june_std_temperature']
df_climate_filtered = ds_climate_stacked_landonly_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_climate_nearest_filtered = ds_climate_nearest_stacked_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_aws_filtered = ds_aws_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_climate_filtered['source'] = 'Climate Model'
df_climate_nearest_filtered['source'] = 'Climate Model Nearest'
df_aws_filtered['source'] = 'AWS'
df_combined = pd.concat([df_climate_filtered,df_climate_nearest_filtered,df_aws_filtered],axis=0).reset_index(drop=True)

# Plotting PairGrid with regression lines
g = sns.PairGrid(df_combined, hue='source',diag_sharey=False, corner=True)

reg_kws = {'scatter': False, 'line_kws':{'linewidth':1}}
g.map_lower(sns.regplot,**reg_kws)
g.add_legend(bbox_to_anchor=(0.8,0.8),markerscale=3)

g.hue_kws = {'marker':['+','x','*'],'s':[2,5,2],'alpha':[0.2,1,1]}
scatter_kws = {'linewidth':0.8}
g.map_lower(plt.scatter,**scatter_kws)

hist_kws = {'common_norm':False,'stat':'density'}
g.map_diag(sns.histplot,**hist_kws)

plt.show()
../../_images/49cdcf25416e05a80e5ddb1bc1d23c0f911a54a261f5ef01d2bc5ffba5cfeb63.png

The pairplots bring up some interesting features:

  • Comparisons of the histograms between the ‘Climate Model’ data (all grid-cells over Antarctica) and the ‘Climate Model Nearest’ data (only grid-cells closest to the weather stations) indicate that the AWS sites are not a particularly representative sample of the whole Antarctic region. There’s a higher proportion of sites at zero elevation, clusters of sites at particular locations (and so latitudes), disproportionality high numbers of sites at regions with relatively high mean temperatures and relatively low standard deviations.

  • There are clear relationships between mean temperature with elevation and latitude. The slope of the linear relationship does not seem too strongly impacted by the particular subsample of AWS locations.

  • There’s only a weak relationship between the standard deviation in temperature with elevation and latitude. As a result we’ll leave out these predictors when estimating the standard deviation across the domain.

  • The behaviour of the relationship between elevation with mean temperature appears quite different at zero elevation sites, which could be linked with various factors such as the proximity and impact of the nearby sea on zero elevation sites. This potentially indicates at the utility of incorporating a distance to the coast predictor, although for this tutorial we leave this out.

Examining spatial covariance after removing influence of elevation and latitude#

The spatial pattern in mean temperature is currently dominated by the relationship with elevation and latitude. While these are clearly important predictors for mean temperature, it’s expected that there’ll be various other important factors that impact mean temperature but are harder to account for (e.g. the funnelling of wind down valleys will impact the mean temperature). One way of at least partially accounting for these factors is to model the spatial covariance between sites after removing the influence of elevation and latitude. That is that nearby sites are likely to be highly correlated as the factors impacting them are similar, whereas the further away you go the less correlated the sites will be (different valleys with different wind patterns etc).

Here we’ll plot the spatial pattern in the mean temperature after removing the linear influence of elevation and latitude. Additionally, we’ll plot the spatial pattern in the log of the standard deviation (taking the log to get the metric on the -\(\infty,\infty\) domain). Since the relationship between elevation and latitude with standard deviation appeared weak, we’ll ignore these predictors for this metric and simply remove a constant to get the zero-mean log(standard deviations).

Hide code cell content
# Modelling the linear relationship between elevation, latitude and June mean temperature

# Defining predictors and scaling
predictors = ['elevation','latitude']
scaled_predictors = [i+'_scaled' for i in predictors]
target = 'june_mean_temperature'

scaler = StandardScaler()
df_aws_filtered_scaled_predictors = pd.DataFrame(scaler.fit_transform(df_aws_filtered[predictors]),columns=scaled_predictors)
df_climate_filtered_scaled_predictors = pd.DataFrame(scaler.transform(df_climate_filtered[predictors]),columns=scaled_predictors)

# Linear Regression AWS
print("Linear Regression AWS:")
model = LinearRegression()
model.fit(df_aws_filtered_scaled_predictors,df_aws_filtered[target])
df_aws_filtered['june_mean_temperature_predicted_lr'] = model.predict(df_aws_filtered_scaled_predictors)
feature_importance = pd.Series(model.coef_, index=scaled_predictors)
print('intercept:',model.intercept_)
print(feature_importance.sort_values(ascending=False,key=abs))

# Linear Regression Climate Model
print("\n Linear Regression Climate Model:")
model = LinearRegression()
model.fit(df_climate_filtered_scaled_predictors,df_climate_filtered[target])
df_climate_filtered['june_mean_temperature_predicted_lr'] = model.predict(df_climate_filtered_scaled_predictors)
feature_importance = pd.Series(model.coef_, index=scaled_predictors)
print('intercept:',model.intercept_)
print(feature_importance.sort_values(ascending=False,key=abs))

ds_climate_stacked_landonly_filtered['june_mean_temperature_predicted_lr'] = (
    ('x'),
    df_climate_filtered['june_mean_temperature_predicted_lr'])

ds_aws_filtered['june_mean_temperature_predicted_lr'] = (
    ('station'),
    df_aws_filtered['june_mean_temperature_predicted_lr'])

ds_climate_stacked_landonly_filtered['june_mean_temperature_residual_lr']=ds_climate_stacked_landonly_filtered['june_mean_temperature']-ds_climate_stacked_landonly_filtered['june_mean_temperature_predicted_lr']
ds_aws_filtered['june_mean_temperature_residual_lr']=ds_aws_filtered['june_mean_temperature']-ds_aws_filtered['june_mean_temperature_predicted_lr']

ds_climate_stacked = xr.merge([ds_climate_stacked,ds_climate_stacked_landonly_filtered])
Linear Regression AWS:
intercept: -33.625017170326736
elevation_scaled   -10.160912
latitude_scaled      1.952785
dtype: float64

 Linear Regression Climate Model:
intercept: -33.0945203068335
elevation_scaled   -10.223010
latitude_scaled      2.895803
dtype: float64
Hide code cell content
# Computing the zero-mean log(standard deviation)
ds_aws_filtered['june_logstd_temperature']=np.log(ds_aws_filtered['june_std_temperature'])
ds_climate_stacked['june_logstd_temperature']=np.log(ds_climate_stacked['june_std_temperature'])
ds_climate_nearest_stacked_filtered['june_logstd_temperature']=np.log(ds_climate_nearest_stacked_filtered['june_std_temperature'])
ds_aws_filtered['june_logstd_temperature_residual_constant'] = ds_aws_filtered['june_logstd_temperature'] - ds_aws_filtered['june_logstd_temperature'].mean()
ds_climate_stacked['june_logstd_temperature_residual_constant'] = ds_climate_stacked['june_logstd_temperature'] - ds_climate_stacked['june_logstd_temperature'].mean()
ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant']=ds_climate_nearest_stacked_filtered['june_logstd_temperature'] - ds_climate_nearest_stacked_filtered['june_logstd_temperature'].mean()
Hide code cell source
# Plotting the linear regression prediction and residual
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)#,frameon=False)

metrics = ['june_mean_temperature_predicted_lr','june_mean_temperature_residual_lr','june_logstd_temperature','june_logstd_temperature_residual_constant']
cmaps = ['viridis','RdBu','viridis','RdBu']
vminmaxs = [(-60,-15),(-10,10),(0.4,1.6),(-0.6,0.6)]
labels = ['June Mean Temperature Predicted (Linear Regression)','June Mean Temperature Residual (Linear Regression)',
          'June Log(Standard Deviation)','June Zero-Mean Log(Standard Deviation)']

for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
    background_map_rotatedcoords(ax)
    ds_climate_stacked[metric].unstack().plot.pcolormesh(
        x='glon',
        y='glat',
        ax=ax,
        alpha=0.9,
        vmin=vminmax[0],
        vmax=vminmax[1],
        cmap=cmap,
        cbar_kwargs = {'fraction':0.030,
                    'pad':0.02,
                    'label':label}
    )

    ax.scatter(
        ds_aws_filtered['glon'],
        ds_aws_filtered['glat'],
        marker="o",
        c=ds_aws_filtered[metric],
        cmap=cmap,
        edgecolor="w",
        linewidth=1.0,
        vmin=vminmax[0],
        vmax=vminmax[1],
    )

    ax.set_axis_off()
    ax.set_title('')
                       
plt.tight_layout()
../../_images/7c3c75e0eaded73ce5432e8594c03fd4c727b9b9e8d1f0c9f37cc7706e06c731.png

It’s clear once we remove the linear dependency of temperature with elevation and latitude we get quite a different spatial structure for the mean temperature. This is the spatial structure we’ll want to model using Gaussian processes, where the covariance between nearby sites is captured by parameterising the covariance as a function of distance. In our model we’ll assume a single length scale for the covariance, that is to say we assume the covariance between nearby sites decays at the same distance wherever you are located over Antarctica. While this assumption is clearly broken in certain areas (e.g. covariance in steep regions near to the coast behaves differently to flat regions inland), it still provides a start and is a lot simpler than the approach of considering non-stationary lengthscales across the region (although this is possible). It’s important to consider how this will impact our results and we expect one of the main influences is that the noise term estimated for our Gaussian process will be relatively high to account for the sharp variations between nearby sites in particularly steep regions.

It’s also important to note that the spatial patterns in both the AWS data and climate model data are similar for each metric. That is the climate model is doing a reasonable job at capturing the more complex dependencies of mean temperature and log(standard dev) in temperature. To utilise this, in our model we consider a shared latent Gaussian process between the datasets and so predictions of the unbiased values are made conditioning on both datasets.

Examining spatial covariance in the bias#

It’s also useful to explore the spatial structure of the bias (both in the mean and log(standard deviation) of June temperature). In this exploratory analysis we do this by examining the empirical values of the metrics from the AWS data and the nearest climate model grid-cells.

For the mean temperature, this is done after accounting for the linear relationship with elevation and latitude (that is we examine the spatial structure in the bias of the residuals). For the log(standard deviation), we examine bias in the zero-meaned values for each dataset.

Hide code cell content
# Recomputing nearest neighbours
ds_climate_nearest_stacked = ds_climate_stacked_landonly_filtered.isel(x=nn_indecies)
ds_climate_nearest_stacked = ds_climate_nearest_stacked.assign_coords(nearest_station=("x", ds_aws.station.data))
ds_climate_nearest_stacked = ds_climate_nearest_stacked.swap_dims({"x": "nearest_station"})
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked.sel(nearest_station=stations_recordsfilter)

# Recomputing the zero-mean log(standard deviation)
ds_climate_nearest_stacked_filtered['june_logstd_temperature']=np.log(ds_climate_nearest_stacked_filtered['june_std_temperature'])
ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant']=ds_climate_nearest_stacked_filtered['june_logstd_temperature'] - ds_climate_nearest_stacked_filtered['june_logstd_temperature'].mean()

# Evaluating bias in residuals from linear regression for the mean temperature
ds_climate_nearest_stacked_filtered['bias_june_mean_temperature_residual_lr'] = (
    ('nearest_station'),
    ds_climate_nearest_stacked_filtered['june_mean_temperature_residual_lr'].data - ds_aws_filtered['june_mean_temperature_residual_lr'].data)

# Evaluating bias for the log(std) temperature
ds_climate_nearest_stacked_filtered['bias_june_logstd_temperature_residual_constant'] = (
    ('nearest_station'),
    ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant'].data - ds_aws_filtered['june_logstd_temperature_residual_constant'].data)
Hide code cell source
# Plotting the linear regression residual and bias in the residual for the mean temperature
fig, axs = plt.subplots(1, 2, figsize=(15, 5),dpi=100)#,frameon=False)

for ax in axs:
    background_map_rotatedcoords(ax)
    ax.set_axis_off()
    ax.set_title('')

ax=axs[0]

plot1 = ax.scatter(
    ds_climate_nearest_stacked_filtered['glon'],
    ds_climate_nearest_stacked_filtered['glat'],
    marker="o",
    c=ds_climate_nearest_stacked_filtered['bias_june_mean_temperature_residual_lr'],
    cmap='RdBu',
    edgecolor="w",
    linewidth=1.0,
    vmin=-5,
    vmax=5,
)

plt.colorbar(plot1, ax=ax, fraction=0.03, pad=0.02,label='Bias in June Mean Temperature Residual \n (Linear Regression)')#, orientation='horizontal', label=metric)

ax=axs[1]

plot2 = ax.scatter(
    ds_climate_nearest_stacked_filtered['glon'],
    ds_climate_nearest_stacked_filtered['glat'],
    marker="o",
    c=ds_climate_nearest_stacked_filtered['bias_june_logstd_temperature_residual_constant'],
    cmap='RdBu',
    edgecolor="w",
    linewidth=1.0,
    vmin=-0.5,
    vmax=0.5,
)

plt.colorbar(plot2, ax=ax, fraction=0.03, pad=0.02,label='Bias in June Log(Std) Temperature Residual \n (Constant)')#, orientation='horizontal', label=metric)

plt.tight_layout()
../../_images/9bf0f2982504281609dfc425e9d2f45417e2a8260ee8989974d2546b165fb2e3.png

There’s clearly a spatial covariance pattern in the bias for both the mean and log(standard dev.). The length scale at which the covariance decays for the bias appears longer than for the raw value of the metrics of each dataset.

Examining relationships between variables and the bias in parameters#

It’s interesting to check whether there’s any obvious relationships between the biased parameter values and predictors such as elevation and latitude. Additionally, it’s worth checking if there’s a relationship between the biased and unbiased parameter values. We do this through a partial pairplot as shown below.

Hide code cell source
# Transforming data for plotting with seaborn PairGrid 
vars = ['elevation','latitude','june_mean_temperature','june_std_temperature','bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant']
x_vars = ['bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant','elevation','latitude','june_mean_temperature','june_std_temperature']
y_vars = ['bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant']
df_bias_climate = ds_climate_nearest_stacked_filtered[vars].to_dataframe()[vars].reset_index(drop=True)

# Plotting PairGrid with regression lines
g = sns.PairGrid(df_bias_climate, x_vars=x_vars,y_vars=y_vars,height=4,diag_sharey=False)#, corner=True)

reg_kws = {'scatter': False, 'line_kws':{'linewidth':1}}
g.map_offdiag(sns.regplot,**reg_kws)

g.hue_kws = {'marker':['+'],'s':[5],'alpha':[1.0]}
scatter_kws = {'linewidth':0.8}
g.map_offdiag(plt.scatter,**scatter_kws)

hist_kws = {'common_norm':False,'stat':'density'}
g.map_diag(sns.histplot,**hist_kws)

plt.show()
../../_images/59f1acc868b3165d94bd3ce01313d63451d6dbeb2c5fe6381971a8ff4cb955d6.png

The above pairplot only shows weak relationships between the bias in parameters and the other variables. Therefore, in the model we’ll assume the bias is generated from an independent underlying process.

Data Preprocessing#

The main pre-processing steps we’ll do are simply scaling the elevation and latitude predictors and removing any AWS sites with only 2 records for average June temperature. We also define a random subsample of the climate model grid-cells, which we’ll use for inference on the parameters of the Gaussian processes in order to reduce the computational demands. Additionally, all the data is transformed away from Xarray and into a dictionary of device arrays (JAX versions of Numpy arrays) of the right shape that the inference package Numpyro expects.

Hide code cell content
# Filtering to June records and stations with more than 2 records
aws_june_filter = ds_aws.where(ds_aws['month']==6,drop=True)['t'].data
ds_aws_preprocessed = ds_aws.sel(t=aws_june_filter)
aws_stations_recordsfilter = ds_aws_preprocessed.where(ds_aws_preprocessed['temperature'].count(['t'])>2,drop=True)['station'].data
ds_aws_preprocessed = ds_aws_preprocessed.sel(station=aws_stations_recordsfilter)

climate_june_filter = ds_climate.where(ds_climate['month']==6,drop=True)['time'].data
ds_climate_preprocessed = ds_climate.sel(time=climate_june_filter)
ds_climate_preprocessed = ds_climate_preprocessed.stack(x=('grid_longitude', 'grid_latitude'))
ds_climate_preprocessed = ds_climate_preprocessed.dropna('x')
random_sample = np.random.choice(np.arange(len(ds_climate_preprocessed['x'])), size=100, replace=False)
ds_climate_preprocessed_sample = ds_climate_preprocessed.isel(x=random_sample)

# Scaling latitude and elevation 
lat_scalar = StandardScaler()
ele_scalar = StandardScaler()
ds_aws_preprocessed['latitude_scaled'] = (['station'],  lat_scalar.fit_transform(ds_aws_preprocessed['latitude'].data.reshape(-1,1))[:,0])
ds_aws_preprocessed['elevation_scaled'] = (['station'], ele_scalar.fit_transform(ds_aws_preprocessed['elevation'].data.reshape(-1,1))[:,0])
ds_climate_preprocessed['latitude_scaled'] = (['x'], lat_scalar.transform(ds_climate_preprocessed['latitude'].data.reshape(-1,1))[:,0])
ds_climate_preprocessed['elevation_scaled'] = (['x'], ele_scalar.transform(ds_climate_preprocessed['elevation'].data.reshape(-1,1))[:,0])

# Transforming into dictionary of device arrays
ox = jnp.array(np.dstack([ds_aws_preprocessed['glon'],ds_aws_preprocessed['glat']]))[0]
odata = jnp.array(ds_aws_preprocessed['temperature'].values).transpose()
olat = jnp.array(ds_aws_preprocessed['latitude'].values)
oele = jnp.array(ds_aws_preprocessed['elevation'].values)
olat_scaled = jnp.array(ds_aws_preprocessed['latitude_scaled'].values)
oele_scaled = jnp.array(ds_aws_preprocessed['elevation_scaled'].values)

cx = jnp.array(np.dstack([ds_climate_preprocessed['glon'],ds_climate_preprocessed['glat']]))[0]
cdata = jnp.array(ds_climate_preprocessed.transpose()['temperature'].values).transpose()
clat = jnp.array(ds_climate_preprocessed['latitude'].values)
cele = jnp.array(ds_climate_preprocessed['elevation'].values)
clat_scaled = jnp.array(ds_climate_preprocessed['latitude_scaled'].values)
cele_scaled = jnp.array(ds_climate_preprocessed['elevation_scaled'].values)

cx_subsample = cx[random_sample]
cdata_subsample = cdata[:,random_sample]
cele_subsample = cele[random_sample]
clat_subsample = clat[random_sample]
cele_scaled_subsample = cele_scaled[random_sample]
clat_scaled_subsample = clat_scaled[random_sample]

data_dictionary = {
    'ds_aws_preprocessed':ds_aws_preprocessed,
    'ds_climate_preprocessed':ds_climate_preprocessed,
    'ds_climate_preprocessed_sample':ds_climate_preprocessed_sample,
    'ox':ox,
    'odata':odata,
    'olat':olat,
    'oele':oele,
    'olat_scaled':jnp.array(olat_scaled),
    'oele_scaled':jnp.array(oele_scaled),
    'cx':cx,
    'cdata':cdata,
    'clat':clat,
    'cele':cele,
    'clat_scaled':jnp.array(clat_scaled),
    'cele_scaled':jnp.array(cele_scaled),
    'cx_subsample':cx_subsample,
    'cdata_subsample':cdata_subsample,
    'cele_subsample':cele_subsample,
    'clat_subsample':clat_subsample,
    'cele_scaled_subsample':jnp.array(cele_scaled_subsample),
    'clat_scaled_subsample':jnp.array(clat_scaled_subsample),
    'ele_scaler':ele_scalar,
    'lat_scaler':lat_scalar,
    'random_sample':random_sample,
}

It’s worth plotting the locations of the sampled climate model grid-cells and it’s also worth performing a quick sanity check that the data is in the right format:

Hide code cell content
#Sanity Check

print('Shapes:')
for key in data_dictionary.keys():
    if key not in ['ds_aws_preprocessed','ds_climate_preprocessed','ds_climate_preprocessed_sample','ele_scaler','lat_scaler']:
        print(f'{key} shape: {data_dictionary[key].shape}')

print('\n Values:')
for key in data_dictionary.keys():
    if key not in ['ds_aws_preprocessed','ds_climate_preprocessed','ds_climate_preprocessed_sample','ele_scaler','lat_scaler']:
        if key=='odata':
            print(f'{key} min={np.nanmin(data_dictionary[key]):.1f}, mean={np.nanmean(data_dictionary[key]):.1f}, max={np.nanmax(data_dictionary[key]):.1f}')
        else:
            print(f'{key} min={data_dictionary[key].min():.1f}, mean={data_dictionary[key].mean():.1f}, max={data_dictionary[key].max():.1f}')

print('\n Types:')
for key in data_dictionary.keys():
    print(f'{key} type: {type(data_dictionary[key])}')
Shapes:
ox shape: (156, 2)
odata shape: (42, 156)
olat shape: (156,)
oele shape: (156,)
olat_scaled shape: (156,)
oele_scaled shape: (156,)
cx shape: (5724, 2)
cdata shape: (38, 5724)
clat shape: (5724,)
cele shape: (5724,)
clat_scaled shape: (5724,)
cele_scaled shape: (5724,)
cx_subsample shape: (100, 2)
cdata_subsample shape: (38, 100)
cele_subsample shape: (100,)
clat_subsample shape: (100,)
cele_scaled_subsample shape: (100,)
clat_scaled_subsample shape: (100,)
random_sample shape: (100,)

 Values:
ox min=-24.0, mean=-1.7, max=21.5
odata min=-70.3, mean=-33.4, max=-7.9
olat min=-90.0, mean=-76.4, max=-65.2
oele min=5.0, mean=1246.4, max=4093.0
olat_scaled min=-2.6, mean=-0.0, max=2.1
oele_scaled min=-1.1, mean=0.0, max=2.5
cx min=-24.9, mean=2.5, max=24.4
cdata min=-73.3, mean=-39.6, max=-7.0
clat min=-89.7, mean=-76.6, max=-64.0
cele min=-3.1, mean=2003.6, max=4063.5
clat_scaled min=-2.5, mean=-0.0, max=2.4
cele_scaled min=-1.1, mean=0.7, max=2.5
cx_subsample min=-19.6, mean=2.4, max=23.5
cdata_subsample min=-69.7, mean=-40.6, max=-11.4
cele_subsample min=41.3, mean=2133.2, max=3999.3
clat_subsample min=-86.8, mean=-76.1, max=-66.5
cele_scaled_subsample min=-1.1, mean=0.8, max=2.5
clat_scaled_subsample min=-2.0, mean=0.1, max=1.9
random_sample min=56.0, mean=2963.7, max=5709.0

 Types:
ds_aws_preprocessed type: <class 'xarray.core.dataset.Dataset'>
ds_climate_preprocessed type: <class 'xarray.core.dataset.Dataset'>
ds_climate_preprocessed_sample type: <class 'xarray.core.dataset.Dataset'>
ox type: <class 'jaxlib.xla_extension.ArrayImpl'>
odata type: <class 'jaxlib.xla_extension.ArrayImpl'>
olat type: <class 'jaxlib.xla_extension.ArrayImpl'>
oele type: <class 'jaxlib.xla_extension.ArrayImpl'>
olat_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
oele_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cx type: <class 'jaxlib.xla_extension.ArrayImpl'>
cdata type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cx_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cdata_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_scaled_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_scaled_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
ele_scaler type: <class 'sklearn.preprocessing._data.StandardScaler'>
lat_scaler type: <class 'sklearn.preprocessing._data.StandardScaler'>
random_sample type: <class 'numpy.ndarray'>
Hide code cell content
# Plotting the subsample locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)
background_map_rotatedcoords(ax)
ax.scatter(
    ds_climate_preprocessed_sample['glon'],
    ds_climate_preprocessed_sample['glat'],
    marker='s',
    s=10,
    edgecolor='k',
    linewidths=0.5,
)
plt.show()
../../_images/1bc6e8708e1d281160b9c1947caf82102a04cdfdf17850c67bb7bae931ccaa58.png

Defining the model#

Let \(Y(s,t)\) represent the random variable for the June temperature from the AWS data at a particular time and location. Let \(Z(s,t)\) represent the equivalent but for the climate model output. We treat the marginal distribution of \(Y(s,t)\) and \(Z(s,t)\) as Normal, such that \(Y(s)\sim \mathcal{N}(\mu_Y(s),\sigma_Y(s))\) and \(Z(s)\sim \mathcal{N}(\mu_Z(s),\sigma_Z(s))\).

To model the spatial covariance in the parameters we use Gaussian processes. A log transformation is applied to the standard deviation (\(\tilde{\sigma}=log(\sigma)\)) so that it’s on the -\(\infty,\infty\) sample space of a Gaussian process. Shared latent Gaussian processes are considered between the datasets as well as an independent Gaussian process that generates the bias in the climate model output. The parameters for the climate model are then considered as the sum of an unbiased and biased component \(\mu_Z(s)=\mu_Y(s)+\mu_B(s)\) and \(\tilde{\sigma}_Z(s)=\tilde{\sigma}_Y(s)+\tilde{\sigma}_B(s)\), where each component is modelled as generated from a Gaussian process. The Gaussian processes are parameterised by a mean function and covariance function. The mean function for \(\mu_Y(s)\) is considered linear with respect to elevation and latitude, while for the other parameters it is considered a constant. The covariance function is taken as a Matern3/2 kernel with a lengthscale \(l\) and variance \(v\).

\[\mu_Y(S) \sim \mathcal{GP}(m_{\mu_Y}(s)|\beta_{\mu_Y},k(s,s'|l_{\mu_Y},v_{\mu_Y}))\]
\[\mu_B(S) \sim \mathcal{GP}(m_{\mu_B},k(s,s'|l_{\mu_B},v_{\mu_B}))\]
\[\tilde{\sigma}_Y(S) \sim \mathcal{GP}(m_{\tilde{\sigma}_Y},k(s,s'|l_{\tilde{\sigma}_Y},v_{\tilde{\sigma}_Y}))\]
\[\tilde{\sigma}_B(S) \sim \mathcal{GP}(m_{\tilde{\sigma}_B},k(s,s'|l_{\tilde{\sigma}_B},v_{\tilde{\sigma}_B}))\]

The plate diagram below shows the relational dependence between the parameters of the model.

alt text

Splitting up the model for computation#

It’s quite common when using Gaussian processes to fit the mean function independently of the covariance function. That is to say a mean function is fit to the data initially and then the zero-mean data is fit using the GP implementation that handles covariances between points. There are various reasons for this, such as making the covariance matrix more well-conditioned for inference.

In this notebook we’ll split up the model into a first component estimating the parameters \(\mu_Y(s_Y)\), \(\mu_Z(s_Z)\), \(\tilde{\sigma}_Y(s_Y)\) and \(\tilde{\sigma}_Z(s_Z)\) at the AWS and climate model grid cell locations, along with the global parameter values for the mean functions of the latent Gaussian processes \(\beta_{0,\mu_Y}\), \(\beta_{1,\mu_Y}\), \(\beta_{2,\mu_Y}\), \(m_{\mu_B}\), \(m_{\tilde{\sigma}_Y}\) and \(m_{\tilde{\sigma}_B}\). Then the second component will use the residuals \(r_{\mu_Y}(s_Y)\), \(r_{\mu_Z}(s_Z)\), \(r_{\tilde{\sigma}_Y}(s_Y)\) and \(r_{\tilde{\sigma}_Z}(s_Z)\) to estimate the parameters of the covariance function for the latent Gaussian processes \(l_{\mu_Y}\), \(v_{\mu_Y}\), \(l_{\mu_B}\), \(v_{\mu_B}\), \(l_{\tilde{\sigma}_Y}\), \(v_{\tilde{\sigma}_Y}\), \(l_{\tilde{\sigma}_B}\) and \(v_{\tilde{\sigma}_B}\).

Parameter Inference#

Inference is performed on the two components of the model separately. To limit the runtime of this notebook, we’ll provide the code for running the inference but will perform the actual inference separately and load in the output to examine. The python scripts for running the inference separately and saving the output are available via the jupterbook_render branch of the repository. Loading in the inference data:

url = ("https://zenodo.org/records/14779669/files/data_dictionary.pkl?download=1")
filename = f'{data_path}data_dictionary.pkl'
if os.path.exists(filename):
    print('File already exists')
else:
    urlretrieve(url, filename)
    print('File downloaded')
File already exists

Inference on parameters of the mean function#

Utilising the Numpyro python package to define the model for component 1:

Hide code cell content
# The model for predicting the mean and logvar for each dataset as well as the parameters for the meanfunction giving domain-wide behaviour
def meanfunc_model(data_dictionary):
    """
    Function for defining the GP mean function model for the temperature data
    Args:
        data_dictionary (python dictionary): dictionary holding the data needed for the model
    """
    omean_b0 = numpyro.sample("omean_b0",data_dictionary['omean_b0_prior'])
    omean_b1 = numpyro.sample("omean_b1",data_dictionary['omean_b1_prior'])
    omean_b2 = numpyro.sample("omean_b2",data_dictionary['omean_b2_prior'])
    omean_noise = numpyro.sample("omean_noise",data_dictionary['omean_noise_prior'])
    omean_func = omean_b0 + omean_b1*data_dictionary['oele_scaled'] + omean_b2*data_dictionary['olat_scaled']
    omean = numpyro.sample("omean",dist.Normal(omean_func, omean_noise))

    ologvar_b0 = numpyro.sample("ologvar_b0",data_dictionary['ologvar_b0_prior'])
    ologvar_noise = numpyro.sample("ologvar_noise",data_dictionary['ologvar_noise_prior'])
    ologvar_func = ologvar_b0 * jnp.ones(data_dictionary['ox'].shape[0])
    ologvar = numpyro.sample("ologvar",dist.Normal(ologvar_func, ologvar_noise))
    ovar = jnp.exp(ologvar)

    obs_mask = (jnp.isnan(data_dictionary['odata'])==False)
    numpyro.sample("AWS Temperature", dist.Normal(omean, jnp.sqrt(ovar)).mask(obs_mask), obs=data_dictionary["odata"])

    cmean_b0 = numpyro.sample("cmean_b0",data_dictionary['cmean_b0_prior'])
    cmean_noise = numpyro.sample("cmean_noise",data_dictionary['cmean_noise_prior'])
    cmean_func = cmean_b0 + omean_b1*data_dictionary['cele_scaled'] + omean_b2*data_dictionary['clat_scaled']
    cmean = numpyro.sample("cmean",dist.Normal(cmean_func, cmean_noise))

    clogvar_b0 = numpyro.sample("clogvar_b0",data_dictionary['clogvar_b0_prior'])
    clogvar_noise = numpyro.sample("clogvar_noise",data_dictionary['clogvar_noise_prior'])
    clogvar_func = clogvar_b0 * jnp.ones(data_dictionary['cx'].shape[0])
    clogvar = numpyro.sample("clogvar",dist.Normal(clogvar_func, clogvar_noise))
    cvar = jnp.exp(clogvar)

    numpyro.sample("Climate Temperature", dist.Normal(cmean, jnp.sqrt(cvar)), obs=data_dictionary["cdata"])

In the model definition ‘omean’ represents the mean June temperature estimate at each AWS location, while ‘cmean’ is the the equivalent for the climate model output at the grid-cell locations. Similarly, ‘ologvar’ and ‘clogvar’ represents for the log-variance for each dataset and location respectively. The model follows the following equations:

\[\mu_Y(s)=m_{\mu_Y}(s)+r_{\mu_Y} \hspace{5em} \mu_Z(s)=m_{\mu_Z}(s)+r_{\mu_Z}\]
\[\tilde{\sigma}_Y(s)=m_{\tilde{\sigma}_Y}(s)+r_{\tilde{\sigma}_Y} \hspace{5em} \tilde{\sigma}_Z(s)=m_{\tilde{\sigma}_Z}(s)+r_{\tilde{\sigma}_Z}\]
\[m_{\mu_Y}(s)=\beta_{0,\mu_Y} + \beta_{1,\mu} \cdot x_{ele}(s) + \beta_{2,\mu} \cdot x_{lat}(s)\]
\[m_{\mu_Z}(s)=\beta_{0,\mu_Z} + \beta_{1,\mu} \cdot x_{ele}(s) + \beta_{2,\mu} \cdot x_{lat}(s)\]

Defining a function for running the Bayesian inference on the model parameters given the data:

Hide code cell content
# A function to run inference on the model
def run_inference(
    model, rng_key, num_warmup, num_samples, num_chains, *args, **kwargs
):
    """
    Helper function for doing MCMC inference
    Args:
        model (python function): function that follows numpyros syntax
        rng_key (np array): PRNGKey for reproducible results
        num_warmup (int): Number of MCMC steps for warmup
        num_samples (int): Number of MCMC samples to take of parameters after warmup
        num_chains (int): Number of chains to run in parallel (or sequentially without GPU)
        *args: Additional arguments to pass to the model
        **kwargs: Additional keyword arguments to pass to the model
        distance_matrix_values(jax device array): matrix of distances between sites, shape [#sites,#sites]
    Returns:
        MCMC numpyro instance (class object): An MCMC class object with functions such as .get_samples() and .run()
    """
    starttime = timeit.default_timer()

    kernel = NUTS(model)
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
    )

    mcmc.run(rng_key, *args, **kwargs)

    mcmc.print_summary()
    print("Time Taken:", timeit.default_timer() - starttime)
    return mcmc

As we’re utilising a Bayesian framework, we’ll have to define some prior distributions for the parameters. While in theory the prior distribution should represent our state of knowledge before observing the data and so if we have no additional knowledge should be fully non-informative, in practice it’s common to at least set sensible bounds using the exploratory analysis to limit the complexity of the inference space. Here, we compute some basic metrics, which we then use to help set sensible priors.

Hide code cell content
print('Useful Metrics for Priors: \n',
      f"""mean odata:
      min={np.nanmin(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
      mean={np.nanmean(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
      max={np.nanmax(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
      var={np.nanvar(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
      \n""",
      f"""logvar odata:
      min={np.nanmin(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
      mean={np.nanmean(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
      max={np.nanmax(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
      var={np.nanvar(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
      \n""",
      f"""mean cdata:
      min={np.nanmin(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
      mean={np.nanmean(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
      max={np.nanmax(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
      var={np.nanvar(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
      \n""",
      f"""logvar cdata:
      min={np.nanmin(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
      mean={np.nanmean(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
      max={np.nanmax(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
      var={np.nanvar(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
      \n""",
)
Useful Metrics for Priors: 
 mean odata:
      min=-65.0,
      mean=-33.4,
      max=-13.4,
      var=155.0,
      
 logvar odata:
      min=-4.9,
      mean=1.7,
      max=4.1,
      var=1.1,
      
 mean cdata:
      min=-62.8,
      mean=-39.6,
      max=-12.6,
      var=154.6,
      
 logvar cdata:
      min=0.6,
      mean=2.1,
      max=3.3,
      var=0.2,
      
Hide code cell content
# Setting priors
data_dictionary.update({
    "omean_b0_prior": dist.Normal(-33.0, 10.0),
    "omean_b1_prior": dist.Normal(0.0, 10.0),
    "omean_b2_prior": dist.Normal(0.0, 10.0),
    "omean_noise_prior": dist.Uniform(1e-2, 10.0),
    "ologvar_b0_prior": dist.Normal(5, 5.0),
    "ologvar_noise_prior": dist.Uniform(1e-3, 2.0),
})

data_dictionary.update({
    "cmean_b0_prior": dist.Normal(-39.0, 10.0),
    "cmean_noise_prior": dist.Uniform(1e-2, 10.0),
    "clogvar_b0_prior": dist.Normal(5, 2.0),
    "clogvar_noise_prior": dist.Uniform(1e-3, 2.0),
})

The code for running the inference and saving estimates of the parameter posterior distributions is given below. Note that the code is commented out as instead of running the inference within this notebook, for computational reasons we’ll simply load in the output. The output is saved using the ArViZ package, which handles inference data from lots of different probabilistic programming packages and can be used to produce nice summary metrics and plots.

Hide code cell content
'''
# %% Running inference
mcmc = run_inference(meanfunc_model, rng_key, 1000, 2000,4, data_dictionary)

idata = az.from_numpyro(mcmc,
                coords={
                "station": data_dictionary['ds_aws_preprocessed']['station'],
                "x": data_dictionary['ds_climate_preprocessed']['x'],
    },
                dims={"clogvar": ["x"],
                      "cmean": ["x"],
                      "ologvar": ["station"],
                      "omean": ["station"],})
meanfunc_posterior = idata.posterior


# Computing the residuals from the mean function model parameters
meanfunc_posterior = meanfunc_posterior.assign_coords({'oele_scaled':('station', data_dictionary['oele_scaled']),
                        'olat_scaled':('station', data_dictionary['olat_scaled']),
                        'cele_scaled':('x', data_dictionary['cele_scaled']),
                        'clat_scaled':('x', data_dictionary['clat_scaled'])})
meanfunc_posterior['omean_func'] = meanfunc_posterior['omean_b0']+meanfunc_posterior['omean_b1']*meanfunc_posterior['oele_scaled']+meanfunc_posterior['omean_b2']*meanfunc_posterior['olat_scaled']
meanfunc_posterior['cmean_func'] = meanfunc_posterior['cmean_b0']+meanfunc_posterior['omean_b1']*meanfunc_posterior['cele_scaled']+meanfunc_posterior['omean_b2']*meanfunc_posterior['clat_scaled']
meanfunc_posterior['ologvar_func'] = meanfunc_posterior['ologvar_b0']
meanfunc_posterior['clogvar_func'] = meanfunc_posterior['clogvar_b0']
meanfunc_posterior['omean_func_residual'] = meanfunc_posterior['omean']-meanfunc_posterior['omean_func']
meanfunc_posterior['cmean_func_residual'] = meanfunc_posterior['cmean']-meanfunc_posterior['cmean_func']
meanfunc_posterior['ologvar_func_residual'] = meanfunc_posterior['ologvar']-meanfunc_posterior['ologvar_func']
meanfunc_posterior['clogvar_func_residual'] = meanfunc_posterior['clogvar']-meanfunc_posterior['clogvar_func']

data_dictionary['meanfunc_posterior'] = meanfunc_posterior
'''

with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
    data_dictionary = pickle.load(f)

meanfunc_posterior = data_dictionary['meanfunc_posterior'] 

MCMC is an approximate procedure and it is difficult to assess directly whether the samples returned from the inference accurately capture the true posterior distributions of the parameters. The effective sample size (ESS) and r_hat diagnostics provide some indication, we’re looking for values of r_hat in the range (1.0, 1.05) and ESS that are comparable to the total number of samples, which we get in this run:

Hide code cell source
az.summary(meanfunc_posterior[['omean_b0',
            'omean_b1',
            'omean_b2',
            'omean_noise',
            'ologvar_b0',
            'ologvar_noise',
            'cmean_b0',
            'cmean_noise',
            'clogvar_b0',
            ]],hdi_prob=0.95)
mean sd hdi_2.5% hdi_97.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
omean_b0 -33.416 0.539 -34.523 -32.402 0.005 0.007 13604.0 5746.0 1.0
omean_b1 -10.059 0.069 -10.193 -9.926 0.001 0.001 11527.0 6119.0 1.0
omean_b2 3.100 0.068 2.963 3.234 0.001 0.001 16561.0 6066.0 1.0
omean_noise 6.740 0.402 5.978 7.514 0.004 0.006 12338.0 5185.0 1.0
ologvar_b0 2.043 0.066 1.917 2.174 0.001 0.001 6133.0 6362.0 1.0
ologvar_noise 0.584 0.063 0.467 0.710 0.001 0.001 2558.0 4297.0 1.0
cmean_b0 -32.623 0.084 -32.796 -32.464 0.001 0.001 11574.0 5437.0 1.0
cmean_noise 5.263 0.050 5.164 5.361 0.000 0.001 16003.0 5160.0 1.0
clogvar_b0 2.146 0.006 2.135 2.158 0.000 0.000 9323.0 6671.0 1.0

Inference on parameters of the residual and covariance functions#

For component 2 of the model, the residuals of the mean and log(standard dev.) are treated as independent and so inference can be conducted separately for each. Again utilising the Numpyro python package to define the model gives:

Hide code cell content
# Helper functions to be used in the residual model

def diagonal_noise(coord, noise):
    return jnp.diag(jnp.full(coord.shape[0], noise))

def generate_obs_conditional_climate_dist(
    ox, cx, cdata, ckernel, cdiag, okernel, odiag
):
    y2 = cdata
    u1 = jnp.full(ox.shape[0], 0)
    u2 = jnp.full(cx.shape[0], 0)
    k11 = okernel(ox, ox) + diagonal_noise(ox, odiag)
    k12 = okernel(ox, cx)
    k21 = okernel(cx, ox)
    k22 = ckernel(cx, cx) + diagonal_noise(cx, cdiag)
    k22i = jnp.linalg.inv(k22)
    u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
    l22 = jnp.linalg.cholesky(k22)
    l22i = jnp.linalg.inv(l22)
    p21 = jnp.matmul(l22i, k21)
    k1g2 = k11 - jnp.matmul(p21.T, p21)
    mvn_dist = dist.MultivariateNormal(u1g2, k1g2)
    return mvn_dist

# The residual model for the mean temperature

def residual_model(data_dictionary,metric):
    """
    Example model where the climate data is generated from 2 GPs,
    one of which also generates the observations and one of
    which generates bias in the climate model.
    """
    meanfunc_posterior = data_dictionary['meanfunc_posterior']
    omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
    omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
    cmeanfunc_residual_exp_subsample = meanfunc_posterior[f'c{metric}_func_residual'].isel(x=data_dictionary['random_sample']).mean(['draw','chain']).data

    kern_var = numpyro.sample("kern_var", data_dictionary[f'o{metric}_func_residual_kvprior'])
    lengthscale = numpyro.sample("lengthscale", data_dictionary[f'o{metric}_func_residual_klprior'])
    kernel = kern_var * kernels.Matern32(lengthscale,L2Distance())
    noise = numpyro.sample("noise", data_dictionary[f'o{metric}_func_residual_nprior'])
    var_obs = omeanfunc_residual_var
    
    bkern_var = numpyro.sample("bkern_var", data_dictionary[f'b{metric}_func_residual_kvprior'])
    blengthscale = numpyro.sample("blengthscale", data_dictionary[f'b{metric}_func_residual_klprior'])
    bkernel = bkern_var * kernels.Matern32(blengthscale,L2Distance())
    bnoise = numpyro.sample("bnoise", data_dictionary[f'b{metric}_func_residual_nprior'])

    ckernel = kernel + bkernel
    cnoise = noise + bnoise 
    cgp = GaussianProcess(ckernel, data_dictionary["cx_subsample"], diag=cnoise, mean=0)
    numpyro.sample("climate_temperature",
                   cgp.numpyro_dist(),
                   obs=cmeanfunc_residual_exp_subsample)

    obs_conditional_climate_dist = generate_obs_conditional_climate_dist(
        data_dictionary["ox"],
        data_dictionary["cx_subsample"],
        cmeanfunc_residual_exp_subsample,
        ckernel,
        cnoise,
        kernel,
        var_obs+noise
    )
    numpyro.sample(
        "obs_temperature",
        obs_conditional_climate_dist,
        obs=omeanfunc_residual_exp
    )

# Function for running the MCMC inference and generating the posterior distributions for the model

def generate_posterior_residual_model(data_dictionary,
                                      metric,
                                      rng_key,
                                      num_warmup,
                                      num_samples,
                                      num_chains):
    mcmc_residual_model = run_inference(
        residual_model,
        rng_key,
        num_warmup,
        num_samples,
        num_chains,
        data_dictionary,
        metric
    )
    idata_residual_model = az.from_numpyro(mcmc_residual_model)
    data_dictionary[f"idata_residual_model_{metric}"] = idata_residual_model

In the model definition for the residuals the observations are the expectations of the residuals computed from inference of the meanfunction model in component 1. That is \(E[r_{\mu_Y}(s_Y)]\) and \(E[r_{\mu_Z}(s_Z)]\). The uncertainty in these values is captured through the variance, so \(V[r_{\mu_Y}(s_Y)]\) and \(V[r_{\mu_Z}(s_Z)]\). Although, we don’t include \(V[r_{\mu_Z}(s_Z)]\) in the model definition as it’s an insignificant quantity. The model follows the equations:

\[r_{\mu_Y}(s) \sim \mathcal{GP}(0|,k(s,s'|l_{\mu_Y},v_{\mu_Y},n_{\mu_Y})) \hspace{2em} r_{\mu_B}(s) \sim \mathcal{GP}(0|,k(s,s'|l_{\mu_B},v_{\mu_B},n_{\mu_B}))\]
\[r_{\mu_Z}(s) = r_{\mu_Y}(s)+r_{\mu_B}(s)\]
\[k(s,s')=v(1+\dfrac{\sqrt{3}|s-s'|}{l}exp\left({-\dfrac{\sqrt{3}|s-s'|}{l}}\right)\]

Again choosing sensible priors for the parameters of the model that have reasonable bounds based on some basic summary metrics:

Hide code cell content
# Useful metrics for priors
meanfunc_posterior = data_dictionary['meanfunc_posterior']
exp_omean_func_residual = meanfunc_posterior['omean_func_residual'].mean(['draw','chain'])
exp_ologvar_func_residual = meanfunc_posterior['ologvar_func_residual'].mean(['draw','chain'])
ox_ranges = data_dictionary['ox'].max(axis=0)-data_dictionary['ox'].min(axis=0)
print('Useful Metrics for Priors: \n',
      f"""
      Mean Obs:
      min={exp_omean_func_residual.min():.1f},
      mean={exp_omean_func_residual.mean():.1f},
      max={exp_omean_func_residual.max():.1f},
      var={exp_omean_func_residual.var():.1f},
      """,
      f"""
      LogVar Obs:
      min={exp_ologvar_func_residual.min():.1f},
      mean={exp_ologvar_func_residual.mean():.1f},
      max={exp_ologvar_func_residual.max():.1f},
      var={exp_ologvar_func_residual.var():.1f},
      """,
      f'\n Obs Axis Ranges={ox_ranges}'
)
Useful Metrics for Priors: 
 
      Mean Obs:
      min=-16.9,
      mean=-0.0,
      max=17.3,
      var=43.1,
       
      LogVar Obs:
      min=-0.9,
      mean=-0.0,
      max=1.2,
      var=0.2,
       
 Obs Axis Ranges=[45.47068137 34.36497791]
Hide code cell content
# Setting priors
lengthscale_max = 20

data_dictionary['omean_func_residual_kvprior'] = dist.Uniform(0.1,100.0)
data_dictionary['omean_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['omean_func_residual_nprior'] = dist.Uniform(0.1,20.0)

data_dictionary['bmean_func_residual_kvprior'] = dist.Uniform(0.1,100.0)
data_dictionary['bmean_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['bmean_func_residual_nprior'] = dist.Uniform(0.1,20.0)

data_dictionary['ologvar_func_residual_kvprior'] = dist.Uniform(0.01,1.0)
data_dictionary['ologvar_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['ologvar_func_residual_nprior'] = dist.Uniform(0.01,1.0)

data_dictionary['blogvar_func_residual_kvprior'] = dist.Uniform(0.01,1.0)
data_dictionary['blogvar_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['blogvar_func_residual_nprior'] = dist.Uniform(0.01,1.0)

The code for running the inference and saving estimates of the parameter posterior distributions is given below. Note that the code is commented out as instead of running the inference within this notebook, for computational reasons we’ll simply load in the output.

Hide code cell content
# Running the inference
'''
generate_posterior_residual_model(data_dictionary,
                                'mean',
                                rng_key,
                                1000,
                                1000,
                                4)

generate_posterior_residual_model(data_dictionary,
                                'logvar',
                                rng_key,
                                1000,
                                1000,
                                4)

'''

with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
    data_dictionary = pickle.load(f)

idata_residual_model_mean = data_dictionary["idata_residual_model_mean"]
idata_residual_model_logvar = data_dictionary["idata_residual_model_logvar"]
Hide code cell source
print(r"Parameter Inference for Model of Mean Residuals")
display(az.summary(idata_residual_model_mean.posterior,hdi_prob=0.95))
print(r"Parameter Inference for Model of LogVar Residuals")
display(az.summary(idata_residual_model_logvar.posterior,hdi_prob=0.95))
Parameter Inference for Model of Mean Residuals
mean sd hdi_2.5% hdi_97.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
bkern_var 5.838 7.004 0.102 18.185 0.125 0.311 2838.0 1879.0 1.0
blengthscale 12.991 4.687 4.797 19.994 0.075 0.053 3111.0 1959.0 1.0
bnoise 0.604 0.510 0.101 1.622 0.008 0.013 2689.0 1391.0 1.0
kern_var 35.067 9.148 20.810 53.265 0.263 0.537 2067.0 1227.0 1.0
lengthscale 3.431 0.638 2.308 4.641 0.016 0.019 2015.0 1571.0 1.0
noise 5.468 1.086 3.384 7.638 0.020 0.018 3129.0 2633.0 1.0
Parameter Inference for Model of LogVar Residuals
mean sd hdi_2.5% hdi_97.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
bkern_var 0.067 0.064 0.010 0.184 0.001 0.003 2827.0 1632.0 1.0
blengthscale 13.684 4.219 6.214 19.997 0.076 0.048 2552.0 1915.0 1.0
bnoise 0.013 0.003 0.010 0.019 0.000 0.000 2291.0 1534.0 1.0
kern_var 0.125 0.075 0.044 0.250 0.002 0.006 1939.0 1116.0 1.0
lengthscale 5.598 1.638 3.142 8.781 0.045 0.075 1974.0 1236.0 1.0
noise 0.012 0.002 0.010 0.017 0.000 0.000 3033.0 1774.0 1.0

The posterior distributions for the parameters of the model appear sensible on first glance. The r-hat statistic indicates that the chains are independent as desired and the effective sample size is of comparable magnitude to the total number of samples taken.

Making posterior predictive estimates of the unbiased mean and variance across the domain#

From the previous section we have:

  • Posterior estimates of the mean and log-variance at the AWS sites and climate model grid cells, that is \(\mu_Y(s_Y)\), \(\tilde{\sigma}_Y(s_Y)\), \(\mu_Z(s_Z)\) and \(\tilde{\sigma}_Z(s_Z)\). As well as the estimates for the meanfunction and residual components: \(m_{\mu_Y}(s_Y)\), \(m_{\tilde{\sigma}_Y}(s_Y)\), \(m_{\mu_Z}(s_Z)\), \(m_{\tilde{\sigma}_Z}(s_Z)\) and \(r_{\mu_Y}(s_Y)\), \(r_{\tilde{\sigma}_Y}(s_Y)\), \(r_{\mu_Z}(s_Z)\) and \(r_{\tilde{\sigma}_Z}(s_Z)\).

  • Posterior estimates of the parameters of the mean functions, so \(\beta_{0,\mu_Y}\), \(\beta_{1,\mu}\), \(\beta_{2,\mu}\), \(\beta_{0,\mu_Z}\), \(\beta_{0,\tilde{\sigma}_Y}\) and \(\beta_{0,\tilde{\sigma}_Z}\).

  • Posterior estimates of the parameters of the covariance functions, so \(l_{\mu_Y}\), \(v_{\mu_Y}\), \(n_{\mu_Y}\), \(l_{\mu_B}\), \(v_{\mu_B}\) and \(n_{\mu_B}\).

Now we want to make estimates of the posterior distributions of the unbiased mean and log-variance at the climate model locations to use for bias correction, so \(\mu_Y(s_Z)\) and \(\tilde{\sigma}_Y(s_Z)\). This is known as the posterior predictive and involves conditioning on both the inferred parameters of the model and the observed data. Since we split up the hierarchical model, to get \(\mu_Y(s_Z)\) and \(\tilde{\sigma}_Y(s_Z)\) we’ll get estimates for \(m_{\mu_Y}(s_Z)\) and \(m_{\tilde{\sigma}_Y}(s_Z)\) from the first component of the model, then \(r_{\mu_Y}(s_Z)\) and \(r_{\tilde{\sigma}_Y}(s_Z)\) from the second component.

Sampling the predictive distribution of the mean function is simple:

Hide code cell content
# Generating posterior predictive estimates of the mean function
meanfunc_posterior['mean_unbiased_meanfunc_predictive'] = (meanfunc_posterior['omean_b0']+
                                                           meanfunc_posterior['omean_b1']*meanfunc_posterior['cele_scaled']+
                                                           meanfunc_posterior['omean_b2']*meanfunc_posterior['clat_scaled'])

meanfunc_posterior['mean_biased_meanfunc_predictive'] = (meanfunc_posterior['cmean_b0']-
                                                         meanfunc_posterior['omean_b0'])
    
meanfunc_posterior['logvar_unbiased_meanfunc_predictive'] = (meanfunc_posterior['ologvar_b0'])

meanfunc_posterior['logvar_biased_meanfunc_predictive'] = (meanfunc_posterior['clogvar_b0']-
                                                         meanfunc_posterior['ologvar_b0'])    

Sampling the predictive distribution of the residual is more difficult and we construct some functions to help.

Hide code cell content
# %% Defining function for generating posterior predictive realisations of the residuals

def generate_truth_predictive_dist(nx, data_dictionary, metric, posterior_param_realisation):
    kern_var_realisation = posterior_param_realisation["kern_var_realisation"]
    lengthscale_realisation = posterior_param_realisation["lengthscale_realisation"]
    noise_realisation = posterior_param_realisation["noise_realisation"]

    bkern_var_realisation = posterior_param_realisation["bkern_var_realisation"]
    blengthscale_realisation = posterior_param_realisation["blengthscale_realisation"]
    bnoise_realisation = posterior_param_realisation["bnoise_realisation"]

    meanfunc_posterior = data_dictionary['meanfunc_posterior']
    omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
    omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
    cmeanfunc_residual_exp = meanfunc_posterior[f'c{metric}_func_residual'].mean(['draw','chain']).data

    ox = data_dictionary["ox"]
    cx = data_dictionary["cx"]
    odata = omeanfunc_residual_exp
    odata_var = omeanfunc_residual_var
    cdata = cmeanfunc_residual_exp
    kernelo = kern_var_realisation * kernels.Matern32(lengthscale_realisation,L2Distance())
    kernelb = bkern_var_realisation * kernels.Matern32(blengthscale_realisation,L2Distance())

    noise = noise_realisation + odata_var
    bnoise = bnoise_realisation
    cnoise = noise_realisation + bnoise

    jitter = 1e-5

    y2 = jnp.hstack([odata, cdata])
    u1 = jnp.full(nx.shape[0], 0)
    u2 = jnp.hstack(
        [jnp.full(ox.shape[0], 0), jnp.full(cx.shape[0], 0)]
    )
    k11 = kernelo(nx, nx) + diagonal_noise(nx, jitter)
    k12 = jnp.hstack([kernelo(nx, ox), kernelo(nx, cx)])
    k21 = jnp.vstack([kernelo(ox, nx), kernelo(cx, nx)])
    k22_upper = jnp.hstack(
        [kernelo(ox, ox) + diagonal_noise(ox, noise), kernelo(ox, cx)]
    )
    k22_lower = jnp.hstack(
        [
            kernelo(cx, ox),
            kernelo(cx, cx) + kernelb(cx, cx) + diagonal_noise(cx, cnoise),
        ]
    )
    k22 = jnp.vstack([k22_upper, k22_lower])
    k22 = k22
    k22i = jnp.linalg.inv(k22)

    u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
    k1g2 = k11 - jnp.matmul(jnp.matmul(k12, k22i), k21)
    mvn = dist.MultivariateNormal(u1g2, k1g2)
    return mvn

def generate_bias_predictive_dist(nx, data_dictionary, metric, posterior_param_realisation):
    kern_var_realisation = posterior_param_realisation["kern_var_realisation"]
    lengthscale_realisation = posterior_param_realisation["lengthscale_realisation"]
    noise_realisation = posterior_param_realisation["noise_realisation"]

    bkern_var_realisation = posterior_param_realisation["bkern_var_realisation"]
    blengthscale_realisation = posterior_param_realisation["blengthscale_realisation"]
    bnoise_realisation = posterior_param_realisation["bnoise_realisation"]

    meanfunc_posterior = data_dictionary['meanfunc_posterior']
    omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
    omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
    cmeanfunc_residual_exp = meanfunc_posterior[f'c{metric}_func_residual'].mean(['draw','chain']).data

    ox = data_dictionary["ox"]
    cx = data_dictionary["cx"]
    odata = omeanfunc_residual_exp
    odata_var = omeanfunc_residual_var
    cdata = cmeanfunc_residual_exp
    kernelo = kern_var_realisation * kernels.Matern32(lengthscale_realisation,L2Distance())
    kernelb = bkern_var_realisation * kernels.Matern32(blengthscale_realisation,L2Distance())

    noise = noise_realisation + odata_var
    bnoise = bnoise_realisation
    cnoise = noise_realisation + bnoise

    jitter = 1e-5

    y2 = jnp.hstack([odata, cdata])
    u1 = jnp.full(nx.shape[0], 0)
    u2 = jnp.hstack(
        [jnp.full(ox.shape[0], 0), jnp.full(cx.shape[0], 0)]
    )
    k11 = kernelb(nx, nx) + diagonal_noise(nx, jitter)
    k12 = jnp.hstack([jnp.full((len(nx), len(ox)), 0), kernelb(nx, cx)])
    k21 = jnp.vstack([jnp.full((len(ox), len(nx)), 0), kernelb(cx, nx)])
    k22_upper = jnp.hstack(
        [kernelo(ox, ox) + diagonal_noise(ox, noise), kernelo(ox, cx)]
    )
    k22_lower = jnp.hstack(
        [
            kernelo(cx, ox),
            kernelo(cx, cx) + kernelb(cx, cx) + diagonal_noise(cx, cnoise),
        ]
    )
    k22 = jnp.vstack([k22_upper, k22_lower])
    k22 = k22
    k22i = jnp.linalg.inv(k22)
    u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
    l22 = jnp.linalg.cholesky(k22)
    l22i = jnp.linalg.inv(l22)
    p21 = jnp.matmul(l22i, k21)
    k1g2 = k11 - jnp.matmul(p21.T, p21)
    mvn = dist.MultivariateNormal(u1g2, k1g2)
    return mvn

def generate_posterior_predictive_realisations_dualprocess(
    nx,
    data_dictionary,
    metric,
    num_parameter_realisations,
    num_posterior_pred_realisations,
    rng_key
):
    posterior = data_dictionary[f"idata_residual_model_{metric}"].posterior
    truth_posterior_predictive_realisations = []
    bias_posterior_predictive_realisations = []
    iteration = 0
    for i in tqdm(np.random.randint(posterior.draw.shape, size=num_parameter_realisations)):
        posterior_param_realisation = {
            "iteration": i,
            "kern_var_realisation": posterior["kern_var"].data[0, :][i],
            "lengthscale_realisation": posterior["lengthscale"].data[0, :][i],
            "noise_realisation": posterior["noise"].data[0, :][i],
            "bkern_var_realisation": posterior["bkern_var"].data[0, :][i],
            "blengthscale_realisation": posterior["blengthscale"].data[0, :][i],
            "bnoise_realisation": posterior["bnoise"].data[0, :][i],
        }

        truth_predictive_dist = generate_truth_predictive_dist(
            nx, data_dictionary, metric, posterior_param_realisation
        )
        bias_predictive_dist = generate_bias_predictive_dist(
            nx, data_dictionary, metric, posterior_param_realisation
        )
        iteration += 1

        truth_predictive_realisations = truth_predictive_dist.sample(
            rng_key, sample_shape=(num_posterior_pred_realisations,)
        )
        rng_key, rng_key_ = random.split(rng_key)
        bias_predictive_realisations = bias_predictive_dist.sample(
            rng_key, sample_shape=(num_posterior_pred_realisations,)
        )
        rng_key, rng_key_ = random.split(rng_key)

        truth_posterior_predictive_realisations.append(truth_predictive_realisations)
        bias_posterior_predictive_realisations.append(bias_predictive_realisations)

    truth_posterior_predictive_realisations = jnp.array(
        truth_posterior_predictive_realisations
    )
    bias_posterior_predictive_realisations = jnp.array(
        bias_posterior_predictive_realisations
    )

    residual_pospred_ds = xr.Dataset(
    coords = {'hyperparameter_draws':range(num_parameter_realisations),
              'gp_draws':range(num_posterior_pred_realisations),
              'nx':range(len(nx)),
              },
    data_vars = {f'unbiased_{metric}_residual_postpred':(['hyperparameter_draws','gp_draws','nx'], truth_posterior_predictive_realisations),
                 f'bias_{metric}_residual_postpred':(['hyperparameter_draws','gp_draws','nx'], bias_posterior_predictive_realisations),
                }
    )
    return residual_pospred_ds

Generating the residual posterior predictive samples:

Hide code cell content
# Generating posterior predictive realisations of the residuals
'''
residual_postpred_mean = generate_posterior_predictive_realisations_dualprocess(
    data_dictionary["cx"],
    data_dictionary,
    "mean",
    100,
    5,
    rng_key
)

residual_postpred_logvar = generate_posterior_predictive_realisations_dualprocess(
    data_dictionary["cx"],
    data_dictionary,
    "logvar",
    100,
    5,
    rng_key
)
'''

with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
    data_dictionary = pickle.load(f)

residual_postpred_mean = data_dictionary["residual_postpred_mean"]
residual_postpred_logvar = data_dictionary["residual_postpred_logvar"]

Computing the expectations and uncertainties:

Hide code cell content
# Meanfunction posterior predictive estimates expectation and uncertainty
mean_unbiased_meanfunc_predictive_exp = meanfunc_posterior['mean_unbiased_meanfunc_predictive'].mean(['draw','chain'])
mean_unbiased_meanfunc_predictive_var = meanfunc_posterior['mean_unbiased_meanfunc_predictive'].var(['draw','chain'])
mean_biased_meanfunc_predictive_exp = meanfunc_posterior['mean_biased_meanfunc_predictive'].mean(['draw','chain'])
mean_biased_meanfunc_predictive_var = meanfunc_posterior['mean_biased_meanfunc_predictive'].var(['draw','chain'])

logvar_unbiased_meanfunc_predictive_exp = meanfunc_posterior['logvar_unbiased_meanfunc_predictive'].mean(['draw','chain'])
logvar_unbiased_meanfunc_predictive_var = meanfunc_posterior['logvar_unbiased_meanfunc_predictive'].var(['draw','chain'])
logvar_biased_meanfunc_predictive_exp = meanfunc_posterior['logvar_biased_meanfunc_predictive'].mean(['draw','chain'])
logvar_biased_meanfunc_predictive_var = meanfunc_posterior['logvar_biased_meanfunc_predictive'].var(['draw','chain'])

# Residuals posterior predictive estimates expectation and uncertainty
mean_unbiased_residual_postpred_exp = residual_postpred_mean['unbiased_mean_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
mean_unbiased_residual_postpred_var = residual_postpred_mean['unbiased_mean_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
mean_biased_residual_postpred_exp = residual_postpred_mean['bias_mean_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
mean_biased_residual_postpred_var = residual_postpred_mean['bias_mean_residual_postpred'].var(['gp_draws','hyperparameter_draws'])

logvar_unbiased_residual_postpred_exp = residual_postpred_logvar['unbiased_logvar_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
logvar_unbiased_residual_postpred_var = residual_postpred_logvar['unbiased_logvar_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
logvar_biased_residual_postpred_exp = residual_postpred_logvar['bias_logvar_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
logvar_biased_residual_postpred_var = residual_postpred_logvar['bias_logvar_residual_postpred'].var(['gp_draws','hyperparameter_draws'])

# Noise in residuals posterior predictive estimates expectation and uncertainty
mean_unbiased_residual_postpred_noise = idata_residual_model_mean.posterior['noise'].mean()
mean_biased_residual_postpred_noise = idata_residual_model_mean.posterior['bnoise'].mean()
logvar_unbiased_residual_postpred_noise = idata_residual_model_logvar.posterior['noise'].mean()
logvar_biased_residual_postpred_noise = idata_residual_model_logvar.posterior['bnoise'].mean()

# Combined posterior predictive estimates expectation and uncertainty

mean_unbiased_predictive_exp = mean_unbiased_meanfunc_predictive_exp.data + mean_unbiased_residual_postpred_exp.data
mean_unbiased_predictive_var = mean_unbiased_meanfunc_predictive_var.data + mean_unbiased_residual_postpred_var.data + mean_unbiased_residual_postpred_noise.data
mean_biased_predictive_exp = mean_biased_meanfunc_predictive_exp.data + mean_biased_residual_postpred_exp.data
mean_biased_predictive_var = mean_biased_meanfunc_predictive_var.data + mean_biased_residual_postpred_var.data + mean_biased_residual_postpred_noise.data

logvar_unbiased_predictive_exp = logvar_unbiased_meanfunc_predictive_exp.data + logvar_unbiased_residual_postpred_exp.data
logvar_unbiased_predictive_var = logvar_unbiased_meanfunc_predictive_var.data + logvar_unbiased_residual_postpred_var.data + logvar_unbiased_residual_postpred_noise.data
logvar_biased_predictive_exp = logvar_biased_meanfunc_predictive_exp.data + logvar_biased_residual_postpred_exp.data
logvar_biased_predictive_var = logvar_biased_meanfunc_predictive_var.data + logvar_biased_residual_postpred_var.data + logvar_biased_residual_postpred_noise.data

Results#

To plot the predictions of the unbiased mean and log(variance), as well as the bias, we combine the posterior predictives with a dataset containing the relevant coordinates for plotting:

Hide code cell content
# Incorporating results into dataset with coordinates 
ds_climate_stacked_landonly_filtered['mean_unbiased_meanfunc_predictive_exp'] = (['x'],mean_unbiased_meanfunc_predictive_exp.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_meanfunc_predictive_var'] = (['x'],mean_unbiased_meanfunc_predictive_var.data)
ds_climate_stacked_landonly_filtered['mean_biased_meanfunc_predictive_exp'] = (['x'],np.full(5724,mean_biased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['mean_biased_meanfunc_predictive_var'] = (['x'],np.full(5724,mean_biased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['logvar_unbiased_meanfunc_predictive_exp'] = (['x'],np.full(5724,logvar_unbiased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['logvar_unbiased_meanfunc_predictive_var'] = (['x'],np.full(5724,logvar_unbiased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['logvar_biased_meanfunc_predictive_exp'] = (['x'],np.full(5724,logvar_biased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['logvar_biased_meanfunc_predictive_var'] = (['x'],np.full(5724,logvar_biased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['mean_unbiased_residual_postpred_exp'] = (['x'],mean_unbiased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_residual_postpred_var'] = (['x'],mean_unbiased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['mean_biased_residual_postpred_exp'] = (['x'],mean_biased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['mean_biased_residual_postpred_var'] = (['x'],mean_biased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['logvar_unbiased_residual_postpred_exp'] = (['x'],logvar_unbiased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['logvar_unbiased_residual_postpred_var'] = (['x'],logvar_unbiased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['logvar_biased_residual_postpred_exp'] = (['x'],logvar_biased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['logvar_biased_residual_postpred_var'] = (['x'],logvar_biased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_exp'] = (['x'],mean_unbiased_predictive_exp)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_var'] = (['x'],mean_unbiased_predictive_var)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_2sigma'] = (['x'],2*np.sqrt(mean_unbiased_predictive_var))
ds_climate_stacked_landonly_filtered['mean_biased_predictive_exp'] = (['x'],mean_biased_predictive_exp)
ds_climate_stacked_landonly_filtered['mean_biased_predictive_var'] = (['x'],mean_biased_predictive_var)
ds_climate_stacked_landonly_filtered['mean_biased_predictive_2sigma'] = (['x'],2*np.sqrt(mean_biased_predictive_var))
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_exp'] = (['x'],logvar_unbiased_predictive_exp)
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_var'] = (['x'],logvar_unbiased_predictive_var)
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_2sigma'] = (['x'],2*np.sqrt(logvar_unbiased_predictive_var))
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_exp'] = (['x'],logvar_biased_predictive_exp)
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_var'] = (['x'],logvar_biased_predictive_var)
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_2sigma'] = (['x'],2*np.sqrt(logvar_biased_predictive_var))

ds_climate_stacked = xr.merge([ds_climate_stacked,ds_climate_stacked_landonly_filtered])

The expectation of the unbiased and biased June mean temperature at the climate model grid-cell locations is then given below along with the uncertainty. The values are the result of summing the posterior predictive estimates from both components of the model.

Hide code cell source
# Plotting the predictions for the unbiased June mean temperature and the bias.
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)

metrics = ['mean_unbiased_predictive_exp','mean_biased_predictive_exp','mean_unbiased_predictive_2sigma','mean_biased_predictive_2sigma']
cmaps = ['viridis','RdBu','plasma','plasma']
vminmaxs = [(-60,-15),(-3,3),(5,6),(2.0,3.5)]
labels = ['Expectation Unbiased June Mean Temperature','Expectation Bias June Mean Temperature',
          'Uncertainty Unbiased June Mean Temperature','Uncertainty Bias June Mean Temperature']

for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
    background_map_rotatedcoords(ax)
    ds_climate_stacked[metric].unstack().plot.pcolormesh(
        x='glon',
        y='glat',
        ax=ax,
        alpha=0.9,
        vmin=vminmax[0],
        vmax=vminmax[1],
        cmap=cmap,
        cbar_kwargs = {'fraction':0.030,
                    'pad':0.02,
                    'label':label}
    )

    ax.scatter(
        ds_aws_filtered['glon'],
        ds_aws_filtered['glat'],
        marker="o",
        edgecolor="w",
        facecolor="none",
        linewidth=1.0,
    )

    ax.set_axis_off()
    ax.set_title('')
                       
plt.tight_layout()
../../_images/4982f5bdcc71547965abb77b6776eae5b0971e2043290cc9e9ea4001d89f2c24.png

The equivalent for the log(variance) is also given below:

Hide code cell source
# Plotting the predictions for the unbiased June mean temperature and the bias.
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)

metrics = ['logvar_unbiased_predictive_exp','logvar_biased_predictive_exp','logvar_unbiased_predictive_2sigma','logvar_biased_predictive_2sigma']
cmaps = ['viridis','RdBu','plasma','plasma']
vminmaxs = [(1.4,3.0),(-0.25,0.25),(0.3,0.45),(0.3,0.45)]
labels = ['Expectation Unbiased June Log(Var) Temperature','Expectation Bias June Log(Var) Temperature',
          'Uncertainty Unbiased June Log(Var) Temperature','Uncertainty Bias June Log(Var) Temperature']

for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
    background_map_rotatedcoords(ax)
    ds_climate_stacked[metric].unstack().plot.pcolormesh(
        x='glon',
        y='glat',
        ax=ax,
        alpha=0.9,
        vmin=vminmax[0],
        vmax=vminmax[1],
        cmap=cmap,
        cbar_kwargs = {'fraction':0.030,
                    'pad':0.02,
                    'label':label}
    )

    ax.scatter(
        ds_aws_filtered['glon'],
        ds_aws_filtered['glat'],
        marker="o",
        edgecolor="w",
        facecolor="none",
        linewidth=1.0,
    )

    ax.set_axis_off()
    ax.set_title('')
                       
plt.tight_layout()
../../_images/2b174c0e15c92f1d73a42b1fdb273a731b0fdf0b6a1bc3f80f8e67680474feda.png

There are various interesting things to note about these predictions:

  • The expectation for the bias varies smoothly across the domain, which is the result of modelling it with a single lengthscale in the latent Gaussian process. In reality we know the bias will vary more sharply across the region, however with limited data capturing the large scale patterns is already a good start and the smoothness makes it clear we’re not overfitting to the data. Being resistant to overfitting is a common benefit of Gaussian processes.

  • The uncertainty is dominated by a constant noise term across the domain, which makes sense since we expect sharp variations that are difficult to predict with the limited weather station data available.

  • The uncertainty is lowest nearby clusters of weather stations. This makes sense and is the result of modelling the spatial covariance. Away from weather stations we have limited information about the bias.

Quantile Mapping#

The final step after estimating the unbiased PDF parameter values across the domain is to actually apply the bias correction to the original climate model spatio-temporal output. To do this we utilise quantile mapping, which maps the climate model data onto the CDF represented by the unbiased PDF parameters.

For every record \(j\) at each location \(i\) we apply the mapping: $\(\hat{z}_{s_{i,j}} = F_{Y_{s_i}}^{-1}(F_{Z_{s_i}}(z_{s_{i,j}}))\)$

Where \(F_{Y_{s_i}}^{-1}\) is the unbiased inverse cumulative density function and \(F_{Z_{s_i}}\) is the CDF for the climate model data.

Instead of using the expectation of posterior predictive parameter estimates, realisations of the posterior predictives of the unbiased parameters are utilised and quantile mapping performed with each realisation separately. This allows propagation of uncertainty, producing multiple realisations for the corrected timeseries and thus an uncertainty band on the correction. To illustrate this six AWS sites are chosen and the timeseries from the AWS records plot alongside the timeseries for the original and corrected climate model output from the nearest grid cell. Additionally, the original and bias corrected climate model timeseries are shown for a site isolated from any nearby AWSs.

Taking a selection of 6 weather station sites dotted around the domain and 1 isolated site, we’ll demonstrate the bias corrected time series at each of the nearest climate model grid-cells. To start with we’ll show the locations of the different chosen sites. Then we’ll illustrate the time series for the weather stations at the sites, the time series for the climate model output of the nearest grid-cell and the time series of the bias corrected climate model output for the nearest grid-cell.

Hide code cell content
# Including mean expectation for climate model data 
ds_climate_stacked_landonly_filtered['exp_mean_climate'] = (['x'],meanfunc_posterior['cmean'].mean(['draw','chain']).data)
ds_climate_stacked_landonly_filtered['exp_logvar_climate'] = (['x'],meanfunc_posterior['clogvar'].mean(['draw','chain']).data)

# Computing nearest grid-cells to weather stations
ox = np.dstack([ds_aws_filtered['glon'],ds_aws_filtered['glat']])[0]
cx = np.dstack([ds_climate_stacked_landonly['glon'],ds_climate_stacked_landonly['glat']])[0]
nn_indecies = []
for point in ox:
    nn_indecies.append(distance.cdist([point], cx).argmin())
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered.isel(x=nn_indecies)
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered_nn.assign_coords(nearest_station=("x", ds_aws_filtered.station.data))
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered_nn.swap_dims({"x": "nearest_station"})
Hide code cell source
# Plotting selected station locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)

stations = ['Henry','Manuela','Butler Island','Byrd','Relay Station','Dome C']
ds = ds_aws_filtered.sel(station = stations)
ds_isolated = ds_climate_stacked_landonly_filtered.isel(x=2050)

ax.scatter(
    ds.glon,
    ds.glat,
    s=30,
    marker='*',
    edgecolor='k',
    linewidths=0.5,
)

ax.scatter(
    ds_isolated.glon,
    ds_isolated.glat,
    s=30,
    marker='*',
    edgecolor='k',
    linewidths=0.5,
)

gdf_icesheet_rotatedcoords.boundary.plot(ax=ax, color="k", linewidth=0.1)

ax.set_title('')
ax.set_axis_off()

offsets = [[10, 5],[10, 5],[-8, 10],[-8, 10],[-20, -15],[0,10]]
for station,offset in zip(stations,offsets):
    ds_station = ds.sel(station = station)
    ax.annotate(
        f'{ds_station.station.data}',
        xy=(ds_station.glon, ds_station.glat), xycoords='data',
        xytext=(offset[0],offset[1]), textcoords='offset points',
        arrowprops=dict(arrowstyle="->"),
        fontsize=10)
    
ax.annotate(
    f'Isolated Location',
    xy=(ds_isolated.glon, ds_isolated.glat), xycoords='data',
    xytext=(-25,10), textcoords='offset points',
    arrowprops=dict(arrowstyle="->"),
    fontsize=10)
    
plt.tight_layout()
../../_images/14093905c62291ba2251eb44276f414a70aae125cbe1387a6bed961bf534ff20.png
Hide code cell source
stations = ['Henry','Manuela','Butler Island','Byrd','Relay Station','Dome C','Isolated Location']
labels = ['a.','b.','c.','d.','e.','f.','g.']
fig, axs = plt.subplots(7, 1, figsize=(10, 10),dpi=300)#,frameon=False)

for ax,station,label in zip(axs[:-1],stations[:-1],labels[:-1]):
    ds_climate = ds_climate_stacked_landonly_filtered_nn.sel(nearest_station=station)
    ds_aws = ds_aws_filtered.sel(station=station)

    cdf_c = norm.cdf(
        ds_climate['temperature'],
        ds_climate['exp_mean_climate'],
        np.sqrt(np.exp(ds_climate['exp_logvar_climate'])),
    )

    mean_unbiased_dist = dist.Normal(ds_climate['mean_unbiased_predictive_exp'].data,np.sqrt(ds_climate['mean_unbiased_predictive_var'].data))
    logvar_unbiased_dist = dist.Normal(ds_climate['logvar_unbiased_predictive_exp'].data,np.sqrt(ds_climate['logvar_unbiased_predictive_var'].data))

    mean_unbiased_samples = mean_unbiased_dist.sample(rng_key,(1000,))
    logvar_unbiased_samples = logvar_unbiased_dist.sample(rng_key,(1000,))

    c_corrected = norm.ppf(
        cdf_c,
        mean_unbiased_samples.reshape(-1, 1),
        np.sqrt(np.exp(logvar_unbiased_samples)).reshape(-1, 1),
    )

    ax.annotate(label+station,xy=(0.01,1.02),xycoords='axes fraction')
    ds_climate['temperature'].plot(x='year',
                               ax=ax,
                               marker='+',
                               linestyle="-",
                               linewidth=0.8,
                               zorder=2,
                               label='Climate Model Ouput')
    ds_aws['temperature'].plot(x='year',
                            ax=ax,
                            marker='x',
                            linestyle="-",
                            linewidth=0.8,
                            zorder=2,
                            label='Nearest AWS Ouput')
    ax.plot(ds_climate['year'],
        c_corrected.mean(axis=0),
        color='k',
        marker='+',
        linestyle="-",
        linewidth=0.8,
        zorder=1,
        label='Bias Corrected Ouput Expectation')
    ax.fill_between(
        ds_climate['year'],
        c_corrected.mean(axis=0)
        + 3 * c_corrected.std(axis=0),
        c_corrected.mean(axis=0)
        - 3 * c_corrected.std(axis=0),
        interpolate=True,
        color="k",
        alpha=0.5,
        label="Bias Corrected Output Uncertainty 3$\sigma$",
        linewidth=0.5,
        facecolor="none",
        edgecolor="k",
        linestyle=(0, (5, 2)),
    )
    for corrected_timeseries in c_corrected[::10]:
        ax.plot(
            ds_climate['year'],
            corrected_timeseries,
            color="k",
            alpha=0.2,
            linestyle="-",
            linewidth=0.2,
            zorder=1,
        )

ax=axs[-1]
ds_isolated = ds_climate_stacked_landonly_filtered.isel(x=2050)
cdf_c = norm.cdf(
    ds_isolated['temperature'],
    ds_isolated['exp_mean_climate'],
    np.sqrt(np.exp(ds_isolated['exp_logvar_climate'])),
)

mean_unbiased_dist = dist.Normal(ds_isolated['mean_unbiased_predictive_exp'].data,np.sqrt(ds_isolated['mean_unbiased_predictive_var'].data))
logvar_unbiased_dist = dist.Normal(ds_isolated['logvar_unbiased_predictive_exp'].data,np.sqrt(ds_isolated['logvar_unbiased_predictive_var'].data))


mean_unbiased_samples = mean_unbiased_dist.sample(rng_key,(1000,))
logvar_unbiased_samples = logvar_unbiased_dist.sample(rng_key,(1000,))

c_corrected = norm.ppf(
    cdf_c,
    mean_unbiased_samples.reshape(-1, 1),
    np.sqrt(np.exp(logvar_unbiased_samples)).reshape(-1, 1),
)

ax.annotate('g. Isolated Location',xy=(0.01,1.02),xycoords='axes fraction')

ds_isolated['temperature'].plot(x='year',
                               ax=ax,
                               marker='+',
                               linestyle="-",
                               linewidth=0.8,
                               zorder=2,
                               label='Climate Model Ouput')

ax.plot(ds_isolated['year'],
    c_corrected.mean(axis=0),
    color='k',
    marker='+',
    linestyle="-",
    linewidth=0.8,
    zorder=1,
    label='Bias Corrected Ouput Expectation')
ax.fill_between(
    ds_isolated['year'],
    c_corrected.mean(axis=0)
    + 3 * c_corrected.std(axis=0),
    c_corrected.mean(axis=0)
    - 3 * c_corrected.std(axis=0),
    interpolate=True,
    color="k",
    alpha=0.5,
    label="Bias Corrected Output Uncertainty 3$\sigma$",
    linewidth=0.5,
    facecolor="none",
    edgecolor="k",
    linestyle=(0, (5, 2)),
)
for corrected_timeseries in c_corrected[::10]:
    ax.plot(
        ds_isolated['year'],
        corrected_timeseries,
        color="k",
        alpha=0.2,
        linestyle="-",
        linewidth=0.2,
        zorder=1,
    )

for ax in axs:
    ax.set_title('')
    ax.set_xlim([1978,2022])

for ax in axs[:-1]:
    ax.set_xlabel('')
    ax.set_xticklabels('')

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    fontsize=8,
    bbox_to_anchor=(0.5, -0.02),
    ncols=4,
    loc=10,
)
plt.tight_layout()
../../_images/f9f0da875a2e586e0248ec848233876c620f1fdf7074bfe0aa9ba50810edb407.png

The bias correction seems to be performing reasonably, with a general shift of the time series towards that of the nearby unbiased weather station data and realistic uncertainty bands. It’s also worth noting some interesting features visible in the weather station data including:

  • An apparent temperature anomaly in the Butler Island records;

  • An apparent shift in the mean after a break for the Relay station.

It is likely these are instrumentation errors rather than physically meaningful phenomenon that need to be captured. This would be best resolved with further data preprocessing and highlights the importance of understanding specific features of the real-world data that are hard to account for in modelling. In general the assumption of time-invariant bias seems reasonable and the model is adding value by improving the timeseries agreement with the observed values and by introducing an uncertainty that can be propagated further through physical ice models, providing some range of possible outcomes, which in turn might be used in various cost-decision analysis schemes.

Conclusions#

Hopefully this notebook provides some interesting ideas that are generalisable to many areas of application including:

  • Where we care about uncertainty in the end output and creating a flexible modelling framework that propagates uncertainty in the different components and is explicit about the assumptions made through prior distributions.

  • Where we’re trying to combine multiple datasets with different spatial and temporal coverage.

  • Where we want to explicitly model the underlying covariance between points.

The main limitations of the methodology presented are in the computational complexity of inference. The model in its current form runs slowly even for small datasets such as presented in this notebook. This makes iterative improvements and careful testing of the results challenging. There are various options to speed up the code that still need exploring and if you’re interested in contributing please do so through the repository link. Another important limitation of the code as it stands is that it’s currently lacking rigorous unit testing of functions and that the documentation needs improving. As such it is suggested to takeaway components of the pipeline to utilise on your own projects and to supplement that with official examples from the relevant package repositories such as Numpyro and TinyGP.