Bias Correction of Climate Model Output#
Paper (Chapter 4)
Notebook Repository
Method Repository
Ongoing Development
Primary Contact: Dr. Jeremy Carter | Notebook Top-to-Bottom Runtime ~ 2 Minutes
Physically based climate models are highly-multidimensional and computationally demanding. Simulation runs are often tuned for specific variables and typically observe a non-insignificant bias in other variables. Depending on what variables your specific research question focuses on it’s often desirable to perform postprocessing bias correction. In-situ measurements of variables can be used to help apply a bias correction. Where the in-situ measurements are sparse it’s important to consider uncertainty, depending on factors such as the underlying spatial covariance between points.
Gaussian processes are used to explicitly model spatial covariance between points and to estimate uncertainties when applying bias correction across the whole domain. A Bayesian hierarchical model is constructed to promote uncertainty propagation across the different components of the model.
Running the Notebook:
To run the notebook it is advised to first clone the repository housing the notebook (’git clone NERC-CEH/ds-toolbox-notebook-biascorrection.git’). This will create a folder in the current directory called ds-toolbox-notebook-biascorrection, which holds all the relevant files including the notebook, environment file and relevant input data. The next step is to create a conda environment with the right packages installed using the clean yml file (’conda env create -f environment-clean.yml’), which creates the jbook_BCA environment that can be activated using ‘conda activate jbook_BCA’. At this point the user can either run code from the notebook in their preferred IDE or via the jupyter interface using the command ‘jupyter notebook’. If using VSCode one recommended approach is to use Python interactive windows and copy and paste snippets of code to run from the notebook.
Data Access:
A partially pre-processed example input dataset is used in this tutorial for simplicity. Raw state-of-the-art climate model output for Antarctica can be accessed via the Antarctic CORDEX project and emailing the primary contacts listed. The automatic weather station data can be accessed via the AntAWS Dataset. The post-inference data for this notebook is available at Bias Correction Walkthrough Tutorial Inference Data and is retrieved in the Parameter Inference section.
Statistical Concepts:
Gaussian processes represent probability distributions over a function of values by considering the covariance between points. For example, consider some spatial surface of mean temperature over a region. The values will vary smoothly over the domain and so values of nearby points are spatially correlated. To interpolate across the surface and to express what the likelihood is of the particular surface of values observed you need to consider the spatial correlation. This is what Gaussian processes do by parameterising the covariance as a function of distance between points. Since Gaussian processes represent a probability distribution over the function you can estimate an expectation and also uncertainty in predictions.
Generalisability:
Gaussian processes are highly generalisable as they don’t assume a fixed functional form of the underlying process. That is many different types of spatial and/or temporal surfaces can be modelled adequately through them, including oscillatory functions for example. They are applicable across domains such as regression, classification, time-series modelling and spatial analysis. Gaussian processes are also highly interpretable and resistant to overfitting due to being defined by a limited number of hyper-parameters. Finally, providing a probabilistic approach, they also capture uncertainty in estimates, important for various real-world applications and data science pipelines leading to decision-making. The specific example in this notebook is generalisable to any scenario involving combining multiple spatio-temporal datasets, particularly when a good-coverage biased dataset is combined with a poor-coverage unbiased dataset.
Warning: Computational Demands
While the notebook runs top-to-bottom in <5mins this is achieved through loading in inference results. The method presented in this notebook is very computationally demanding even for small datasets and was developed with the idea of getting maximum utility out of the limited weather station data available over Antarctica. The methodology is still being worked on and if the reader is interested in contributing/raising issues this would be very welcomed via the repository link and issues tab.
It is suggested that the reader takes away some of the core concepts presented in the pipeline and utilises particular code snippets to apply to their specific area of application. The Bayesian modelling framework is highly-flexible and so specific modelling adaptions should be simple to incorporate.
Introduction#
This notebook demonstrates an approach to bias correction that utilises Gaussian Processes (GPs) and a Bayesian hierarchical framework. The method is applied to bias correcting temperature data from a climate model over Antarctica using in-situ observations from automatic weather stations. Since the weather stations are both spatially and temporally sparse it’s important to capture uncertainty in the correction. Further detail is available at: J.Carter Thesis (Chapter 4). The code is in ongoing development and is available at: Bias Correction Application. The datasets and modules needed to run this notebook are available via the jupterbook_render branch of the repository.
Importing Required Libraries and Loading Data#
The data for this tutorial is stored as NetCDF files. This is a common file format for climate model output and is handled well by the Xarray Python package, which can be thought of as a Pandas equivalent for efficient handling of multidimensional data. Xarray loads data as ‘datasets’ and ‘dataarrays’. The model in this tutorial is defined using the Python packages Numpyro and TinyGP, which are compatible. Numpyro provides an intuitive probabilistic programming language for Bayesian statistics, built ontop of JAX and with NumPy based syntax. TinyGP provides a lightweight and intuitive package for working with Gaussian Process objects.
Show code cell content
# Importing libraries
import os
from urllib.request import urlretrieve
import pickle
import timeit
from tqdm import tqdm
import numpy as np
import xarray as xr
import pandas as pd
from scipy.spatial import distance
from scipy.stats import norm
import arviz as az
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from tinygp import kernels, GaussianProcess
from tinygp.kernels.distance import L2Distance
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
import geopandas as gpd
import seaborn as sns
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
rng_key = jax.random.PRNGKey(1)
jax.config.update("jax_enable_x64", True)
c:\Users\jercar\AppData\Local\miniconda3\envs\jbook_BCA\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
# Loading Data
data_path = os.path.join(os.getcwd(),"data","")
ds_aws = xr.open_dataset(f'{data_path}ds_aws.nc') # Automatic Weather Station Data
ds_climate = xr.open_dataset(f'{data_path}ds_climate.nc') # Climate Model Data
Using Xarray datasets provides nice interactive tables of the multidimensional data:
# Displaying the AWS data
ds_aws
Show code cell output
<xarray.Dataset> Size: 925kB Dimensions: (station: 219, t: 504) Coordinates: * station (station) <U22 19kB 'AGO Site' 'AGO-4' ... 'aws16' 'aws17' glat (station) float64 2kB ... glon (station) float64 2kB ... grid_latitude (station) float64 2kB ... grid_longitude (station) float64 2kB ... year (t) float64 4kB ... month (t) float64 4kB ... * t (t) float64 4kB 1.0 2.0 3.0 4.0 ... 501.0 502.0 503.0 504.0 Data variables: latitude (station) float64 2kB ... elevation (station) float64 2kB ... temperature (station, t) float64 883kB ...
# Displaying the climate model data
ds_climate
Show code cell output
<xarray.Dataset> Size: 45MB Dimensions: (time: 456, grid_longitude: 126, grid_latitude: 98) Coordinates: * grid_longitude (grid_longitude) float64 1kB 152.4 152.9 ... 207.0 207.4 * grid_latitude (grid_latitude) float64 784B -21.39 -20.95 ... 20.84 21.29 * time (time) datetime64[ns] 4kB 1981-01-31 ... 2018-12-31 month (time) int64 4kB ... year (time) int64 4kB ... glon (grid_longitude, grid_latitude) float64 99kB ... glat (grid_longitude, grid_latitude) float64 99kB ... t (time) int64 4kB ... Data variables: temperature (time, grid_longitude, grid_latitude) float64 45MB ... elevation (grid_longitude, grid_latitude) float64 99kB ... latitude (grid_longitude, grid_latitude) float64 99kB ...
# Computing basic summary statistics using the pandas port of xarray objects
print('Summary of Automatic Weather Station Data \n',
ds_aws.to_dataframe().describe()[['elevation','latitude','temperature']])
print('\n Summary of Climate Model Data \n',
ds_climate.to_dataframe().describe()[['elevation','latitude','temperature']])
Show code cell output
Summary of Automatic Weather Station Data
elevation latitude temperature
count 110376.000000 110376.000000 18088.000000
mean 1251.009132 -76.472009 -25.852952
std 1131.183380 5.411963 14.071861
min 5.000000 -90.000000 -71.740000
25% 87.000000 -79.820000 -31.880000
50% 1122.000000 -76.320000 -24.135000
75% 2090.000000 -73.080000 -16.320000
max 4093.000000 -65.240000 1.750000
Summary of Climate Model Data
elevation latitude temperature
count 2.610144e+06 2.610144e+06 2.610144e+06
mean 2.003590e+03 -7.661032e+01 -3.257423e+01
std 1.150357e+03 5.392458e+00 1.458319e+01
min -3.087963e+00 -8.971554e+01 -7.326862e+01
25% 1.045115e+03 -8.060204e+01 -4.373946e+01
50% 2.192235e+03 -7.646343e+01 -3.125523e+01
75% 2.988649e+03 -7.229835e+01 -2.182017e+01
max 4.063502e+03 -6.397399e+01 1.437534e+00
Data Exploration#
Initial data exploration is essential for informing our model construction. It includes examining the spatial and temporal distributions of the data and relationships between variables and potential predictors. To start we’ll examine the spatial distribution of weather stations, plotting over the grid of the climate model.
Spatial Distribution of Weather Stations#
Coordinate Systems:
Note that for plotting we use a specific rotated coordinate system (defined below). The ‘glon’ and ‘glat’ fields are in this coordinate system and it was created to get around issues assoicated with plotting the shapefile when the longitude flips between -180 to 180.
Show code cell source
# Loading ice sheet shapefile
icesheet_shapefile_path = os.path.join(data_path,"icesheet_shapefile","icesheet.shp")
gdf_icesheet = gpd.read_file(icesheet_shapefile_path)
# Defining rotated coordinate system (glon,glat) and converting ice sheet shapefile to rotated coordinates
rotated_coord_system = ccrs.RotatedGeodetic(
13.079999923706055,
0.5199999809265137,
central_rotated_longitude=180.0,
globe=None,
)
gdf_icesheet_rotatedcoords = gdf_icesheet.to_crs(rotated_coord_system)
############################################################################################################
# Defining background map function
def background_map_rotatedcoords(ax):
gdf_icesheet_rotatedcoords.boundary.plot(
ax=ax,
color='k',
linewidth=0.3,
alpha=0.4)
ax.set_axis_off()
# Defining marker size legend function
def markersize_legend(ax, bins, scale_multipler, legend_fontsize=10,loc=3,ncols=1,columnspacing=0.8,handletextpad=0.1,bbox=(0.,0.)):
ax.add_artist(
ax.legend(
handles=[
mlines.Line2D(
[],
[],
color="tab:blue",
markeredgecolor="k",
markeredgewidth=0.3,
lw=0,
marker="o",
markersize=np.sqrt(b*scale_multipler),
label=str(int(b)),
)
for i, b in enumerate(bins)
],
loc=loc,
fontsize = legend_fontsize,
ncols=ncols,
columnspacing=columnspacing,
handletextpad=handletextpad,
bbox_to_anchor=bbox,
framealpha=0,
)
)
############################################################################################################
# Plotting the weather station locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)
background_map_rotatedcoords(ax)
ax.scatter(
ds_aws.glon,
ds_aws.glat,
s=ds_aws.count('t')['temperature']/5,
edgecolor='k',
linewidths=0.5,
)
ax.annotate(
'Number of records:',
xy=(0.08, 0.1), xycoords='axes fraction',
fontsize=10)
markersize_legend(ax, [1,10,30,50,100,200,300,400], scale_multipler=1/5, legend_fontsize=10,loc=3,ncols=9,columnspacing=0.3,handletextpad=-0.4,bbox=(0.08,0.05))
# highlighting individual station
station = 'Manuela'
ax.annotate(
f'{station} Weather Station',
xy=(ds_aws.sel(station = station).glon, ds_aws.sel(station = station).glat), xycoords='data',
xytext=(-140,-15), textcoords='offset points',
arrowprops=dict(arrowstyle="->"),
fontsize=10)
# Plotting the climate model grid
ds_climate['temperature'].mean('time').notnull().plot.pcolormesh(
x='glon',
y='glat',
ax=ax,
alpha=0.05,
add_colorbar=False,
edgecolor='k',
linewidth=0.3,
)
plt.tight_layout()
plt.show()

It is clear that the spatial distribution of weather stations is not uniform over the domain. There are certain regions containing high-density clusters of stations. This will induce a bias in the bias-correction itself when conducted over the whole domain, with corrections skewed towards these regions. If these clusters occurred randomly then modelling the spatial correlation between sites would be adequate to account for the distribution. Instead it is likely these regions were chosen for specific features, such as having anomalously high temperatures, which is difficult to account for in the model. We don’t directly attempt to solve this problem but it is noted as a remaining limitation.
Time Series and Distribution of Single Weather Station (Manuela)#
Comparisons are made between the time series for the Manuela weather station and the nearest grid-cell of the climate model output. The time series represents the average temperature for each month and the values are aggregated from hourly measurements of the raw data (preprocessed in this notebook for simplicity).
Show code cell source
# Computing Nearest Neighbours
ds_climate_stacked = ds_climate.stack(x=('grid_longitude', 'grid_latitude'))
ds_climate_stacked_landonly = ds_climate_stacked.dropna('x')
ox = np.dstack([ds_aws['glon'],ds_aws['glat']])[0]
cx = np.dstack([ds_climate_stacked_landonly['glon'],ds_climate_stacked_landonly['glat']])[0]
nn_indecies = []
for point in ox:
nn_indecies.append(distance.cdist([point], cx).argmin())
ds_climate_nearest_stacked = ds_climate_stacked_landonly.isel(x=nn_indecies)
ds_climate_nearest_stacked = ds_climate_nearest_stacked.assign_coords(nearest_station=("x", ds_aws.station.data))
ds_climate_nearest_stacked = ds_climate_nearest_stacked.swap_dims({"x": "nearest_station"})
# Single Site Full Time Series
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)
station = 'Manuela'
ds_climate_nearest_stacked.sel(nearest_station = station)['temperature'].plot(x="t",
ax=ax,
hue='station',
alpha=0.7,
label='Climate Model Output (Nearest Grid-Cell)',
marker='x',
ms=1,
color='tab:blue',
linewidth=1.0)
ds_aws.sel(station = station)['temperature'].plot(ax=ax,
hue='station',
alpha=0.7,
label=f'{station} Weather Station',
marker='x',
ms=1,
color='tab:orange',
linewidth=1.5)
xticks = np.arange(0,45*12,12*5)
xticklabels = np.arange(1980,2025,5)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_ylabel('Temperature')
ax.set_xlabel('Time')
ax.legend()
ax.set_title('')
plt.tight_layout()
plt.show()

