TP: Thin Plate Spline#
Setup and Imports#
import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import pandas as pd
import plotnine as p9
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel_gam as gam
from scipy import stats
rng = np.random.default_rng(1)
x = rng.uniform(-2, 2, 200)
log_sigma = -1.0 + 0.3 * (
0.5 * x + 15 * stats.norm.pdf(2 * (x - 0.2)) - stats.norm.pdf(x + 0.4)
)
mu = -x + np.pi * np.sin(np.pi * x)
y = mu + jnp.exp(log_sigma) * rng.normal(0.0, 1.0, 200)
df = pd.DataFrame({"y": y, "x": x})
Model Definition#
tb = gam.TermBuilder.from_df(df)
loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)
y = lsl.Var.new_obs(
value=df.y.to_numpy(),
distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
name="y",
)
loc += tb.tp("x", k=20)
scale += tb.tp("x", k=20)
Build and plot model#
model = lsl.Model([y])
model.plot_vars()
Run MCMC#
eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)
engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{tp(x)1}$', '$\\tau_{tp(x)1}^2$', '$\\beta_{0,\\mu}$', '$\\beta_{tp(x)}$', '$\\tau_{tp(x)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:04<00:00, 1.47s/chunk]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:03<00:00, 3.27chunk/s]
liesel.goose.engine - INFO - Finished epoch
MCMC summary#
summary = gs.Summary(results)
summary
Parameter summary:
| kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| parameter | index | ||||||||||
| $\beta_{0,\mu}$ | () | kernel_03 | -0.319492 | 0.058619 | -0.417282 | -0.318300 | -0.223115 | 4000 | 759.178939 | 1390.140373 | 1.006856 |
| $\beta_{0,\sigma}$ | () | kernel_00 | -0.633012 | 0.054018 | -0.720973 | -0.633687 | -0.543151 | 4000 | 3654.875916 | 3506.584384 | 1.000164 |
| $\beta_{tp(x)1}$ | (0,) | kernel_01 | -0.223218 | 3.381989 | -5.825364 | -0.159784 | 5.198373 | 4000 | 3923.407807 | 3448.435510 | 1.001015 |
| (1,) | kernel_01 | -0.651185 | 3.449494 | -6.292287 | -0.535440 | 4.735307 | 4000 | 3724.497219 | 3288.718322 | 0.999892 | |
| (2,) | kernel_01 | -0.388206 | 3.363310 | -5.801289 | -0.358361 | 5.081929 | 4000 | 3755.369546 | 3496.006638 | 0.999853 | |
| (3,) | kernel_01 | -0.188393 | 3.393148 | -5.681811 | -0.162407 | 5.377168 | 4000 | 3724.284559 | 3302.598441 | 1.000447 | |
| (4,) | kernel_01 | 0.161879 | 3.360948 | -5.370816 | 0.231757 | 5.434961 | 4000 | 3286.151142 | 3084.633690 | 1.000577 | |
| (5,) | kernel_01 | -0.309639 | 3.443770 | -5.932868 | -0.256730 | 5.230284 | 4000 | 3655.642964 | 3211.616561 | 1.002177 | |
| (6,) | kernel_01 | 0.809840 | 3.432589 | -4.439904 | 0.633200 | 6.725374 | 4000 | 3575.146040 | 3400.868275 | 0.999992 | |
| (7,) | kernel_01 | 1.374591 | 3.457020 | -3.672742 | 1.155379 | 7.192048 | 4000 | 3474.675486 | 2976.197048 | 0.999856 | |
| (8,) | kernel_01 | -0.379438 | 3.039759 | -5.486038 | -0.272999 | 4.364275 | 4000 | 3730.393011 | 3703.683340 | 1.000608 | |
| (9,) | kernel_01 | 0.636093 | 3.227236 | -4.391140 | 0.570598 | 6.003035 | 4000 | 3317.688572 | 3445.043578 | 1.000090 | |
| (10,) | kernel_01 | -0.946280 | 3.241883 | -6.405692 | -0.721521 | 3.902848 | 4000 | 3391.859257 | 3348.902498 | 1.000342 | |
| (11,) | kernel_01 | 1.580806 | 2.706721 | -2.671009 | 1.459306 | 6.193112 | 4000 | 3445.880904 | 3559.079020 | 1.000430 | |
| (12,) | kernel_01 | 1.673710 | 2.902806 | -2.619197 | 1.468997 | 6.849204 | 4000 | 2921.664383 | 3526.974192 | 1.000376 | |
| (13,) | kernel_01 | -4.757657 | 2.222360 | -8.588407 | -4.659715 | -1.258762 | 4000 | 2426.777267 | 3540.149249 | 1.001801 | |
| (14,) | kernel_01 | -0.718017 | 2.041734 | -4.166221 | -0.686317 | 2.561756 | 4000 | 3505.243579 | 3700.864220 | 0.999957 | |
| (15,) | kernel_01 | 5.139132 | 0.955577 | 3.641691 | 5.099299 | 6.786162 | 4000 | 3068.259913 | 3551.104372 | 1.001718 | |
| (16,) | kernel_01 | -0.024585 | 1.100135 | -1.825384 | 0.006827 | 1.758128 | 4000 | 3518.617787 | 3608.466848 | 1.000523 | |
| (17,) | kernel_01 | -0.066763 | 2.001000 | -3.337880 | -0.092207 | 3.238401 | 4000 | 3797.440515 | 3576.147203 | 1.001054 | |
| (18,) | kernel_01 | 3.526665 | 6.968064 | -8.181623 | 3.555306 | 14.983890 | 4000 | 3527.731927 | 3700.005411 | 1.000842 | |
| $\beta_{tp(x)}$ | (0,) | kernel_04 | 17.374926 | 12.048872 | -2.115159 | 17.097779 | 37.496637 | 4000 | 3418.190214 | 4058.040563 | 1.000037 |
| (1,) | kernel_04 | -8.206913 | 17.057421 | -36.587266 | -8.184768 | 19.207555 | 4000 | 3976.469799 | 3851.250567 | 1.000226 | |
| (2,) | kernel_04 | -6.139669 | 16.142166 | -33.055743 | -6.050435 | 19.953465 | 4000 | 3836.027444 | 3727.531343 | 0.999774 | |
| (3,) | kernel_04 | 6.755631 | 12.783772 | -14.409870 | 6.634890 | 27.528941 | 4000 | 4070.418794 | 3915.357780 | 0.999657 | |
| (4,) | kernel_04 | 4.613989 | 17.254919 | -23.685720 | 4.607632 | 32.291230 | 4000 | 3848.630123 | 3984.331627 | 1.000886 | |
| (5,) | kernel_04 | 8.255840 | 10.909239 | -10.050988 | 8.203017 | 26.236567 | 4000 | 4095.838057 | 3996.902283 | 1.000622 | |
| (6,) | kernel_04 | 16.417242 | 13.566302 | -5.733943 | 16.346449 | 38.926948 | 4000 | 3211.902957 | 3929.810460 | 0.999822 | |
| (7,) | kernel_04 | -1.876089 | 10.845407 | -19.825790 | -1.948507 | 15.918729 | 4000 | 4001.367343 | 3929.565246 | 1.000351 | |
| (8,) | kernel_04 | 0.418349 | 9.047864 | -14.436178 | 0.565723 | 15.286263 | 4000 | 3998.999811 | 3779.979486 | 0.999745 | |
| (9,) | kernel_04 | -11.152001 | 9.711149 | -27.169579 | -11.046268 | 4.734870 | 4000 | 3826.114597 | 3892.753701 | 1.000264 | |
| (10,) | kernel_04 | -2.454368 | 9.581319 | -18.239215 | -2.447716 | 13.228080 | 4000 | 3765.213797 | 3755.722132 | 1.000382 | |
| (11,) | kernel_04 | -12.750726 | 4.834691 | -20.740265 | -12.765520 | -4.804670 | 4000 | 3431.418696 | 3880.480789 | 0.999577 | |
| (12,) | kernel_04 | -3.595982 | 5.778088 | -12.961833 | -3.660667 | 6.236989 | 4000 | 3410.139668 | 3652.386708 | 1.000816 | |
| (13,) | kernel_04 | -58.499371 | 2.609229 | -62.750878 | -58.546143 | -54.147104 | 4000 | 2089.279937 | 3456.248353 | 1.001315 | |
| (14,) | kernel_04 | 41.967918 | 3.651897 | 35.864888 | 42.005041 | 48.015962 | 4000 | 3880.537990 | 3776.407019 | 1.000062 | |
| (15,) | kernel_04 | 2.158853 | 1.358711 | -0.054415 | 2.201960 | 4.325503 | 4000 | 1675.947179 | 2866.601782 | 1.001129 | |
| (16,) | kernel_04 | 7.572556 | 1.456800 | 5.200209 | 7.597540 | 9.913282 | 4000 | 3845.597316 | 3809.627550 | 0.999615 | |
| (17,) | kernel_04 | 4.224003 | 3.875293 | -2.297158 | 4.341357 | 10.549628 | 4000 | 3711.640189 | 3610.332456 | 0.999772 | |
| (18,) | kernel_04 | 95.929138 | 11.976497 | 75.981076 | 96.160095 | 115.058223 | 4000 | 3751.910039 | 3728.348665 | 1.000273 | |
| $\tau_{tp(x)1}^2$ | () | kernel_02 | 11.939083 | 10.116852 | 2.905346 | 9.234104 | 30.152539 | 4000 | 1687.784985 | 2201.192031 | 1.003295 |
| $\tau_{tp(x)}^2$ | () | kernel_05 | 460.618439 | 186.002197 | 239.091196 | 419.973419 | 819.621671 | 4000 | 3240.689942 | 3557.437275 | 0.999790 |
Predictions#
samples = results.get_posterior_samples()