TF: Full Tensor Product#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import matplotlib.pyplot as plt
import numpy as np
import plotnine as p9
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
import jax

jax.config.update("jax_enable_x64", True)
df = gam.demo_data_ta(n=600, noise_sd=0.25, seed=42)
df_grid = gam.demo_data_ta(n=5000, grid=True)
plt.figure(figsize=(6, 5))
plt.scatter(df["x"], df["y"], c=df["z"])
plt.xlabel("x")
plt.ylabel("y")
plt.title("2D Color Plot")
plt.colorbar(label="eta")
plt.tight_layout()
plt.show()
../../_images/7aedd143977d7fcfb53a61fc79894e2e0d0754cfd68b69cd4b3d7b48996ebc80.png
plt.figure(figsize=(6, 5))
plt.scatter(df_grid["x"], df_grid["y"], c=df_grid["eta"])
plt.xlabel("x")
plt.ylabel("y")
plt.title("2D Color Plot")
plt.colorbar(label="eta")
plt.tight_layout()
plt.show()
../../_images/9c3b6d969216f251b75184b1efa5132bef46c888f324c6c7eb825365fe03cd46.png

Model Definition#

Setup response model#

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)


z = lsl.Var.new_obs(
    value=df.z.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="z",
)
tb = gam.TermBuilder.from_df(df)
psx = tb.ps("x", k=12)
psy = tb.ps("y", k=12)
loc += tb.tf(psx, psy)

Build and plot model#