The Manuela weather station has one of the highest numbers of temperature records, spanning from 1984-2021. The time series for the climate model spans 1981-2019. It’s clear that the variance in the time series are dominated by the seasonal cycle and that any bias in for example the mean will have some seasonal dependency. The PDFs for the 2 time series are plot below.
Show code cell source
# Probability Density Function (all months)
station = 'Manuela'
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)
ds_aws.sel(station=station).to_dataframe()[['temperature']].hist(bins=40,
ax=ax,
edgecolor='k',
linewidth=0.2,
grid=False,
density=1,
alpha=0.7,
label = f'{station} Weather Station',
)
ds_climate_nearest_stacked.sel(nearest_station=station).to_dataframe()[['temperature']].hist(bins=40,
ax=ax,
edgecolor='k',
linewidth=0.2,
grid=False,
density=1,
alpha=0.7,
label = 'Climate Model Output (Nearest Grid-Cell)',
)
ax.annotate('All Months',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
plt.legend()
plt.tight_layout()
plt.show()

In this tutorial we define bias with respect to differences between the parameters that describe the PDFs of the time series. The PDF above is multi-modal, reflecting the seasonality of the data, meaning we’d need to use quite a few parameters to adequately describe the distribution. Simply using the mean would have limited value as we can see that the winter peak is ~5° higher for the climate model output while the summer peaks are approximately equal. The common approach here is to simply split the time series up by the month, focusing on defining bias for each month separately. The PDFs for individual months are approximately Gaussian (see below) and so the bias can be defined in terms of differences in the mean and variance parameters between the datasets.
Show code cell source
# Probability Density Function by Month
station = 'Manuela'
fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=100)
sns.histplot(data=ds_aws.sel(station=station).to_dataframe()[['month','temperature']],
x='temperature',
hue='month',
bins=40,
ax=ax,
edgecolor='k',
linewidth=0.2,
kde=True,
palette='Paired',
)
ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
plt.legend(['January','February','March','April','May','June','July','August','September','October','November','December'])
plt.tight_layout()
plt.show()

In this tutorial, we’ll focus on just applying bias correction to the monthly June time series, shown below for the Manuela station.
Show code cell source
month = 6
station = 'Manuela'
fig, axs = plt.subplots(1, 2, figsize=(10, 4), dpi=100)
ax = axs[0]
ds_climate_nearest_stacked.sel(nearest_station = station).where(ds_climate_nearest_stacked['month']==month,drop=True)['temperature'].plot(
x='t',
ax=ax,
hue='station',
alpha=0.7,
label='Climate Model Output (Nearest Grid-Cell)',
marker='x',
ms=1,
color='tab:orange',
linewidth=1.0)
ds_aws.sel(station = station).where(ds_aws['month']==month,drop=True)['temperature'].plot(ax=ax,
hue='station',
alpha=0.7,
label=f'{station} Weather Station',
marker='x',
ms=1,
color='tab:blue',
linewidth=1.5)
ax.annotate('June',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)#,rotation = 90)
ax.set_ylabel('Temperature')
ax.set_xlabel('Time')
ax.legend(fontsize=8)
ax.set_title('')
ax=axs[1]
ds_aws.sel(station=station).where(ds_aws['month']==month,drop=True).to_dataframe()[['temperature']].hist(bins=12,
ax=ax,
edgecolor='k',
linewidth=0.2,
grid=False,
density=1,
alpha=0.7,
label = f'{station} Weather Station',
)
ds_climate_nearest_stacked.sel(nearest_station=station).where(ds_climate_nearest_stacked['month']==month,drop=True).to_dataframe()[['temperature']].hist(bins=12,
ax=ax,
edgecolor='k',
linewidth=0.2,
grid=False,
density=1,
alpha=0.7,
label = 'Climate Model Output (Nearest Grid-Cell)',
)
ax.annotate('June',xy=(0.03,0.95),xycoords='axes fraction')
ax.set_title('')
ax.set_xlabel('Temperature')
ax.set_ylabel('Density')
ax.legend(fontsize=8)
plt.tight_layout()
plt.show()

When examining the June time series, it’s clear that the climate model performs well at capturing the yearly variability and is highly correlated with the weather station output. However, there’s a significant bias in the mean of approximately 4°, indicating the utility of applying a bias correction.
Pairplots#
Since we’re interested in evaluating bias in the mean and variance of the June time series, it’s useful to think about predictors that have an influence on these metrics. The two most obvious predictors are elevation and latitude, both of which have well-understood physical justification for the impact on the mean temperature.
Filtering AWS Data:
We filter to just June records as discussed above and also to only stations with more than 5 records. This limits the influence of uncertainty in the empirical estimates used for exploratory analysis.
Nearest Grid Cells:
We’ve included a subcategory ‘Climate Model Nearest’ that represent values for just the climate model grid-cells nearest to each AWS. This is to get a feel for how representative the AWS samples are of Antarctic-wide region (e.g. it might be that differences between the datasets are just the result of the particular sample of locations with weather stations).
Show code cell source
# Filtering to June only
ds_aws_filtered = ds_aws.sel(t = ds_aws.month == 6)
ds_climate_stacked_landonly_filtered = ds_climate_stacked_landonly.sel(time = ds_climate_stacked_landonly.month == 6)
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked.sel(time = ds_climate_nearest_stacked.month == 6)
# Filtering AWS by number of records and updating nearest grid-cells to match
ds_aws_filtered['records'] = ds_aws_filtered.count('t')['temperature']
stations_recordsfilter = ds_aws_filtered.where(ds_aws_filtered['records']>5,drop=True)['station'].data
ds_aws_filtered = ds_aws_filtered.sel(station=stations_recordsfilter)
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked_filtered.sel(nearest_station=stations_recordsfilter)
# Computing June Mean and Standard Deviation
ds_aws_filtered['june_mean_temperature'] = ds_aws_filtered.mean('t')['temperature']
ds_aws_filtered['june_std_temperature'] = ds_aws_filtered.std('t')['temperature']
ds_climate_stacked_landonly_filtered['june_mean_temperature'] = ds_climate_stacked_landonly_filtered.mean('time')['temperature']
ds_climate_stacked_landonly_filtered['june_std_temperature'] = ds_climate_stacked_landonly_filtered.std('time')['temperature']
ds_climate_nearest_stacked_filtered['june_mean_temperature'] = ds_climate_nearest_stacked_filtered.mean('time')['temperature']
ds_climate_nearest_stacked_filtered['june_std_temperature'] = ds_climate_nearest_stacked_filtered.std('time')['temperature']
# Transforming data for plotting with seaborn PairGrid
vars = ['elevation','latitude','june_mean_temperature','june_std_temperature']
df_climate_filtered = ds_climate_stacked_landonly_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_climate_nearest_filtered = ds_climate_nearest_stacked_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_aws_filtered = ds_aws_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
df_climate_filtered['source'] = 'Climate Model'
df_climate_nearest_filtered['source'] = 'Climate Model Nearest'
df_aws_filtered['source'] = 'AWS'
df_combined = pd.concat([df_climate_filtered,df_climate_nearest_filtered,df_aws_filtered],axis=0).reset_index(drop=True)
# Plotting PairGrid with regression lines
g = sns.PairGrid(df_combined, hue='source',diag_sharey=False, corner=True)
reg_kws = {'scatter': False, 'line_kws':{'linewidth':1}}
g.map_lower(sns.regplot,**reg_kws)
g.add_legend(bbox_to_anchor=(0.8,0.8),markerscale=3)
g.hue_kws = {'marker':['+','x','*'],'s':[2,5,2],'alpha':[0.2,1,1]}
scatter_kws = {'linewidth':0.8}
g.map_lower(plt.scatter,**scatter_kws)
hist_kws = {'common_norm':False,'stat':'density'}
g.map_diag(sns.histplot,**hist_kws)
plt.show()

The pairplots bring up some interesting features:
Comparisons of the histograms between the ‘Climate Model’ data (all grid-cells over Antarctica) and the ‘Climate Model Nearest’ data (only grid-cells closest to the weather stations) indicate that the AWS sites are not a particularly representative sample of the whole Antarctic region. There’s a higher proportion of sites at zero elevation, clusters of sites at particular locations (and so latitudes), disproportionality high numbers of sites at regions with relatively high mean temperatures and relatively low standard deviations.
There are clear relationships between mean temperature with elevation and latitude. The slope of the linear relationship does not seem too strongly impacted by the particular subsample of AWS locations.
There’s only a weak relationship between the standard deviation in temperature with elevation and latitude. As a result we’ll leave out these predictors when estimating the standard deviation across the domain.
The behaviour of the relationship between elevation with mean temperature appears quite different at zero elevation sites, which could be linked with various factors such as the proximity and impact of the nearby sea on zero elevation sites. This potentially indicates at the utility of incorporating a distance to the coast predictor, although for this tutorial we leave this out.
Examining spatial covariance after removing influence of elevation and latitude#
The spatial pattern in mean temperature is currently dominated by the relationship with elevation and latitude. While these are clearly important predictors for mean temperature, it’s expected that there’ll be various other important factors that impact mean temperature but are harder to account for (e.g. the funnelling of wind down valleys will impact the mean temperature). One way of at least partially accounting for these factors is to model the spatial covariance between sites after removing the influence of elevation and latitude. That is that nearby sites are likely to be highly correlated as the factors impacting them are similar, whereas the further away you go the less correlated the sites will be (different valleys with different wind patterns etc).
Here we’ll plot the spatial pattern in the mean temperature after removing the linear influence of elevation and latitude. Additionally, we’ll plot the spatial pattern in the log of the standard deviation (taking the log to get the metric on the -\(\infty,\infty\) domain). Since the relationship between elevation and latitude with standard deviation appeared weak, we’ll ignore these predictors for this metric and simply remove a constant to get the zero-mean log(standard deviations).
Show code cell content
# Modelling the linear relationship between elevation, latitude and June mean temperature
# Defining predictors and scaling
predictors = ['elevation','latitude']
scaled_predictors = [i+'_scaled' for i in predictors]
target = 'june_mean_temperature'
scaler = StandardScaler()
df_aws_filtered_scaled_predictors = pd.DataFrame(scaler.fit_transform(df_aws_filtered[predictors]),columns=scaled_predictors)
df_climate_filtered_scaled_predictors = pd.DataFrame(scaler.transform(df_climate_filtered[predictors]),columns=scaled_predictors)
# Linear Regression AWS
print("Linear Regression AWS:")
model = LinearRegression()
model.fit(df_aws_filtered_scaled_predictors,df_aws_filtered[target])
df_aws_filtered['june_mean_temperature_predicted_lr'] = model.predict(df_aws_filtered_scaled_predictors)
feature_importance = pd.Series(model.coef_, index=scaled_predictors)
print('intercept:',model.intercept_)
print(feature_importance.sort_values(ascending=False,key=abs))
# Linear Regression Climate Model
print("\n Linear Regression Climate Model:")
model = LinearRegression()
model.fit(df_climate_filtered_scaled_predictors,df_climate_filtered[target])
df_climate_filtered['june_mean_temperature_predicted_lr'] = model.predict(df_climate_filtered_scaled_predictors)
feature_importance = pd.Series(model.coef_, index=scaled_predictors)
print('intercept:',model.intercept_)
print(feature_importance.sort_values(ascending=False,key=abs))
ds_climate_stacked_landonly_filtered['june_mean_temperature_predicted_lr'] = (
('x'),
df_climate_filtered['june_mean_temperature_predicted_lr'])
ds_aws_filtered['june_mean_temperature_predicted_lr'] = (
('station'),
df_aws_filtered['june_mean_temperature_predicted_lr'])
ds_climate_stacked_landonly_filtered['june_mean_temperature_residual_lr']=ds_climate_stacked_landonly_filtered['june_mean_temperature']-ds_climate_stacked_landonly_filtered['june_mean_temperature_predicted_lr']
ds_aws_filtered['june_mean_temperature_residual_lr']=ds_aws_filtered['june_mean_temperature']-ds_aws_filtered['june_mean_temperature_predicted_lr']
ds_climate_stacked = xr.merge([ds_climate_stacked,ds_climate_stacked_landonly_filtered])
Linear Regression AWS:
intercept: -33.625017170326736
elevation_scaled -10.160912
latitude_scaled 1.952785
dtype: float64
Linear Regression Climate Model:
intercept: -33.0945203068335
elevation_scaled -10.223010
latitude_scaled 2.895803
dtype: float64
Show code cell content
# Computing the zero-mean log(standard deviation)
ds_aws_filtered['june_logstd_temperature']=np.log(ds_aws_filtered['june_std_temperature'])
ds_climate_stacked['june_logstd_temperature']=np.log(ds_climate_stacked['june_std_temperature'])
ds_climate_nearest_stacked_filtered['june_logstd_temperature']=np.log(ds_climate_nearest_stacked_filtered['june_std_temperature'])
ds_aws_filtered['june_logstd_temperature_residual_constant'] = ds_aws_filtered['june_logstd_temperature'] - ds_aws_filtered['june_logstd_temperature'].mean()
ds_climate_stacked['june_logstd_temperature_residual_constant'] = ds_climate_stacked['june_logstd_temperature'] - ds_climate_stacked['june_logstd_temperature'].mean()
ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant']=ds_climate_nearest_stacked_filtered['june_logstd_temperature'] - ds_climate_nearest_stacked_filtered['june_logstd_temperature'].mean()
Show code cell source
# Plotting the linear regression prediction and residual
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)#,frameon=False)
metrics = ['june_mean_temperature_predicted_lr','june_mean_temperature_residual_lr','june_logstd_temperature','june_logstd_temperature_residual_constant']
cmaps = ['viridis','RdBu','viridis','RdBu']
vminmaxs = [(-60,-15),(-10,10),(0.4,1.6),(-0.6,0.6)]
labels = ['June Mean Temperature Predicted (Linear Regression)','June Mean Temperature Residual (Linear Regression)',
'June Log(Standard Deviation)','June Zero-Mean Log(Standard Deviation)']
for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
background_map_rotatedcoords(ax)
ds_climate_stacked[metric].unstack().plot.pcolormesh(
x='glon',
y='glat',
ax=ax,
alpha=0.9,
vmin=vminmax[0],
vmax=vminmax[1],
cmap=cmap,
cbar_kwargs = {'fraction':0.030,
'pad':0.02,
'label':label}
)
ax.scatter(
ds_aws_filtered['glon'],
ds_aws_filtered['glat'],
marker="o",
c=ds_aws_filtered[metric],
cmap=cmap,
edgecolor="w",
linewidth=1.0,
vmin=vminmax[0],
vmax=vminmax[1],
)
ax.set_axis_off()
ax.set_title('')
plt.tight_layout()

