TP: 2d Thin Plat Spline#
Setup and Imports#
import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel_gam as gam
import jax
jax.config.update("jax_enable_x64", True)
df = gam.demo_data_ta(n=600, noise_sd=0.25, seed=42)
df_grid = gam.demo_data_ta(n=5000, grid=True)
plt.figure(figsize=(6, 5))
plt.scatter(df["x"], df["y"], c=df["z"])
plt.xlabel("x")
plt.ylabel("y")
plt.title("2D Color Plot")
plt.colorbar(label="eta")
plt.tight_layout()
plt.show()
plt.figure(figsize=(6, 5))
plt.scatter(df_grid["x"], df_grid["y"], c=df_grid["eta"])
plt.xlabel("x")
plt.ylabel("y")
plt.title("2D Color Plot")
plt.colorbar(label="eta")
plt.tight_layout()
plt.show()
Model Definition#
Setup response model#
loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)
z = lsl.Var.new_obs(
value=df.z.to_numpy(),
distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
name="z",
)
tb = gam.TermBuilder.from_df(df)
loc += tb.tp("x", "y", k=50)
Build and plot model#
model = lsl.Model([z], to_float32=False)
model.plot_vars()
Run MCMC#
eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)
engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{tp(x,y)}$', '$\\tau_{tp(x,y)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:02<00:00, 1.15chunk/s]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:07<00:00, 1.42chunk/s]
liesel.goose.engine - INFO - Finished epoch
MCMC summary#
summary = gs.Summary(results)
diagnostics = (
summary.to_dataframe()
.reset_index()
.loc[:, ["variable", "rhat", "ess_bulk", "ess_tail"]]
.groupby("variable", as_index=False)
.agg(
ess_bulk_min=("ess_bulk", "min"),
ess_bulk_median=("ess_bulk", "median"),
ess_tail_min=("ess_tail", "min"),
ess_tail_median=("ess_tail", "median"),
rhat_max=("rhat", "max"),
rhat_median=("rhat", "median"),
)
)
diagnostics
| variable | ess_bulk_min | ess_bulk_median | ess_tail_min | ess_tail_median | rhat_max | rhat_median | |
|---|---|---|---|---|---|---|---|
| 0 | $\beta_{0,\mu}$ | 3939.966762 | 3939.966762 | 3914.127419 | 3914.127419 | 1.000291 | 1.000291 |
| 1 | $\beta_{0,\sigma}$ | 3831.141762 | 3831.141762 | 3721.044012 | 3721.044012 | 0.999664 | 0.999664 |
| 2 | $\beta_{tp(x,y)}$ | 2659.992156 | 3150.464822 | 3106.133761 | 3520.405328 | 1.002207 | 1.000427 |
| 3 | $\tau_{tp(x,y)}^2$ | 3403.701248 | 3403.701248 | 3570.692099 | 3570.692099 | 1.000384 | 1.000384 |
gs.plot_trace(results, [n for n in model.parameters if "tau" in n])
<seaborn.axisgrid.FacetGrid at 0x133ad92b0>
samples = results.get_posterior_samples()