model = lsl.Model([z], to_float32=False)
model.plot_vars()
../../_images/835c360d9987876851cd03fb7eea752e9612fb7bd080848bb9dd8feba8adbc4c.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_adaptation(3000)  # adaptation instead of burnin, because scales use HMC kernel
# eb.add_burnin(5000)  # adaptation instead of burnin, because scales use HMC kernel
eb.add_posterior(11_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{tf(x,y)}$', 'ln($\\tau_{ps(y)}^2$)', 'ln($\\tau_{ps(x)}^2$)', '$\\beta_{ps(y)}$', '$\\beta_{ps(x)}$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 300 transitions, 25 jitted together
100%|████████████████████████████████████████| 12/12 [00:08<00:00,  1.48chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 26, 17, 24, 18 / 300 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 18, 27, 13, 17 / 300 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
100%|█████████████████████████████████████████| 1/1 [00:00<00:00, 611.68chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 4, 4, 4, 2 / 25 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 5, 4, 5, 5 / 25 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
100%|████████████████████████████████████████| 2/2 [00:00<00:00, 1006.55chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 5, 4, 5, 7 / 50 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 4, 5, 5, 6 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
100%|████████████████████████████████████████| 4/4 [00:00<00:00, 1244.51chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 10, 9, 7, 13 / 100 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 9, 7, 7, 8 / 100 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
100%|█████████████████████████████████████████| 8/8 [00:00<00:00, 121.60chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 17, 15, 14, 24 / 200 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 20, 16, 20, 12 / 200 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 400 transitions, 25 jitted together
100%|████████████████████████████████████████| 16/16 [00:00<00:00, 28.33chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 33, 28, 19, 29 / 400 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 28, 20, 29, 29 / 400 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 1325 transitions, 25 jitted together
100%|████████████████████████████████████████| 53/53 [00:02<00:00, 18.84chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 40, 45, 45, 50 / 1325 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 49, 50, 50, 47 / 1325 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 600 transitions, 25 jitted together
100%|████████████████████████████████████████| 24/24 [00:01<00:00, 21.02chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_03: 31, 23, 31, 27 / 600 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_04: 29, 29, 36, 39 / 600 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 11000 transitions, 25 jitted together
100%|██████████████████████████████████████| 440/440 [00:28<00:00, 15.24chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)

diagnostics = (
    summary.to_dataframe()
    .reset_index()
    .loc[:, ["variable", "rhat", "ess_bulk", "ess_tail"]]
    .groupby("variable", as_index=False)
    .agg(
        ess_bulk_min=("ess_bulk", "min"),
        ess_bulk_median=("ess_bulk", "median"),
        ess_tail_min=("ess_tail", "min"),
        ess_tail_median=("ess_tail", "median"),
        rhat_max=("rhat", "max"),
        rhat_median=("rhat", "median"),
    )
)
diagnostics
variable ess_bulk_min ess_bulk_median ess_tail_min ess_tail_median rhat_max rhat_median
0 $\beta_{0,\mu}$ 3667.153507 3667.153507 4143.494652 4143.494652 1.000081 1.000081
1 $\beta_{0,\sigma}$ 3237.887897 3237.887897 4234.954461 4234.954461 1.000708 1.000708
2 $\beta_{ps(x)}$ 2654.456086 2950.001567 3329.684058 3773.720469 1.003448 1.002228
3 $\beta_{ps(y)}$ 1715.689416 3458.595736 2951.163691 3706.319678 1.001731 1.000516
4 $\beta_{tf(x,y)}$ 562.751795 1610.338148 288.872165 2004.276465 1.007987 1.002640
5 ln($\tau_{ps(x)}^2$) 357.905826 357.905826 1083.738474 1083.738474 1.009642 1.009642
6 ln($\tau_{ps(y)}^2$) 210.415649 210.415649 352.794218 352.794218 1.010037 1.010037
summary.error_df()
count sample_size sample_size_total relative
kernel error_code error_msg phase
kernel_03 1 divergent transition warmup 630 12000 12000 0.0525
posterior 0 4400 44000 0.0
kernel_04 1 divergent transition warmup 648 12000 12000 0.054
posterior 0 4400 44000 0.0
gs.plot_trace(results, [n for n in model.parameters if "tau" in n])
../../_images/6ff346bae08ba33eb9522c1fc8c9937d61aea186eb255bdcb843fff840c00e4b.png
<seaborn.axisgrid.FacetGrid at 0x149885450>

Predictions#

samples = results.get_posterior_samples()

Plot#

gam.plot_2d_smooth(model.vars["tf(x,y)"], samples, ngrid=100)

Predict variables at new x values#

predictions = model.predict(
    samples=samples,
    predict=["tf(x,y)", "$\\mu$"],
    newdata={"x": df_grid.x.to_numpy(), "y": df_grid.y.to_numpy()},
)

predictions_summary = (
    gs.SamplesSummary(predictions, which=["mean", "quantiles"])
    .to_dataframe()
    .reset_index()
)
predictions_summary["x"] = np.tile(df_grid.x.to_numpy(), len(predictions))
predictions_summary["y"] = np.tile(df_grid.y.to_numpy(), len(predictions))
predictions_summary.head()
variable var_fqn var_index sample_size mean q_0.05 q_0.5 q_0.95 x y
0 $\mu$ $\mu$[0] (0,) 4400 0.183780 -0.960436 0.194830 1.308365 0.000000 0.0
1 $\mu$ $\mu$[1] (1,) 4400 0.577247 -0.326586 0.575744 1.464964 0.014286 0.0
2 $\mu$ $\mu$[2] (2,) 4400 0.931156 0.203578 0.931666 1.651620 0.028571 0.0
3 $\mu$ $\mu$[3] (3,) 4400 1.230964 0.641704 1.226999 1.833733 0.042857 0.0
4 $\mu$ $\mu$[4] (4,) 4400 1.463612 0.973566 1.462210 1.967166 0.057143 0.0

Plot predictions#

select = predictions_summary["variable"].isin(["tf(x,y)"])
(p9.ggplot(predictions_summary[select]) + p9.geom_tile(p9.aes("x", "y", fill="mean")))
select = predictions_summary["variable"].isin(["$\\mu$"])
(p9.ggplot(predictions_summary[select]) + p9.geom_tile(p9.aes("x", "y", fill="mean")))
select = predictions_summary["variable"].isin(["$\\mu$"])
(
    p9.ggplot(predictions_summary[select].query("y == 0.0"))
    + p9.geom_ribbon(
        p9.aes("x", ymin="q_0.05", ymax="q_0.95", fill="variable"), alpha=0.3
    )
    + p9.geom_line(p9.aes("x", "mean"))
    + p9.facet_wrap("~variable", scales="free_y", ncol=1)
    + p9.guides(fill="none")
)
select = predictions_summary["variable"].isin(["$\\mu$"])
(
    p9.ggplot(predictions_summary[select].query("x == 0.0"))
    + p9.geom_ribbon(
        p9.aes("y", ymin="q_0.05", ymax="q_0.95", fill="variable"), alpha=0.3
    )
    + p9.geom_line(p9.aes("y", "mean"))
    + p9.facet_wrap("~variable", scales="free_y", ncol=1)
    + p9.guides(fill="none")
)