It’s clear once we remove the linear dependency of temperature with elevation and latitude we get quite a different spatial structure for the mean temperature. This is the spatial structure we’ll want to model using Gaussian processes, where the covariance between nearby sites is captured by parameterising the covariance as a function of distance. In our model we’ll assume a single length scale for the covariance, that is to say we assume the covariance between nearby sites decays at the same distance wherever you are located over Antarctica. While this assumption is clearly broken in certain areas (e.g. covariance in steep regions near to the coast behaves differently to flat regions inland), it still provides a start and is a lot simpler than the approach of considering non-stationary lengthscales across the region (although this is possible). It’s important to consider how this will impact our results and we expect one of the main influences is that the noise term estimated for our Gaussian process will be relatively high to account for the sharp variations between nearby sites in particularly steep regions.
It’s also important to note that the spatial patterns in both the AWS data and climate model data are similar for each metric. That is the climate model is doing a reasonable job at capturing the more complex dependencies of mean temperature and log(standard dev) in temperature. To utilise this, in our model we consider a shared latent Gaussian process between the datasets and so predictions of the unbiased values are made conditioning on both datasets.
Examining spatial covariance in the bias#
It’s also useful to explore the spatial structure of the bias (both in the mean and log(standard deviation) of June temperature). In this exploratory analysis we do this by examining the empirical values of the metrics from the AWS data and the nearest climate model grid-cells.
For the mean temperature, this is done after accounting for the linear relationship with elevation and latitude (that is we examine the spatial structure in the bias of the residuals). For the log(standard deviation), we examine bias in the zero-meaned values for each dataset.
Show code cell content
# Recomputing nearest neighbours
ds_climate_nearest_stacked = ds_climate_stacked_landonly_filtered.isel(x=nn_indecies)
ds_climate_nearest_stacked = ds_climate_nearest_stacked.assign_coords(nearest_station=("x", ds_aws.station.data))
ds_climate_nearest_stacked = ds_climate_nearest_stacked.swap_dims({"x": "nearest_station"})
ds_climate_nearest_stacked_filtered = ds_climate_nearest_stacked.sel(nearest_station=stations_recordsfilter)
# Recomputing the zero-mean log(standard deviation)
ds_climate_nearest_stacked_filtered['june_logstd_temperature']=np.log(ds_climate_nearest_stacked_filtered['june_std_temperature'])
ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant']=ds_climate_nearest_stacked_filtered['june_logstd_temperature'] - ds_climate_nearest_stacked_filtered['june_logstd_temperature'].mean()
# Evaluating bias in residuals from linear regression for the mean temperature
ds_climate_nearest_stacked_filtered['bias_june_mean_temperature_residual_lr'] = (
('nearest_station'),
ds_climate_nearest_stacked_filtered['june_mean_temperature_residual_lr'].data - ds_aws_filtered['june_mean_temperature_residual_lr'].data)
# Evaluating bias for the log(std) temperature
ds_climate_nearest_stacked_filtered['bias_june_logstd_temperature_residual_constant'] = (
('nearest_station'),
ds_climate_nearest_stacked_filtered['june_logstd_temperature_residual_constant'].data - ds_aws_filtered['june_logstd_temperature_residual_constant'].data)
Show code cell source
# Plotting the linear regression residual and bias in the residual for the mean temperature
fig, axs = plt.subplots(1, 2, figsize=(15, 5),dpi=100)#,frameon=False)
for ax in axs:
background_map_rotatedcoords(ax)
ax.set_axis_off()
ax.set_title('')
ax=axs[0]
plot1 = ax.scatter(
ds_climate_nearest_stacked_filtered['glon'],
ds_climate_nearest_stacked_filtered['glat'],
marker="o",
c=ds_climate_nearest_stacked_filtered['bias_june_mean_temperature_residual_lr'],
cmap='RdBu',
edgecolor="w",
linewidth=1.0,
vmin=-5,
vmax=5,
)
plt.colorbar(plot1, ax=ax, fraction=0.03, pad=0.02,label='Bias in June Mean Temperature Residual \n (Linear Regression)')#, orientation='horizontal', label=metric)
ax=axs[1]
plot2 = ax.scatter(
ds_climate_nearest_stacked_filtered['glon'],
ds_climate_nearest_stacked_filtered['glat'],
marker="o",
c=ds_climate_nearest_stacked_filtered['bias_june_logstd_temperature_residual_constant'],
cmap='RdBu',
edgecolor="w",
linewidth=1.0,
vmin=-0.5,
vmax=0.5,
)
plt.colorbar(plot2, ax=ax, fraction=0.03, pad=0.02,label='Bias in June Log(Std) Temperature Residual \n (Constant)')#, orientation='horizontal', label=metric)
plt.tight_layout()

There’s clearly a spatial covariance pattern in the bias for both the mean and log(standard dev.). The length scale at which the covariance decays for the bias appears longer than for the raw value of the metrics of each dataset.
Examining relationships between variables and the bias in parameters#
It’s interesting to check whether there’s any obvious relationships between the biased parameter values and predictors such as elevation and latitude. Additionally, it’s worth checking if there’s a relationship between the biased and unbiased parameter values. We do this through a partial pairplot as shown below.
Show code cell source
# Transforming data for plotting with seaborn PairGrid
vars = ['elevation','latitude','june_mean_temperature','june_std_temperature','bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant']
x_vars = ['bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant','elevation','latitude','june_mean_temperature','june_std_temperature']
y_vars = ['bias_june_mean_temperature_residual_lr','bias_june_logstd_temperature_residual_constant']
df_bias_climate = ds_climate_nearest_stacked_filtered[vars].to_dataframe()[vars].reset_index(drop=True)
# Plotting PairGrid with regression lines
g = sns.PairGrid(df_bias_climate, x_vars=x_vars,y_vars=y_vars,height=4,diag_sharey=False)#, corner=True)
reg_kws = {'scatter': False, 'line_kws':{'linewidth':1}}
g.map_offdiag(sns.regplot,**reg_kws)
g.hue_kws = {'marker':['+'],'s':[5],'alpha':[1.0]}
scatter_kws = {'linewidth':0.8}
g.map_offdiag(plt.scatter,**scatter_kws)
hist_kws = {'common_norm':False,'stat':'density'}
g.map_diag(sns.histplot,**hist_kws)
plt.show()

The above pairplot only shows weak relationships between the bias in parameters and the other variables. Therefore, in the model we’ll assume the bias is generated from an independent underlying process.
Data Preprocessing#
The main pre-processing steps we’ll do are simply scaling the elevation and latitude predictors and removing any AWS sites with only 2 records for average June temperature. We also define a random subsample of the climate model grid-cells, which we’ll use for inference on the parameters of the Gaussian processes in order to reduce the computational demands. Additionally, all the data is transformed away from Xarray and into a dictionary of device arrays (JAX versions of Numpy arrays) of the right shape that the inference package Numpyro expects.
Computational Complexity and Subsampling:
While Gaussian processes are extremely useful and flexible, their high computational demand of is often described as their Achilles’ heel. Typically, the time for inference scales as the cube of the number of data points, which is a result of taking the inverse of the covariance matrix when computing the likelihood of a set of data for given hyper-parameters. However, there are various methods for improving the speed of inference, such as sparse variational Gaussian processes (SVGPs). These are not utilised in this notebook as they require some additional complexity. Instead we do the sub-optimal procedure of simply only using a subsample of the climate model data for inference. This approach is clearly sub-optimal as it doesn’t target the regions of most interest, however, hopefully it’ll still produce sensible results and be sufficient for this tutorial example.
Additional Preprocessing:
Note that some pre-filtering has already been conducted to the data made available for the notebook, such as removing weather stations located on islands and filtering the climate model output to land onlu etc. These steps can be found in the preprocessing.py file of the notebook repository. Additionally, while not done in this notebook, it’s important to mention that typically feature engineering is an important component of preprocessing and we could for example derive a distance to the coast predictor.
Show code cell content
# Filtering to June records and stations with more than 2 records
aws_june_filter = ds_aws.where(ds_aws['month']==6,drop=True)['t'].data
ds_aws_preprocessed = ds_aws.sel(t=aws_june_filter)
aws_stations_recordsfilter = ds_aws_preprocessed.where(ds_aws_preprocessed['temperature'].count(['t'])>2,drop=True)['station'].data
ds_aws_preprocessed = ds_aws_preprocessed.sel(station=aws_stations_recordsfilter)
climate_june_filter = ds_climate.where(ds_climate['month']==6,drop=True)['time'].data
ds_climate_preprocessed = ds_climate.sel(time=climate_june_filter)
ds_climate_preprocessed = ds_climate_preprocessed.stack(x=('grid_longitude', 'grid_latitude'))
ds_climate_preprocessed = ds_climate_preprocessed.dropna('x')
random_sample = np.random.choice(np.arange(len(ds_climate_preprocessed['x'])), size=100, replace=False)
ds_climate_preprocessed_sample = ds_climate_preprocessed.isel(x=random_sample)
# Scaling latitude and elevation
lat_scalar = StandardScaler()
ele_scalar = StandardScaler()
ds_aws_preprocessed['latitude_scaled'] = (['station'], lat_scalar.fit_transform(ds_aws_preprocessed['latitude'].data.reshape(-1,1))[:,0])
ds_aws_preprocessed['elevation_scaled'] = (['station'], ele_scalar.fit_transform(ds_aws_preprocessed['elevation'].data.reshape(-1,1))[:,0])
ds_climate_preprocessed['latitude_scaled'] = (['x'], lat_scalar.transform(ds_climate_preprocessed['latitude'].data.reshape(-1,1))[:,0])
ds_climate_preprocessed['elevation_scaled'] = (['x'], ele_scalar.transform(ds_climate_preprocessed['elevation'].data.reshape(-1,1))[:,0])
# Transforming into dictionary of device arrays
ox = jnp.array(np.dstack([ds_aws_preprocessed['glon'],ds_aws_preprocessed['glat']]))[0]
odata = jnp.array(ds_aws_preprocessed['temperature'].values).transpose()
olat = jnp.array(ds_aws_preprocessed['latitude'].values)
oele = jnp.array(ds_aws_preprocessed['elevation'].values)
olat_scaled = jnp.array(ds_aws_preprocessed['latitude_scaled'].values)
oele_scaled = jnp.array(ds_aws_preprocessed['elevation_scaled'].values)
cx = jnp.array(np.dstack([ds_climate_preprocessed['glon'],ds_climate_preprocessed['glat']]))[0]
cdata = jnp.array(ds_climate_preprocessed.transpose()['temperature'].values).transpose()
clat = jnp.array(ds_climate_preprocessed['latitude'].values)
cele = jnp.array(ds_climate_preprocessed['elevation'].values)
clat_scaled = jnp.array(ds_climate_preprocessed['latitude_scaled'].values)
cele_scaled = jnp.array(ds_climate_preprocessed['elevation_scaled'].values)
cx_subsample = cx[random_sample]
cdata_subsample = cdata[:,random_sample]
cele_subsample = cele[random_sample]
clat_subsample = clat[random_sample]
cele_scaled_subsample = cele_scaled[random_sample]
clat_scaled_subsample = clat_scaled[random_sample]
data_dictionary = {
'ds_aws_preprocessed':ds_aws_preprocessed,
'ds_climate_preprocessed':ds_climate_preprocessed,
'ds_climate_preprocessed_sample':ds_climate_preprocessed_sample,
'ox':ox,
'odata':odata,
'olat':olat,
'oele':oele,
'olat_scaled':jnp.array(olat_scaled),
'oele_scaled':jnp.array(oele_scaled),
'cx':cx,
'cdata':cdata,
'clat':clat,
'cele':cele,
'clat_scaled':jnp.array(clat_scaled),
'cele_scaled':jnp.array(cele_scaled),
'cx_subsample':cx_subsample,
'cdata_subsample':cdata_subsample,
'cele_subsample':cele_subsample,
'clat_subsample':clat_subsample,
'cele_scaled_subsample':jnp.array(cele_scaled_subsample),
'clat_scaled_subsample':jnp.array(clat_scaled_subsample),
'ele_scaler':ele_scalar,
'lat_scaler':lat_scalar,
'random_sample':random_sample,
}
It’s worth plotting the locations of the sampled climate model grid-cells and it’s also worth performing a quick sanity check that the data is in the right format:
Show code cell content
#Sanity Check
print('Shapes:')
for key in data_dictionary.keys():
if key not in ['ds_aws_preprocessed','ds_climate_preprocessed','ds_climate_preprocessed_sample','ele_scaler','lat_scaler']:
print(f'{key} shape: {data_dictionary[key].shape}')
print('\n Values:')
for key in data_dictionary.keys():
if key not in ['ds_aws_preprocessed','ds_climate_preprocessed','ds_climate_preprocessed_sample','ele_scaler','lat_scaler']:
if key=='odata':
print(f'{key} min={np.nanmin(data_dictionary[key]):.1f}, mean={np.nanmean(data_dictionary[key]):.1f}, max={np.nanmax(data_dictionary[key]):.1f}')
else:
print(f'{key} min={data_dictionary[key].min():.1f}, mean={data_dictionary[key].mean():.1f}, max={data_dictionary[key].max():.1f}')
print('\n Types:')
for key in data_dictionary.keys():
print(f'{key} type: {type(data_dictionary[key])}')
Shapes:
ox shape: (156, 2)
odata shape: (42, 156)
olat shape: (156,)
oele shape: (156,)
olat_scaled shape: (156,)
oele_scaled shape: (156,)
cx shape: (5724, 2)
cdata shape: (38, 5724)
clat shape: (5724,)
cele shape: (5724,)
clat_scaled shape: (5724,)
cele_scaled shape: (5724,)
cx_subsample shape: (100, 2)
cdata_subsample shape: (38, 100)
cele_subsample shape: (100,)
clat_subsample shape: (100,)
cele_scaled_subsample shape: (100,)
clat_scaled_subsample shape: (100,)
random_sample shape: (100,)
Values:
ox min=-24.0, mean=-1.7, max=21.5
odata min=-70.3, mean=-33.4, max=-7.9
olat min=-90.0, mean=-76.4, max=-65.2
oele min=5.0, mean=1246.4, max=4093.0
olat_scaled min=-2.6, mean=-0.0, max=2.1
oele_scaled min=-1.1, mean=0.0, max=2.5
cx min=-24.9, mean=2.5, max=24.4
cdata min=-73.3, mean=-39.6, max=-7.0
clat min=-89.7, mean=-76.6, max=-64.0
cele min=-3.1, mean=2003.6, max=4063.5
clat_scaled min=-2.5, mean=-0.0, max=2.4
cele_scaled min=-1.1, mean=0.7, max=2.5
cx_subsample min=-19.6, mean=2.4, max=23.5
cdata_subsample min=-69.7, mean=-40.6, max=-11.4
cele_subsample min=41.3, mean=2133.2, max=3999.3
clat_subsample min=-86.8, mean=-76.1, max=-66.5
cele_scaled_subsample min=-1.1, mean=0.8, max=2.5
clat_scaled_subsample min=-2.0, mean=0.1, max=1.9
random_sample min=56.0, mean=2963.7, max=5709.0
Types:
ds_aws_preprocessed type: <class 'xarray.core.dataset.Dataset'>
ds_climate_preprocessed type: <class 'xarray.core.dataset.Dataset'>
ds_climate_preprocessed_sample type: <class 'xarray.core.dataset.Dataset'>
ox type: <class 'jaxlib.xla_extension.ArrayImpl'>
odata type: <class 'jaxlib.xla_extension.ArrayImpl'>
olat type: <class 'jaxlib.xla_extension.ArrayImpl'>
oele type: <class 'jaxlib.xla_extension.ArrayImpl'>
olat_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
oele_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cx type: <class 'jaxlib.xla_extension.ArrayImpl'>
cdata type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_scaled type: <class 'jaxlib.xla_extension.ArrayImpl'>
cx_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cdata_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
cele_scaled_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
clat_scaled_subsample type: <class 'jaxlib.xla_extension.ArrayImpl'>
ele_scaler type: <class 'sklearn.preprocessing._data.StandardScaler'>
lat_scaler type: <class 'sklearn.preprocessing._data.StandardScaler'>
random_sample type: <class 'numpy.ndarray'>
Show code cell content
# Plotting the subsample locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)
background_map_rotatedcoords(ax)
ax.scatter(
ds_climate_preprocessed_sample['glon'],
ds_climate_preprocessed_sample['glat'],
marker='s',
s=10,
edgecolor='k',
linewidths=0.5,
)
plt.show()

Defining the model#
Let \(Y(s,t)\) represent the random variable for the June temperature from the AWS data at a particular time and location. Let \(Z(s,t)\) represent the equivalent but for the climate model output. We treat the marginal distribution of \(Y(s,t)\) and \(Z(s,t)\) as Normal, such that \(Y(s)\sim \mathcal{N}(\mu_Y(s),\sigma_Y(s))\) and \(Z(s)\sim \mathcal{N}(\mu_Z(s),\sigma_Z(s))\).
To model the spatial covariance in the parameters we use Gaussian processes. A log transformation is applied to the standard deviation (\(\tilde{\sigma}=log(\sigma)\)) so that it’s on the -\(\infty,\infty\) sample space of a Gaussian process. Shared latent Gaussian processes are considered between the datasets as well as an independent Gaussian process that generates the bias in the climate model output. The parameters for the climate model are then considered as the sum of an unbiased and biased component \(\mu_Z(s)=\mu_Y(s)+\mu_B(s)\) and \(\tilde{\sigma}_Z(s)=\tilde{\sigma}_Y(s)+\tilde{\sigma}_B(s)\), where each component is modelled as generated from a Gaussian process. The Gaussian processes are parameterised by a mean function and covariance function. The mean function for \(\mu_Y(s)\) is considered linear with respect to elevation and latitude, while for the other parameters it is considered a constant. The covariance function is taken as a Matern3/2 kernel with a lengthscale \(l\) and variance \(v\).
The plate diagram below shows the relational dependence between the parameters of the model.
Mean Function and Covariance Function:
The mean function for the parameter \(\mu_Y(s)\) is taken as \(m_{\mu_Y}(s)=\beta_{0,\mu_Y} + \beta_{1,\mu_Y} \cdot x_{ele}(s) + \beta_{2,\mu_Y} \cdot x_{lat}(s)\).
The covariance function is taken as a Matern3/2 kernel for each parameter, that is: \(k_{M^{3/2}}(s,s')=v\left(1+\dfrac{\sqrt{3}|s-s'|}{l}exp\left({-\dfrac{\sqrt{3}|s-s'|}{l}}\right)\right)\). The distance \(|s-s'|\) between points is computed from the ‘grid_latitude’ and ‘grid_longitude’ coordinates, which are approximately Euclidean across the domain. The matern3/2 kernel is a popular choice for real-world data applications, see this blog by Andy Jones for some additional info: The Matérn class of covariance functions .
Splitting up the model for computation#
It’s quite common when using Gaussian processes to fit the mean function independently of the covariance function. That is to say a mean function is fit to the data initially and then the zero-mean data is fit using the GP implementation that handles covariances between points. There are various reasons for this, such as making the covariance matrix more well-conditioned for inference.
In this notebook we’ll split up the model into a first component estimating the parameters \(\mu_Y(s_Y)\), \(\mu_Z(s_Z)\), \(\tilde{\sigma}_Y(s_Y)\) and \(\tilde{\sigma}_Z(s_Z)\) at the AWS and climate model grid cell locations, along with the global parameter values for the mean functions of the latent Gaussian processes \(\beta_{0,\mu_Y}\), \(\beta_{1,\mu_Y}\), \(\beta_{2,\mu_Y}\), \(m_{\mu_B}\), \(m_{\tilde{\sigma}_Y}\) and \(m_{\tilde{\sigma}_B}\). Then the second component will use the residuals \(r_{\mu_Y}(s_Y)\), \(r_{\mu_Z}(s_Z)\), \(r_{\tilde{\sigma}_Y}(s_Y)\) and \(r_{\tilde{\sigma}_Z}(s_Z)\) to estimate the parameters of the covariance function for the latent Gaussian processes \(l_{\mu_Y}\), \(v_{\mu_Y}\), \(l_{\mu_B}\), \(v_{\mu_B}\), \(l_{\tilde{\sigma}_Y}\), \(v_{\tilde{\sigma}_Y}\), \(l_{\tilde{\sigma}_B}\) and \(v_{\tilde{\sigma}_B}\).
Modelling assumptions and splitting up the model:
It’s important to consider the impact of splitting up the hierarchical model and what assumptions it makes. The main assumption we make is that the noise in \(\mu_Y(s_Y)\), \(\mu_Z(s_Z)\), \(\tilde{\sigma}_Y(s_Y)\) and \(\tilde{\sigma}_Z(s_Z)\) is Gaussian, which seems reasonable. However, if we expected the posterior distributions of these parameters to be highly skewed and non-Gaussian then splitting up the model wouldn’t adequately capture this and keeping a hierarchical model would better represent the uncertainty in the final predictive distributions.
Model Diagram Component 1: Mean Function:
Model Diagram Component 2: Residuals and Covariance
Parameter Inference#
Inference is performed on the two components of the model separately. To limit the runtime of this notebook, we’ll provide the code for running the inference but will perform the actual inference separately and load in the output to examine. The python scripts for running the inference separately and saving the output are available via the jupterbook_render branch of the repository. Loading in the inference data:
url = ("https://zenodo.org/records/14779669/files/data_dictionary.pkl?download=1")
filename = f'{data_path}data_dictionary.pkl'
if os.path.exists(filename):
print('File already exists')
else:
urlretrieve(url, filename)
print('File downloaded')
File already exists
Inference on parameters of the mean function#
Utilising the Numpyro python package to define the model for component 1:
Show code cell content
# The model for predicting the mean and logvar for each dataset as well as the parameters for the meanfunction giving domain-wide behaviour
def meanfunc_model(data_dictionary):
"""
Function for defining the GP mean function model for the temperature data
Args:
data_dictionary (python dictionary): dictionary holding the data needed for the model
"""
omean_b0 = numpyro.sample("omean_b0",data_dictionary['omean_b0_prior'])
omean_b1 = numpyro.sample("omean_b1",data_dictionary['omean_b1_prior'])
omean_b2 = numpyro.sample("omean_b2",data_dictionary['omean_b2_prior'])
omean_noise = numpyro.sample("omean_noise",data_dictionary['omean_noise_prior'])
omean_func = omean_b0 + omean_b1*data_dictionary['oele_scaled'] + omean_b2*data_dictionary['olat_scaled']
omean = numpyro.sample("omean",dist.Normal(omean_func, omean_noise))
ologvar_b0 = numpyro.sample("ologvar_b0",data_dictionary['ologvar_b0_prior'])
ologvar_noise = numpyro.sample("ologvar_noise",data_dictionary['ologvar_noise_prior'])
ologvar_func = ologvar_b0 * jnp.ones(data_dictionary['ox'].shape[0])
ologvar = numpyro.sample("ologvar",dist.Normal(ologvar_func, ologvar_noise))
ovar = jnp.exp(ologvar)
obs_mask = (jnp.isnan(data_dictionary['odata'])==False)
numpyro.sample("AWS Temperature", dist.Normal(omean, jnp.sqrt(ovar)).mask(obs_mask), obs=data_dictionary["odata"])
cmean_b0 = numpyro.sample("cmean_b0",data_dictionary['cmean_b0_prior'])
cmean_noise = numpyro.sample("cmean_noise",data_dictionary['cmean_noise_prior'])
cmean_func = cmean_b0 + omean_b1*data_dictionary['cele_scaled'] + omean_b2*data_dictionary['clat_scaled']
cmean = numpyro.sample("cmean",dist.Normal(cmean_func, cmean_noise))
clogvar_b0 = numpyro.sample("clogvar_b0",data_dictionary['clogvar_b0_prior'])
clogvar_noise = numpyro.sample("clogvar_noise",data_dictionary['clogvar_noise_prior'])
clogvar_func = clogvar_b0 * jnp.ones(data_dictionary['cx'].shape[0])
clogvar = numpyro.sample("clogvar",dist.Normal(clogvar_func, clogvar_noise))
cvar = jnp.exp(clogvar)
numpyro.sample("Climate Temperature", dist.Normal(cmean, jnp.sqrt(cvar)), obs=data_dictionary["cdata"])
In the model definition ‘omean’ represents the mean June temperature estimate at each AWS location, while ‘cmean’ is the the equivalent for the climate model output at the grid-cell locations. Similarly, ‘ologvar’ and ‘clogvar’ represents for the log-variance for each dataset and location respectively. The model follows the following equations:
Defining a function for running the Bayesian inference on the model parameters given the data:
Show code cell content
# A function to run inference on the model
def run_inference(
model, rng_key, num_warmup, num_samples, num_chains, *args, **kwargs
):
"""
Helper function for doing MCMC inference
Args:
model (python function): function that follows numpyros syntax
rng_key (np array): PRNGKey for reproducible results
num_warmup (int): Number of MCMC steps for warmup
num_samples (int): Number of MCMC samples to take of parameters after warmup
num_chains (int): Number of chains to run in parallel (or sequentially without GPU)
*args: Additional arguments to pass to the model
**kwargs: Additional keyword arguments to pass to the model
distance_matrix_values(jax device array): matrix of distances between sites, shape [#sites,#sites]
Returns:
MCMC numpyro instance (class object): An MCMC class object with functions such as .get_samples() and .run()
"""
starttime = timeit.default_timer()
kernel = NUTS(model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(rng_key, *args, **kwargs)
mcmc.print_summary()
print("Time Taken:", timeit.default_timer() - starttime)
return mcmc
As we’re utilising a Bayesian framework, we’ll have to define some prior distributions for the parameters. While in theory the prior distribution should represent our state of knowledge before observing the data and so if we have no additional knowledge should be fully non-informative, in practice it’s common to at least set sensible bounds using the exploratory analysis to limit the complexity of the inference space. Here, we compute some basic metrics, which we then use to help set sensible priors.
Show code cell content
print('Useful Metrics for Priors: \n',
f"""mean odata:
min={np.nanmin(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
mean={np.nanmean(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
max={np.nanmax(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
var={np.nanvar(np.nanmean(data_dictionary['odata'],axis=0)):.1f},
\n""",
f"""logvar odata:
min={np.nanmin(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
mean={np.nanmean(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
max={np.nanmax(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
var={np.nanvar(np.log(np.nanvar(data_dictionary['odata'],axis=0))):.1f},
\n""",
f"""mean cdata:
min={np.nanmin(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
mean={np.nanmean(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
max={np.nanmax(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
var={np.nanvar(np.nanmean(data_dictionary['cdata'],axis=0)):.1f},
\n""",
f"""logvar cdata:
min={np.nanmin(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
mean={np.nanmean(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
max={np.nanmax(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
var={np.nanvar(np.log(np.nanvar(data_dictionary['cdata'],axis=0))):.1f},
\n""",
)
Useful Metrics for Priors:
mean odata:
min=-65.0,
mean=-33.4,
max=-13.4,
var=155.0,
logvar odata:
min=-4.9,
mean=1.7,
max=4.1,
var=1.1,
mean cdata:
min=-62.8,
mean=-39.6,
max=-12.6,
var=154.6,
logvar cdata:
min=0.6,
mean=2.1,
max=3.3,
var=0.2,
Show code cell content
# Setting priors
data_dictionary.update({
"omean_b0_prior": dist.Normal(-33.0, 10.0),
"omean_b1_prior": dist.Normal(0.0, 10.0),
"omean_b2_prior": dist.Normal(0.0, 10.0),
"omean_noise_prior": dist.Uniform(1e-2, 10.0),
"ologvar_b0_prior": dist.Normal(5, 5.0),
"ologvar_noise_prior": dist.Uniform(1e-3, 2.0),
})
data_dictionary.update({
"cmean_b0_prior": dist.Normal(-39.0, 10.0),
"cmean_noise_prior": dist.Uniform(1e-2, 10.0),
"clogvar_b0_prior": dist.Normal(5, 2.0),
"clogvar_noise_prior": dist.Uniform(1e-3, 2.0),
})
The code for running the inference and saving estimates of the parameter posterior distributions is given below. Note that the code is commented out as instead of running the inference within this notebook, for computational reasons we’ll simply load in the output. The output is saved using the ArViZ package, which handles inference data from lots of different probabilistic programming packages and can be used to produce nice summary metrics and plots.
Show code cell content
'''
# %% Running inference
mcmc = run_inference(meanfunc_model, rng_key, 1000, 2000,4, data_dictionary)
idata = az.from_numpyro(mcmc,
coords={
"station": data_dictionary['ds_aws_preprocessed']['station'],
"x": data_dictionary['ds_climate_preprocessed']['x'],
},
dims={"clogvar": ["x"],
"cmean": ["x"],
"ologvar": ["station"],
"omean": ["station"],})
meanfunc_posterior = idata.posterior
# Computing the residuals from the mean function model parameters
meanfunc_posterior = meanfunc_posterior.assign_coords({'oele_scaled':('station', data_dictionary['oele_scaled']),
'olat_scaled':('station', data_dictionary['olat_scaled']),
'cele_scaled':('x', data_dictionary['cele_scaled']),
'clat_scaled':('x', data_dictionary['clat_scaled'])})
meanfunc_posterior['omean_func'] = meanfunc_posterior['omean_b0']+meanfunc_posterior['omean_b1']*meanfunc_posterior['oele_scaled']+meanfunc_posterior['omean_b2']*meanfunc_posterior['olat_scaled']
meanfunc_posterior['cmean_func'] = meanfunc_posterior['cmean_b0']+meanfunc_posterior['omean_b1']*meanfunc_posterior['cele_scaled']+meanfunc_posterior['omean_b2']*meanfunc_posterior['clat_scaled']
meanfunc_posterior['ologvar_func'] = meanfunc_posterior['ologvar_b0']
meanfunc_posterior['clogvar_func'] = meanfunc_posterior['clogvar_b0']
meanfunc_posterior['omean_func_residual'] = meanfunc_posterior['omean']-meanfunc_posterior['omean_func']
meanfunc_posterior['cmean_func_residual'] = meanfunc_posterior['cmean']-meanfunc_posterior['cmean_func']
meanfunc_posterior['ologvar_func_residual'] = meanfunc_posterior['ologvar']-meanfunc_posterior['ologvar_func']
meanfunc_posterior['clogvar_func_residual'] = meanfunc_posterior['clogvar']-meanfunc_posterior['clogvar_func']
data_dictionary['meanfunc_posterior'] = meanfunc_posterior
'''
with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
data_dictionary = pickle.load(f)
meanfunc_posterior = data_dictionary['meanfunc_posterior']
MCMC is an approximate procedure and it is difficult to assess directly whether the samples returned from the inference accurately capture the true posterior distributions of the parameters. The effective sample size (ESS) and r_hat diagnostics provide some indication, we’re looking for values of r_hat in the range (1.0, 1.05) and ESS that are comparable to the total number of samples, which we get in this run:
Show code cell source
az.summary(meanfunc_posterior[['omean_b0',
'omean_b1',
'omean_b2',
'omean_noise',
'ologvar_b0',
'ologvar_noise',
'cmean_b0',
'cmean_noise',
'clogvar_b0',
]],hdi_prob=0.95)
mean | sd | hdi_2.5% | hdi_97.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
omean_b0 | -33.416 | 0.539 | -34.523 | -32.402 | 0.005 | 0.007 | 13604.0 | 5746.0 | 1.0 |
omean_b1 | -10.059 | 0.069 | -10.193 | -9.926 | 0.001 | 0.001 | 11527.0 | 6119.0 | 1.0 |
omean_b2 | 3.100 | 0.068 | 2.963 | 3.234 | 0.001 | 0.001 | 16561.0 | 6066.0 | 1.0 |
omean_noise | 6.740 | 0.402 | 5.978 | 7.514 | 0.004 | 0.006 | 12338.0 | 5185.0 | 1.0 |
ologvar_b0 | 2.043 | 0.066 | 1.917 | 2.174 | 0.001 | 0.001 | 6133.0 | 6362.0 | 1.0 |
ologvar_noise | 0.584 | 0.063 | 0.467 | 0.710 | 0.001 | 0.001 | 2558.0 | 4297.0 | 1.0 |
cmean_b0 | -32.623 | 0.084 | -32.796 | -32.464 | 0.001 | 0.001 | 11574.0 | 5437.0 | 1.0 |
cmean_noise | 5.263 | 0.050 | 5.164 | 5.361 | 0.000 | 0.001 | 16003.0 | 5160.0 | 1.0 |
clogvar_b0 | 2.146 | 0.006 | 2.135 | 2.158 | 0.000 | 0.000 | 9323.0 | 6671.0 | 1.0 |
Inference on parameters of the residual and covariance functions#
For component 2 of the model, the residuals of the mean and log(standard dev.) are treated as independent and so inference can be conducted separately for each. Again utilising the Numpyro python package to define the model gives:
Show code cell content
# Helper functions to be used in the residual model
def diagonal_noise(coord, noise):
return jnp.diag(jnp.full(coord.shape[0], noise))
def generate_obs_conditional_climate_dist(
ox, cx, cdata, ckernel, cdiag, okernel, odiag
):
y2 = cdata
u1 = jnp.full(ox.shape[0], 0)
u2 = jnp.full(cx.shape[0], 0)
k11 = okernel(ox, ox) + diagonal_noise(ox, odiag)
k12 = okernel(ox, cx)
k21 = okernel(cx, ox)
k22 = ckernel(cx, cx) + diagonal_noise(cx, cdiag)
k22i = jnp.linalg.inv(k22)
u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
l22 = jnp.linalg.cholesky(k22)
l22i = jnp.linalg.inv(l22)
p21 = jnp.matmul(l22i, k21)
k1g2 = k11 - jnp.matmul(p21.T, p21)
mvn_dist = dist.MultivariateNormal(u1g2, k1g2)
return mvn_dist
# The residual model for the mean temperature
def residual_model(data_dictionary,metric):
"""
Example model where the climate data is generated from 2 GPs,
one of which also generates the observations and one of
which generates bias in the climate model.
"""
meanfunc_posterior = data_dictionary['meanfunc_posterior']
omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
cmeanfunc_residual_exp_subsample = meanfunc_posterior[f'c{metric}_func_residual'].isel(x=data_dictionary['random_sample']).mean(['draw','chain']).data
kern_var = numpyro.sample("kern_var", data_dictionary[f'o{metric}_func_residual_kvprior'])
lengthscale = numpyro.sample("lengthscale", data_dictionary[f'o{metric}_func_residual_klprior'])
kernel = kern_var * kernels.Matern32(lengthscale,L2Distance())
noise = numpyro.sample("noise", data_dictionary[f'o{metric}_func_residual_nprior'])
var_obs = omeanfunc_residual_var
bkern_var = numpyro.sample("bkern_var", data_dictionary[f'b{metric}_func_residual_kvprior'])
blengthscale = numpyro.sample("blengthscale", data_dictionary[f'b{metric}_func_residual_klprior'])
bkernel = bkern_var * kernels.Matern32(blengthscale,L2Distance())
bnoise = numpyro.sample("bnoise", data_dictionary[f'b{metric}_func_residual_nprior'])
ckernel = kernel + bkernel
cnoise = noise + bnoise
cgp = GaussianProcess(ckernel, data_dictionary["cx_subsample"], diag=cnoise, mean=0)
numpyro.sample("climate_temperature",
cgp.numpyro_dist(),
obs=cmeanfunc_residual_exp_subsample)
obs_conditional_climate_dist = generate_obs_conditional_climate_dist(
data_dictionary["ox"],
data_dictionary["cx_subsample"],
cmeanfunc_residual_exp_subsample,
ckernel,
cnoise,
kernel,
var_obs+noise
)
numpyro.sample(
"obs_temperature",
obs_conditional_climate_dist,
obs=omeanfunc_residual_exp
)
# Function for running the MCMC inference and generating the posterior distributions for the model
def generate_posterior_residual_model(data_dictionary,
metric,
rng_key,
num_warmup,
num_samples,
num_chains):
mcmc_residual_model = run_inference(
residual_model,
rng_key,
num_warmup,
num_samples,
num_chains,
data_dictionary,
metric
)
idata_residual_model = az.from_numpyro(mcmc_residual_model)
data_dictionary[f"idata_residual_model_{metric}"] = idata_residual_model
In the model definition for the residuals the observations are the expectations of the residuals computed from inference of the meanfunction model in component 1. That is \(E[r_{\mu_Y}(s_Y)]\) and \(E[r_{\mu_Z}(s_Z)]\). The uncertainty in these values is captured through the variance, so \(V[r_{\mu_Y}(s_Y)]\) and \(V[r_{\mu_Z}(s_Z)]\). Although, we don’t include \(V[r_{\mu_Z}(s_Z)]\) in the model definition as it’s an insignificant quantity. The model follows the equations:
Function ‘generate_obs_conditional_climate_dist’:
The function ‘generate_obs_conditional_climate_dist’ is defined as the parameters of \(\mathcal{GP}(0|,k(s,s'|l_{\mu_Y},v_{\mu_Y},n_{\mu_Y}))\) and \(\mathcal{GP}(0|,k(s,s'|l_{\mu_B},v_{\mu_B},n_{\mu_B}))\) are conditional on both \(r_{\mu_Y}(s_Y)\) and \(r_{\mu_Z}(s_Z)\). So we generate a distribution for the probability of the unbiased AWS data given the already conditioned on climate model data.
Again choosing sensible priors for the parameters of the model that have reasonable bounds based on some basic summary metrics:
Show code cell content
# Useful metrics for priors
meanfunc_posterior = data_dictionary['meanfunc_posterior']
exp_omean_func_residual = meanfunc_posterior['omean_func_residual'].mean(['draw','chain'])
exp_ologvar_func_residual = meanfunc_posterior['ologvar_func_residual'].mean(['draw','chain'])
ox_ranges = data_dictionary['ox'].max(axis=0)-data_dictionary['ox'].min(axis=0)
print('Useful Metrics for Priors: \n',
f"""
Mean Obs:
min={exp_omean_func_residual.min():.1f},
mean={exp_omean_func_residual.mean():.1f},
max={exp_omean_func_residual.max():.1f},
var={exp_omean_func_residual.var():.1f},
""",
f"""
LogVar Obs:
min={exp_ologvar_func_residual.min():.1f},
mean={exp_ologvar_func_residual.mean():.1f},
max={exp_ologvar_func_residual.max():.1f},
var={exp_ologvar_func_residual.var():.1f},
""",
f'\n Obs Axis Ranges={ox_ranges}'
)
Useful Metrics for Priors:
Mean Obs:
min=-16.9,
mean=-0.0,
max=17.3,
var=43.1,
LogVar Obs:
min=-0.9,
mean=-0.0,
max=1.2,
var=0.2,
Obs Axis Ranges=[45.47068137 34.36497791]
Show code cell content
# Setting priors
lengthscale_max = 20
data_dictionary['omean_func_residual_kvprior'] = dist.Uniform(0.1,100.0)
data_dictionary['omean_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['omean_func_residual_nprior'] = dist.Uniform(0.1,20.0)
data_dictionary['bmean_func_residual_kvprior'] = dist.Uniform(0.1,100.0)
data_dictionary['bmean_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['bmean_func_residual_nprior'] = dist.Uniform(0.1,20.0)
data_dictionary['ologvar_func_residual_kvprior'] = dist.Uniform(0.01,1.0)
data_dictionary['ologvar_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['ologvar_func_residual_nprior'] = dist.Uniform(0.01,1.0)
data_dictionary['blogvar_func_residual_kvprior'] = dist.Uniform(0.01,1.0)
data_dictionary['blogvar_func_residual_klprior'] = dist.Uniform(1,lengthscale_max)
data_dictionary['blogvar_func_residual_nprior'] = dist.Uniform(0.01,1.0)
The code for running the inference and saving estimates of the parameter posterior distributions is given below. Note that the code is commented out as instead of running the inference within this notebook, for computational reasons we’ll simply load in the output.
Show code cell content
# Running the inference
'''
generate_posterior_residual_model(data_dictionary,
'mean',
rng_key,
1000,
1000,
4)
generate_posterior_residual_model(data_dictionary,
'logvar',
rng_key,
1000,
1000,
4)
'''
with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
data_dictionary = pickle.load(f)
idata_residual_model_mean = data_dictionary["idata_residual_model_mean"]
idata_residual_model_logvar = data_dictionary["idata_residual_model_logvar"]
Show code cell source
print(r"Parameter Inference for Model of Mean Residuals")
display(az.summary(idata_residual_model_mean.posterior,hdi_prob=0.95))
print(r"Parameter Inference for Model of LogVar Residuals")
display(az.summary(idata_residual_model_logvar.posterior,hdi_prob=0.95))
Parameter Inference for Model of Mean Residuals
mean | sd | hdi_2.5% | hdi_97.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
bkern_var | 5.838 | 7.004 | 0.102 | 18.185 | 0.125 | 0.311 | 2838.0 | 1879.0 | 1.0 |
blengthscale | 12.991 | 4.687 | 4.797 | 19.994 | 0.075 | 0.053 | 3111.0 | 1959.0 | 1.0 |
bnoise | 0.604 | 0.510 | 0.101 | 1.622 | 0.008 | 0.013 | 2689.0 | 1391.0 | 1.0 |
kern_var | 35.067 | 9.148 | 20.810 | 53.265 | 0.263 | 0.537 | 2067.0 | 1227.0 | 1.0 |
lengthscale | 3.431 | 0.638 | 2.308 | 4.641 | 0.016 | 0.019 | 2015.0 | 1571.0 | 1.0 |
noise | 5.468 | 1.086 | 3.384 | 7.638 | 0.020 | 0.018 | 3129.0 | 2633.0 | 1.0 |
Parameter Inference for Model of LogVar Residuals
mean | sd | hdi_2.5% | hdi_97.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
bkern_var | 0.067 | 0.064 | 0.010 | 0.184 | 0.001 | 0.003 | 2827.0 | 1632.0 | 1.0 |
blengthscale | 13.684 | 4.219 | 6.214 | 19.997 | 0.076 | 0.048 | 2552.0 | 1915.0 | 1.0 |
bnoise | 0.013 | 0.003 | 0.010 | 0.019 | 0.000 | 0.000 | 2291.0 | 1534.0 | 1.0 |
kern_var | 0.125 | 0.075 | 0.044 | 0.250 | 0.002 | 0.006 | 1939.0 | 1116.0 | 1.0 |
lengthscale | 5.598 | 1.638 | 3.142 | 8.781 | 0.045 | 0.075 | 1974.0 | 1236.0 | 1.0 |
noise | 0.012 | 0.002 | 0.010 | 0.017 | 0.000 | 0.000 | 3033.0 | 1774.0 | 1.0 |
The posterior distributions for the parameters of the model appear sensible on first glance. The r-hat statistic indicates that the chains are independent as desired and the effective sample size is of comparable magnitude to the total number of samples taken.
Making posterior predictive estimates of the unbiased mean and variance across the domain#
From the previous section we have:
Posterior estimates of the mean and log-variance at the AWS sites and climate model grid cells, that is \(\mu_Y(s_Y)\), \(\tilde{\sigma}_Y(s_Y)\), \(\mu_Z(s_Z)\) and \(\tilde{\sigma}_Z(s_Z)\). As well as the estimates for the meanfunction and residual components: \(m_{\mu_Y}(s_Y)\), \(m_{\tilde{\sigma}_Y}(s_Y)\), \(m_{\mu_Z}(s_Z)\), \(m_{\tilde{\sigma}_Z}(s_Z)\) and \(r_{\mu_Y}(s_Y)\), \(r_{\tilde{\sigma}_Y}(s_Y)\), \(r_{\mu_Z}(s_Z)\) and \(r_{\tilde{\sigma}_Z}(s_Z)\).
Posterior estimates of the parameters of the mean functions, so \(\beta_{0,\mu_Y}\), \(\beta_{1,\mu}\), \(\beta_{2,\mu}\), \(\beta_{0,\mu_Z}\), \(\beta_{0,\tilde{\sigma}_Y}\) and \(\beta_{0,\tilde{\sigma}_Z}\).
Posterior estimates of the parameters of the covariance functions, so \(l_{\mu_Y}\), \(v_{\mu_Y}\), \(n_{\mu_Y}\), \(l_{\mu_B}\), \(v_{\mu_B}\) and \(n_{\mu_B}\).
Now we want to make estimates of the posterior distributions of the unbiased mean and log-variance at the climate model locations to use for bias correction, so \(\mu_Y(s_Z)\) and \(\tilde{\sigma}_Y(s_Z)\). This is known as the posterior predictive and involves conditioning on both the inferred parameters of the model and the observed data. Since we split up the hierarchical model, to get \(\mu_Y(s_Z)\) and \(\tilde{\sigma}_Y(s_Z)\) we’ll get estimates for \(m_{\mu_Y}(s_Z)\) and \(m_{\tilde{\sigma}_Y}(s_Z)\) from the first component of the model, then \(r_{\mu_Y}(s_Z)\) and \(r_{\tilde{\sigma}_Y}(s_Z)\) from the second component.
Posterior predictive calculation:
Let the collection of parameters for the model be: \(\phi = [m_{\mu_Y}(s_Y), m_{\mu_Z}(s_Z), r_{\mu_Y}(s_Y), r_{\mu_Z}(s_Z), \beta_{0,\mu_Y}, \beta_{1,\mu}, \beta_{2,\mu}, \beta_{0,\mu_Z}, l_{\mu_Y}, v_{\mu_Y}, n_{\mu_Y}, l_{\mu_B}, v_{\mu_B}, n_{\mu_B}]\)
Then the posterior predictive distributions is given by:
Where \(p(m_{\mu_Y}(\hat{s})|\beta_{0,\mu_Y}, \beta_{1,\mu}, \beta_{2,\mu}, \beta_{0,\mu_Z})\) is the predictive distribution of the meanfunction and \(p(r_{\mu_Y}(\hat{s})|r_{\mu_Y}(s_Y), r_{\mu_Z}(s_Z),l_{\mu_Y}, v_{\mu_Y}, n_{\mu_Y}, l_{\mu_B}, v_{\mu_B}, n_{\mu_B})\) is the predictive distribution of the residual function. The distribution \(p(\theta|Y(s_Y),Z(s_Z))\) is simply the posterior distribution for the parameters of our model, which we get samples from using MCMC.
To get samples of \(\mu_Y(\hat{s})\) from the posterior predictive distribution we take samples from the predictive distribution of the meanfunction and the predictive distribution of the residual using parameters sampled from the posterior. That is for each sample from our MCMC inference, we define the predictive distributions and then take a sample.
Sampling the predictive distribution of the mean function is simple:
Show code cell content
# Generating posterior predictive estimates of the mean function
meanfunc_posterior['mean_unbiased_meanfunc_predictive'] = (meanfunc_posterior['omean_b0']+
meanfunc_posterior['omean_b1']*meanfunc_posterior['cele_scaled']+
meanfunc_posterior['omean_b2']*meanfunc_posterior['clat_scaled'])
meanfunc_posterior['mean_biased_meanfunc_predictive'] = (meanfunc_posterior['cmean_b0']-
meanfunc_posterior['omean_b0'])
meanfunc_posterior['logvar_unbiased_meanfunc_predictive'] = (meanfunc_posterior['ologvar_b0'])
meanfunc_posterior['logvar_biased_meanfunc_predictive'] = (meanfunc_posterior['clogvar_b0']-
meanfunc_posterior['ologvar_b0'])
Sampling the predictive distribution of the residual is more difficult and we construct some functions to help.
Show code cell content
# %% Defining function for generating posterior predictive realisations of the residuals
def generate_truth_predictive_dist(nx, data_dictionary, metric, posterior_param_realisation):
kern_var_realisation = posterior_param_realisation["kern_var_realisation"]
lengthscale_realisation = posterior_param_realisation["lengthscale_realisation"]
noise_realisation = posterior_param_realisation["noise_realisation"]
bkern_var_realisation = posterior_param_realisation["bkern_var_realisation"]
blengthscale_realisation = posterior_param_realisation["blengthscale_realisation"]
bnoise_realisation = posterior_param_realisation["bnoise_realisation"]
meanfunc_posterior = data_dictionary['meanfunc_posterior']
omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
cmeanfunc_residual_exp = meanfunc_posterior[f'c{metric}_func_residual'].mean(['draw','chain']).data
ox = data_dictionary["ox"]
cx = data_dictionary["cx"]
odata = omeanfunc_residual_exp
odata_var = omeanfunc_residual_var
cdata = cmeanfunc_residual_exp
kernelo = kern_var_realisation * kernels.Matern32(lengthscale_realisation,L2Distance())
kernelb = bkern_var_realisation * kernels.Matern32(blengthscale_realisation,L2Distance())
noise = noise_realisation + odata_var
bnoise = bnoise_realisation
cnoise = noise_realisation + bnoise
jitter = 1e-5
y2 = jnp.hstack([odata, cdata])
u1 = jnp.full(nx.shape[0], 0)
u2 = jnp.hstack(
[jnp.full(ox.shape[0], 0), jnp.full(cx.shape[0], 0)]
)
k11 = kernelo(nx, nx) + diagonal_noise(nx, jitter)
k12 = jnp.hstack([kernelo(nx, ox), kernelo(nx, cx)])
k21 = jnp.vstack([kernelo(ox, nx), kernelo(cx, nx)])
k22_upper = jnp.hstack(
[kernelo(ox, ox) + diagonal_noise(ox, noise), kernelo(ox, cx)]
)
k22_lower = jnp.hstack(
[
kernelo(cx, ox),
kernelo(cx, cx) + kernelb(cx, cx) + diagonal_noise(cx, cnoise),
]
)
k22 = jnp.vstack([k22_upper, k22_lower])
k22 = k22
k22i = jnp.linalg.inv(k22)
u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
k1g2 = k11 - jnp.matmul(jnp.matmul(k12, k22i), k21)
mvn = dist.MultivariateNormal(u1g2, k1g2)
return mvn
def generate_bias_predictive_dist(nx, data_dictionary, metric, posterior_param_realisation):
kern_var_realisation = posterior_param_realisation["kern_var_realisation"]
lengthscale_realisation = posterior_param_realisation["lengthscale_realisation"]
noise_realisation = posterior_param_realisation["noise_realisation"]
bkern_var_realisation = posterior_param_realisation["bkern_var_realisation"]
blengthscale_realisation = posterior_param_realisation["blengthscale_realisation"]
bnoise_realisation = posterior_param_realisation["bnoise_realisation"]
meanfunc_posterior = data_dictionary['meanfunc_posterior']
omeanfunc_residual_exp = meanfunc_posterior[f'o{metric}_func_residual'].mean(['draw','chain']).data
omeanfunc_residual_var = meanfunc_posterior[f'o{metric}_func_residual'].var(['draw','chain']).data
cmeanfunc_residual_exp = meanfunc_posterior[f'c{metric}_func_residual'].mean(['draw','chain']).data
ox = data_dictionary["ox"]
cx = data_dictionary["cx"]
odata = omeanfunc_residual_exp
odata_var = omeanfunc_residual_var
cdata = cmeanfunc_residual_exp
kernelo = kern_var_realisation * kernels.Matern32(lengthscale_realisation,L2Distance())
kernelb = bkern_var_realisation * kernels.Matern32(blengthscale_realisation,L2Distance())
noise = noise_realisation + odata_var
bnoise = bnoise_realisation
cnoise = noise_realisation + bnoise
jitter = 1e-5
y2 = jnp.hstack([odata, cdata])
u1 = jnp.full(nx.shape[0], 0)
u2 = jnp.hstack(
[jnp.full(ox.shape[0], 0), jnp.full(cx.shape[0], 0)]
)
k11 = kernelb(nx, nx) + diagonal_noise(nx, jitter)
k12 = jnp.hstack([jnp.full((len(nx), len(ox)), 0), kernelb(nx, cx)])
k21 = jnp.vstack([jnp.full((len(ox), len(nx)), 0), kernelb(cx, nx)])
k22_upper = jnp.hstack(
[kernelo(ox, ox) + diagonal_noise(ox, noise), kernelo(ox, cx)]
)
k22_lower = jnp.hstack(
[
kernelo(cx, ox),
kernelo(cx, cx) + kernelb(cx, cx) + diagonal_noise(cx, cnoise),
]
)
k22 = jnp.vstack([k22_upper, k22_lower])
k22 = k22
k22i = jnp.linalg.inv(k22)
u1g2 = u1 + jnp.matmul(jnp.matmul(k12, k22i), y2 - u2)
l22 = jnp.linalg.cholesky(k22)
l22i = jnp.linalg.inv(l22)
p21 = jnp.matmul(l22i, k21)
k1g2 = k11 - jnp.matmul(p21.T, p21)
mvn = dist.MultivariateNormal(u1g2, k1g2)
return mvn
def generate_posterior_predictive_realisations_dualprocess(
nx,
data_dictionary,
metric,
num_parameter_realisations,
num_posterior_pred_realisations,
rng_key
):
posterior = data_dictionary[f"idata_residual_model_{metric}"].posterior
truth_posterior_predictive_realisations = []
bias_posterior_predictive_realisations = []
iteration = 0
for i in tqdm(np.random.randint(posterior.draw.shape, size=num_parameter_realisations)):
posterior_param_realisation = {
"iteration": i,
"kern_var_realisation": posterior["kern_var"].data[0, :][i],
"lengthscale_realisation": posterior["lengthscale"].data[0, :][i],
"noise_realisation": posterior["noise"].data[0, :][i],
"bkern_var_realisation": posterior["bkern_var"].data[0, :][i],
"blengthscale_realisation": posterior["blengthscale"].data[0, :][i],
"bnoise_realisation": posterior["bnoise"].data[0, :][i],
}
truth_predictive_dist = generate_truth_predictive_dist(
nx, data_dictionary, metric, posterior_param_realisation
)
bias_predictive_dist = generate_bias_predictive_dist(
nx, data_dictionary, metric, posterior_param_realisation
)
iteration += 1
truth_predictive_realisations = truth_predictive_dist.sample(
rng_key, sample_shape=(num_posterior_pred_realisations,)
)
rng_key, rng_key_ = random.split(rng_key)
bias_predictive_realisations = bias_predictive_dist.sample(
rng_key, sample_shape=(num_posterior_pred_realisations,)
)
rng_key, rng_key_ = random.split(rng_key)
truth_posterior_predictive_realisations.append(truth_predictive_realisations)
bias_posterior_predictive_realisations.append(bias_predictive_realisations)
truth_posterior_predictive_realisations = jnp.array(
truth_posterior_predictive_realisations
)
bias_posterior_predictive_realisations = jnp.array(
bias_posterior_predictive_realisations
)
residual_pospred_ds = xr.Dataset(
coords = {'hyperparameter_draws':range(num_parameter_realisations),
'gp_draws':range(num_posterior_pred_realisations),
'nx':range(len(nx)),
},
data_vars = {f'unbiased_{metric}_residual_postpred':(['hyperparameter_draws','gp_draws','nx'], truth_posterior_predictive_realisations),
f'bias_{metric}_residual_postpred':(['hyperparameter_draws','gp_draws','nx'], bias_posterior_predictive_realisations),
}
)
return residual_pospred_ds
Generating the residual posterior predictive samples:
Show code cell content
# Generating posterior predictive realisations of the residuals
'''
residual_postpred_mean = generate_posterior_predictive_realisations_dualprocess(
data_dictionary["cx"],
data_dictionary,
"mean",
100,
5,
rng_key
)
residual_postpred_logvar = generate_posterior_predictive_realisations_dualprocess(
data_dictionary["cx"],
data_dictionary,
"logvar",
100,
5,
rng_key
)
'''
with open(f'{data_path}data_dictionary.pkl', 'rb') as f:
data_dictionary = pickle.load(f)
residual_postpred_mean = data_dictionary["residual_postpred_mean"]
residual_postpred_logvar = data_dictionary["residual_postpred_logvar"]
Computing the expectations and uncertainties:
Show code cell content
# Meanfunction posterior predictive estimates expectation and uncertainty
mean_unbiased_meanfunc_predictive_exp = meanfunc_posterior['mean_unbiased_meanfunc_predictive'].mean(['draw','chain'])
mean_unbiased_meanfunc_predictive_var = meanfunc_posterior['mean_unbiased_meanfunc_predictive'].var(['draw','chain'])
mean_biased_meanfunc_predictive_exp = meanfunc_posterior['mean_biased_meanfunc_predictive'].mean(['draw','chain'])
mean_biased_meanfunc_predictive_var = meanfunc_posterior['mean_biased_meanfunc_predictive'].var(['draw','chain'])
logvar_unbiased_meanfunc_predictive_exp = meanfunc_posterior['logvar_unbiased_meanfunc_predictive'].mean(['draw','chain'])
logvar_unbiased_meanfunc_predictive_var = meanfunc_posterior['logvar_unbiased_meanfunc_predictive'].var(['draw','chain'])
logvar_biased_meanfunc_predictive_exp = meanfunc_posterior['logvar_biased_meanfunc_predictive'].mean(['draw','chain'])
logvar_biased_meanfunc_predictive_var = meanfunc_posterior['logvar_biased_meanfunc_predictive'].var(['draw','chain'])
# Residuals posterior predictive estimates expectation and uncertainty
mean_unbiased_residual_postpred_exp = residual_postpred_mean['unbiased_mean_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
mean_unbiased_residual_postpred_var = residual_postpred_mean['unbiased_mean_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
mean_biased_residual_postpred_exp = residual_postpred_mean['bias_mean_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
mean_biased_residual_postpred_var = residual_postpred_mean['bias_mean_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
logvar_unbiased_residual_postpred_exp = residual_postpred_logvar['unbiased_logvar_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
logvar_unbiased_residual_postpred_var = residual_postpred_logvar['unbiased_logvar_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
logvar_biased_residual_postpred_exp = residual_postpred_logvar['bias_logvar_residual_postpred'].mean(['gp_draws','hyperparameter_draws'])
logvar_biased_residual_postpred_var = residual_postpred_logvar['bias_logvar_residual_postpred'].var(['gp_draws','hyperparameter_draws'])
# Noise in residuals posterior predictive estimates expectation and uncertainty
mean_unbiased_residual_postpred_noise = idata_residual_model_mean.posterior['noise'].mean()
mean_biased_residual_postpred_noise = idata_residual_model_mean.posterior['bnoise'].mean()
logvar_unbiased_residual_postpred_noise = idata_residual_model_logvar.posterior['noise'].mean()
logvar_biased_residual_postpred_noise = idata_residual_model_logvar.posterior['bnoise'].mean()
# Combined posterior predictive estimates expectation and uncertainty
mean_unbiased_predictive_exp = mean_unbiased_meanfunc_predictive_exp.data + mean_unbiased_residual_postpred_exp.data
mean_unbiased_predictive_var = mean_unbiased_meanfunc_predictive_var.data + mean_unbiased_residual_postpred_var.data + mean_unbiased_residual_postpred_noise.data
mean_biased_predictive_exp = mean_biased_meanfunc_predictive_exp.data + mean_biased_residual_postpred_exp.data
mean_biased_predictive_var = mean_biased_meanfunc_predictive_var.data + mean_biased_residual_postpred_var.data + mean_biased_residual_postpred_noise.data
logvar_unbiased_predictive_exp = logvar_unbiased_meanfunc_predictive_exp.data + logvar_unbiased_residual_postpred_exp.data
logvar_unbiased_predictive_var = logvar_unbiased_meanfunc_predictive_var.data + logvar_unbiased_residual_postpred_var.data + logvar_unbiased_residual_postpred_noise.data
logvar_biased_predictive_exp = logvar_biased_meanfunc_predictive_exp.data + logvar_biased_residual_postpred_exp.data
logvar_biased_predictive_var = logvar_biased_meanfunc_predictive_var.data + logvar_biased_residual_postpred_var.data + logvar_biased_residual_postpred_noise.data
Results#
To plot the predictions of the unbiased mean and log(variance), as well as the bias, we combine the posterior predictives with a dataset containing the relevant coordinates for plotting:
Show code cell content
# Incorporating results into dataset with coordinates
ds_climate_stacked_landonly_filtered['mean_unbiased_meanfunc_predictive_exp'] = (['x'],mean_unbiased_meanfunc_predictive_exp.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_meanfunc_predictive_var'] = (['x'],mean_unbiased_meanfunc_predictive_var.data)
ds_climate_stacked_landonly_filtered['mean_biased_meanfunc_predictive_exp'] = (['x'],np.full(5724,mean_biased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['mean_biased_meanfunc_predictive_var'] = (['x'],np.full(5724,mean_biased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['logvar_unbiased_meanfunc_predictive_exp'] = (['x'],np.full(5724,logvar_unbiased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['logvar_unbiased_meanfunc_predictive_var'] = (['x'],np.full(5724,logvar_unbiased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['logvar_biased_meanfunc_predictive_exp'] = (['x'],np.full(5724,logvar_biased_meanfunc_predictive_exp.data))
ds_climate_stacked_landonly_filtered['logvar_biased_meanfunc_predictive_var'] = (['x'],np.full(5724,logvar_biased_meanfunc_predictive_var.data))
ds_climate_stacked_landonly_filtered['mean_unbiased_residual_postpred_exp'] = (['x'],mean_unbiased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_residual_postpred_var'] = (['x'],mean_unbiased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['mean_biased_residual_postpred_exp'] = (['x'],mean_biased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['mean_biased_residual_postpred_var'] = (['x'],mean_biased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['logvar_unbiased_residual_postpred_exp'] = (['x'],logvar_unbiased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['logvar_unbiased_residual_postpred_var'] = (['x'],logvar_unbiased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['logvar_biased_residual_postpred_exp'] = (['x'],logvar_biased_residual_postpred_exp.data)
ds_climate_stacked_landonly_filtered['logvar_biased_residual_postpred_var'] = (['x'],logvar_biased_residual_postpred_var.data)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_exp'] = (['x'],mean_unbiased_predictive_exp)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_var'] = (['x'],mean_unbiased_predictive_var)
ds_climate_stacked_landonly_filtered['mean_unbiased_predictive_2sigma'] = (['x'],2*np.sqrt(mean_unbiased_predictive_var))
ds_climate_stacked_landonly_filtered['mean_biased_predictive_exp'] = (['x'],mean_biased_predictive_exp)
ds_climate_stacked_landonly_filtered['mean_biased_predictive_var'] = (['x'],mean_biased_predictive_var)
ds_climate_stacked_landonly_filtered['mean_biased_predictive_2sigma'] = (['x'],2*np.sqrt(mean_biased_predictive_var))
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_exp'] = (['x'],logvar_unbiased_predictive_exp)
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_var'] = (['x'],logvar_unbiased_predictive_var)
ds_climate_stacked_landonly_filtered['logvar_unbiased_predictive_2sigma'] = (['x'],2*np.sqrt(logvar_unbiased_predictive_var))
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_exp'] = (['x'],logvar_biased_predictive_exp)
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_var'] = (['x'],logvar_biased_predictive_var)
ds_climate_stacked_landonly_filtered['logvar_biased_predictive_2sigma'] = (['x'],2*np.sqrt(logvar_biased_predictive_var))
ds_climate_stacked = xr.merge([ds_climate_stacked,ds_climate_stacked_landonly_filtered])
The expectation of the unbiased and biased June mean temperature at the climate model grid-cell locations is then given below along with the uncertainty. The values are the result of summing the posterior predictive estimates from both components of the model.
Show code cell source
# Plotting the predictions for the unbiased June mean temperature and the bias.
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)
metrics = ['mean_unbiased_predictive_exp','mean_biased_predictive_exp','mean_unbiased_predictive_2sigma','mean_biased_predictive_2sigma']
cmaps = ['viridis','RdBu','plasma','plasma']
vminmaxs = [(-60,-15),(-3,3),(5,6),(2.0,3.5)]
labels = ['Expectation Unbiased June Mean Temperature','Expectation Bias June Mean Temperature',
'Uncertainty Unbiased June Mean Temperature','Uncertainty Bias June Mean Temperature']
for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
background_map_rotatedcoords(ax)
ds_climate_stacked[metric].unstack().plot.pcolormesh(
x='glon',
y='glat',
ax=ax,
alpha=0.9,
vmin=vminmax[0],
vmax=vminmax[1],
cmap=cmap,
cbar_kwargs = {'fraction':0.030,
'pad':0.02,
'label':label}
)
ax.scatter(
ds_aws_filtered['glon'],
ds_aws_filtered['glat'],
marker="o",
edgecolor="w",
facecolor="none",
linewidth=1.0,
)
ax.set_axis_off()
ax.set_title('')
plt.tight_layout()

The equivalent for the log(variance) is also given below:
Show code cell source
# Plotting the predictions for the unbiased June mean temperature and the bias.
fig, axs = plt.subplots(2, 2, figsize=(15, 10),dpi=100)
metrics = ['logvar_unbiased_predictive_exp','logvar_biased_predictive_exp','logvar_unbiased_predictive_2sigma','logvar_biased_predictive_2sigma']
cmaps = ['viridis','RdBu','plasma','plasma']
vminmaxs = [(1.4,3.0),(-0.25,0.25),(0.3,0.45),(0.3,0.45)]
labels = ['Expectation Unbiased June Log(Var) Temperature','Expectation Bias June Log(Var) Temperature',
'Uncertainty Unbiased June Log(Var) Temperature','Uncertainty Bias June Log(Var) Temperature']
for ax,metric,vminmax,cmap,label in zip(axs.ravel(),metrics,vminmaxs,cmaps,labels):
background_map_rotatedcoords(ax)
ds_climate_stacked[metric].unstack().plot.pcolormesh(
x='glon',
y='glat',
ax=ax,
alpha=0.9,
vmin=vminmax[0],
vmax=vminmax[1],
cmap=cmap,
cbar_kwargs = {'fraction':0.030,
'pad':0.02,
'label':label}
)
ax.scatter(
ds_aws_filtered['glon'],
ds_aws_filtered['glat'],
marker="o",
edgecolor="w",
facecolor="none",
linewidth=1.0,
)
ax.set_axis_off()
ax.set_title('')
plt.tight_layout()

There are various interesting things to note about these predictions:
The expectation for the bias varies smoothly across the domain, which is the result of modelling it with a single lengthscale in the latent Gaussian process. In reality we know the bias will vary more sharply across the region, however with limited data capturing the large scale patterns is already a good start and the smoothness makes it clear we’re not overfitting to the data. Being resistant to overfitting is a common benefit of Gaussian processes.
The uncertainty is dominated by a constant noise term across the domain, which makes sense since we expect sharp variations that are difficult to predict with the limited weather station data available.
The uncertainty is lowest nearby clusters of weather stations. This makes sense and is the result of modelling the spatial covariance. Away from weather stations we have limited information about the bias.
Quantile Mapping#
The final step after estimating the unbiased PDF parameter values across the domain is to actually apply the bias correction to the original climate model spatio-temporal output. To do this we utilise quantile mapping, which maps the climate model data onto the CDF represented by the unbiased PDF parameters.
For every record \(j\) at each location \(i\) we apply the mapping: $\(\hat{z}_{s_{i,j}} = F_{Y_{s_i}}^{-1}(F_{Z_{s_i}}(z_{s_{i,j}}))\)$
Where \(F_{Y_{s_i}}^{-1}\) is the unbiased inverse cumulative density function and \(F_{Z_{s_i}}\) is the CDF for the climate model data.
Instead of using the expectation of posterior predictive parameter estimates, realisations of the posterior predictives of the unbiased parameters are utilised and quantile mapping performed with each realisation separately. This allows propagation of uncertainty, producing multiple realisations for the corrected timeseries and thus an uncertainty band on the correction. To illustrate this six AWS sites are chosen and the timeseries from the AWS records plot alongside the timeseries for the original and corrected climate model output from the nearest grid cell. Additionally, the original and bias corrected climate model timeseries are shown for a site isolated from any nearby AWSs.
Taking a selection of 6 weather station sites dotted around the domain and 1 isolated site, we’ll demonstrate the bias corrected time series at each of the nearest climate model grid-cells. To start with we’ll show the locations of the different chosen sites. Then we’ll illustrate the time series for the weather stations at the sites, the time series for the climate model output of the nearest grid-cell and the time series of the bias corrected climate model output for the nearest grid-cell.
Show code cell content
# Including mean expectation for climate model data
ds_climate_stacked_landonly_filtered['exp_mean_climate'] = (['x'],meanfunc_posterior['cmean'].mean(['draw','chain']).data)
ds_climate_stacked_landonly_filtered['exp_logvar_climate'] = (['x'],meanfunc_posterior['clogvar'].mean(['draw','chain']).data)
# Computing nearest grid-cells to weather stations
ox = np.dstack([ds_aws_filtered['glon'],ds_aws_filtered['glat']])[0]
cx = np.dstack([ds_climate_stacked_landonly['glon'],ds_climate_stacked_landonly['glat']])[0]
nn_indecies = []
for point in ox:
nn_indecies.append(distance.cdist([point], cx).argmin())
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered.isel(x=nn_indecies)
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered_nn.assign_coords(nearest_station=("x", ds_aws_filtered.station.data))
ds_climate_stacked_landonly_filtered_nn = ds_climate_stacked_landonly_filtered_nn.swap_dims({"x": "nearest_station"})
Show code cell source
# Plotting selected station locations
fig, ax = plt.subplots(1, 1, figsize=(10, 10),dpi=100)#,frameon=False)
stations = ['Henry','Manuela','Butler Island','Byrd','Relay Station','Dome C']
ds = ds_aws_filtered.sel(station = stations)
ds_isolated = ds_climate_stacked_landonly_filtered.isel(x=2050)
ax.scatter(
ds.glon,
ds.glat,
s=30,
marker='*',
edgecolor='k',
linewidths=0.5,
)
ax.scatter(
ds_isolated.glon,
ds_isolated.glat,
s=30,
marker='*',
edgecolor='k',
linewidths=0.5,
)
gdf_icesheet_rotatedcoords.boundary.plot(ax=ax, color="k", linewidth=0.1)
ax.set_title('')
ax.set_axis_off()
offsets = [[10, 5],[10, 5],[-8, 10],[-8, 10],[-20, -15],[0,10]]
for station,offset in zip(stations,offsets):
ds_station = ds.sel(station = station)
ax.annotate(
f'{ds_station.station.data}',
xy=(ds_station.glon, ds_station.glat), xycoords='data',
xytext=(offset[0],offset[1]), textcoords='offset points',
arrowprops=dict(arrowstyle="->"),
fontsize=10)
ax.annotate(
f'Isolated Location',
xy=(ds_isolated.glon, ds_isolated.glat), xycoords='data',
xytext=(-25,10), textcoords='offset points',
arrowprops=dict(arrowstyle="->"),
fontsize=10)
plt.tight_layout()

Show code cell source
stations = ['Henry','Manuela','Butler Island','Byrd','Relay Station','Dome C','Isolated Location']
labels = ['a.','b.','c.','d.','e.','f.','g.']
fig, axs = plt.subplots(7, 1, figsize=(10, 10),dpi=300)#,frameon=False)
for ax,station,label in zip(axs[:-1],stations[:-1],labels[:-1]):
ds_climate = ds_climate_stacked_landonly_filtered_nn.sel(nearest_station=station)
ds_aws = ds_aws_filtered.sel(station=station)
cdf_c = norm.cdf(
ds_climate['temperature'],
ds_climate['exp_mean_climate'],
np.sqrt(np.exp(ds_climate['exp_logvar_climate'])),
)
mean_unbiased_dist = dist.Normal(ds_climate['mean_unbiased_predictive_exp'].data,np.sqrt(ds_climate['mean_unbiased_predictive_var'].data))
logvar_unbiased_dist = dist.Normal(ds_climate['logvar_unbiased_predictive_exp'].data,np.sqrt(ds_climate['logvar_unbiased_predictive_var'].data))
mean_unbiased_samples = mean_unbiased_dist.sample(rng_key,(1000,))
logvar_unbiased_samples = logvar_unbiased_dist.sample(rng_key,(1000,))
c_corrected = norm.ppf(
cdf_c,
mean_unbiased_samples.reshape(-1, 1),
np.sqrt(np.exp(logvar_unbiased_samples)).reshape(-1, 1),
)
ax.annotate(label+station,xy=(0.01,1.02),xycoords='axes fraction')
ds_climate['temperature'].plot(x='year',
ax=ax,
marker='+',
linestyle="-",
linewidth=0.8,
zorder=2,
label='Climate Model Ouput')
ds_aws['temperature'].plot(x='year',
ax=ax,
marker='x',
linestyle="-",
linewidth=0.8,
zorder=2,
label='Nearest AWS Ouput')
ax.plot(ds_climate['year'],
c_corrected.mean(axis=0),
color='k',
marker='+',
linestyle="-",
linewidth=0.8,
zorder=1,
label='Bias Corrected Ouput Expectation')
ax.fill_between(
ds_climate['year'],
c_corrected.mean(axis=0)
+ 3 * c_corrected.std(axis=0),
c_corrected.mean(axis=0)
- 3 * c_corrected.std(axis=0),
interpolate=True,
color="k",
alpha=0.5,
label="Bias Corrected Output Uncertainty 3$\sigma$",
linewidth=0.5,
facecolor="none",
edgecolor="k",
linestyle=(0, (5, 2)),
)
for corrected_timeseries in c_corrected[::10]:
ax.plot(
ds_climate['year'],
corrected_timeseries,
color="k",
alpha=0.2,
linestyle="-",
linewidth=0.2,
zorder=1,
)
ax=axs[-1]
ds_isolated = ds_climate_stacked_landonly_filtered.isel(x=2050)
cdf_c = norm.cdf(
ds_isolated['temperature'],
ds_isolated['exp_mean_climate'],
np.sqrt(np.exp(ds_isolated['exp_logvar_climate'])),
)
mean_unbiased_dist = dist.Normal(ds_isolated['mean_unbiased_predictive_exp'].data,np.sqrt(ds_isolated['mean_unbiased_predictive_var'].data))
logvar_unbiased_dist = dist.Normal(ds_isolated['logvar_unbiased_predictive_exp'].data,np.sqrt(ds_isolated['logvar_unbiased_predictive_var'].data))
mean_unbiased_samples = mean_unbiased_dist.sample(rng_key,(1000,))
logvar_unbiased_samples = logvar_unbiased_dist.sample(rng_key,(1000,))
c_corrected = norm.ppf(
cdf_c,
mean_unbiased_samples.reshape(-1, 1),
np.sqrt(np.exp(logvar_unbiased_samples)).reshape(-1, 1),
)
ax.annotate('g. Isolated Location',xy=(0.01,1.02),xycoords='axes fraction')
ds_isolated['temperature'].plot(x='year',
ax=ax,
marker='+',
linestyle="-",
linewidth=0.8,
zorder=2,
label='Climate Model Ouput')
ax.plot(ds_isolated['year'],
c_corrected.mean(axis=0),
color='k',
marker='+',
linestyle="-",
linewidth=0.8,
zorder=1,
label='Bias Corrected Ouput Expectation')
ax.fill_between(
ds_isolated['year'],
c_corrected.mean(axis=0)
+ 3 * c_corrected.std(axis=0),
c_corrected.mean(axis=0)
- 3 * c_corrected.std(axis=0),
interpolate=True,
color="k",
alpha=0.5,
label="Bias Corrected Output Uncertainty 3$\sigma$",
linewidth=0.5,
facecolor="none",
edgecolor="k",
linestyle=(0, (5, 2)),
)
for corrected_timeseries in c_corrected[::10]:
ax.plot(
ds_isolated['year'],
corrected_timeseries,
color="k",
alpha=0.2,
linestyle="-",
linewidth=0.2,
zorder=1,
)
for ax in axs:
ax.set_title('')
ax.set_xlim([1978,2022])
for ax in axs[:-1]:
ax.set_xlabel('')
ax.set_xticklabels('')
handles, labels = ax.get_legend_handles_labels()
fig.legend(
handles,
labels,
fontsize=8,
bbox_to_anchor=(0.5, -0.02),
ncols=4,
loc=10,
)
plt.tight_layout()

The bias correction seems to be performing reasonably, with a general shift of the time series towards that of the nearby unbiased weather station data and realistic uncertainty bands. It’s also worth noting some interesting features visible in the weather station data including:
An apparent temperature anomaly in the Butler Island records;
An apparent shift in the mean after a break for the Relay station.
It is likely these are instrumentation errors rather than physically meaningful phenomenon that need to be captured. This would be best resolved with further data preprocessing and highlights the importance of understanding specific features of the real-world data that are hard to account for in modelling. In general the assumption of time-invariant bias seems reasonable and the model is adding value by improving the timeseries agreement with the observed values and by introducing an uncertainty that can be propagated further through physical ice models, providing some range of possible outcomes, which in turn might be used in various cost-decision analysis schemes.
Conclusions#
Hopefully this notebook provides some interesting ideas that are generalisable to many areas of application including:
Where we care about uncertainty in the end output and creating a flexible modelling framework that propagates uncertainty in the different components and is explicit about the assumptions made through prior distributions.
Where we’re trying to combine multiple datasets with different spatial and temporal coverage.
Where we want to explicitly model the underlying covariance between points.
The main limitations of the methodology presented are in the computational complexity of inference. The model in its current form runs slowly even for small datasets such as presented in this notebook. This makes iterative improvements and careful testing of the results challenging. There are various options to speed up the code that still need exploring and if you’re interested in contributing please do so through the repository link. Another important limitation of the code as it stands is that it’s currently lacking rigorous unit testing of functions and that the documentation needs improving. As such it is suggested to takeaway components of the pipeline to utilise on your own projects and to supplement that with official examples from the relevant package repositories such as Numpyro and TinyGP.