TP: Thin Plate Spline#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import pandas as pd
import plotnine as p9
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
from scipy import stats

rng = np.random.default_rng(1)
x = rng.uniform(-2, 2, 200)

log_sigma = -1.0 + 0.3 * (
    0.5 * x + 15 * stats.norm.pdf(2 * (x - 0.2)) - stats.norm.pdf(x + 0.4)
)
mu = -x + np.pi * np.sin(np.pi * x)
y = mu + jnp.exp(log_sigma) * rng.normal(0.0, 1.0, 200)

df = pd.DataFrame({"y": y, "x": x})
(p9.ggplot(df) + p9.geom_point(p9.aes("x", "y")))

Model Definition#

tb = gam.TermBuilder.from_df(df)

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)


y = lsl.Var.new_obs(
    value=df.y.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)
loc += tb.tp("x", k=20)
scale += tb.tp("x", k=20)

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
../../_images/b17e7a5f5302e82065f813ba978b2285c448e85519760205a1b674ebb8bf2809.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{tp(x)1}$', '$\\tau_{tp(x)1}^2$', '$\\beta_{0,\\mu}$', '$\\beta_{tp(x)}$', '$\\tau_{tp(x)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:04<00:00,  1.47s/chunk]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:03<00:00,  3.27chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,\mu}$ () kernel_03 -0.319492 0.058619 -0.417282 -0.318300 -0.223115 4000 759.178939 1390.140373 1.006856
$\beta_{0,\sigma}$ () kernel_00 -0.633012 0.054018 -0.720973 -0.633687 -0.543151 4000 3654.875916 3506.584384 1.000164
$\beta_{tp(x)1}$ (0,) kernel_01 -0.223218 3.381989 -5.825364 -0.159784 5.198373 4000 3923.407807 3448.435510 1.001015
(1,) kernel_01 -0.651185 3.449494 -6.292287 -0.535440 4.735307 4000 3724.497219 3288.718322 0.999892
(2,) kernel_01 -0.388206 3.363310 -5.801289 -0.358361 5.081929 4000 3755.369546 3496.006638 0.999853
(3,) kernel_01 -0.188393 3.393148 -5.681811 -0.162407 5.377168 4000 3724.284559 3302.598441 1.000447
(4,) kernel_01 0.161879 3.360948 -5.370816 0.231757 5.434961 4000 3286.151142 3084.633690 1.000577
(5,) kernel_01 -0.309639 3.443770 -5.932868 -0.256730 5.230284 4000 3655.642964 3211.616561 1.002177
(6,) kernel_01 0.809840 3.432589 -4.439904 0.633200 6.725374 4000 3575.146040 3400.868275 0.999992
(7,) kernel_01 1.374591 3.457020 -3.672742 1.155379 7.192048 4000 3474.675486 2976.197048 0.999856
(8,) kernel_01 -0.379438 3.039759 -5.486038 -0.272999 4.364275 4000 3730.393011 3703.683340 1.000608
(9,) kernel_01 0.636093 3.227236 -4.391140 0.570598 6.003035 4000 3317.688572 3445.043578 1.000090
(10,) kernel_01 -0.946280 3.241883 -6.405692 -0.721521 3.902848 4000 3391.859257 3348.902498 1.000342
(11,) kernel_01 1.580806 2.706721 -2.671009 1.459306 6.193112 4000 3445.880904 3559.079020 1.000430
(12,) kernel_01 1.673710 2.902806 -2.619197 1.468997 6.849204 4000 2921.664383 3526.974192 1.000376
(13,) kernel_01 -4.757657 2.222360 -8.588407 -4.659715 -1.258762 4000 2426.777267 3540.149249 1.001801
(14,) kernel_01 -0.718017 2.041734 -4.166221 -0.686317 2.561756 4000 3505.243579 3700.864220 0.999957
(15,) kernel_01 5.139132 0.955577 3.641691 5.099299 6.786162 4000 3068.259913 3551.104372 1.001718
(16,) kernel_01 -0.024585 1.100135 -1.825384 0.006827 1.758128 4000 3518.617787 3608.466848 1.000523
(17,) kernel_01 -0.066763 2.001000 -3.337880 -0.092207 3.238401 4000 3797.440515 3576.147203 1.001054
(18,) kernel_01 3.526665 6.968064 -8.181623 3.555306 14.983890 4000 3527.731927 3700.005411 1.000842
$\beta_{tp(x)}$ (0,) kernel_04 17.374926 12.048872 -2.115159 17.097779 37.496637 4000 3418.190214 4058.040563 1.000037
(1,) kernel_04 -8.206913 17.057421 -36.587266 -8.184768 19.207555 4000 3976.469799 3851.250567 1.000226
(2,) kernel_04 -6.139669 16.142166 -33.055743 -6.050435 19.953465 4000 3836.027444 3727.531343 0.999774
(3,) kernel_04 6.755631 12.783772 -14.409870 6.634890 27.528941 4000 4070.418794 3915.357780 0.999657
(4,) kernel_04 4.613989 17.254919 -23.685720 4.607632 32.291230 4000 3848.630123 3984.331627 1.000886
(5,) kernel_04 8.255840 10.909239 -10.050988 8.203017 26.236567 4000 4095.838057 3996.902283 1.000622
(6,) kernel_04 16.417242 13.566302 -5.733943 16.346449 38.926948 4000 3211.902957 3929.810460 0.999822
(7,) kernel_04 -1.876089 10.845407 -19.825790 -1.948507 15.918729 4000 4001.367343 3929.565246 1.000351
(8,) kernel_04 0.418349 9.047864 -14.436178 0.565723 15.286263 4000 3998.999811 3779.979486 0.999745
(9,) kernel_04 -11.152001 9.711149 -27.169579 -11.046268 4.734870 4000 3826.114597 3892.753701 1.000264
(10,) kernel_04 -2.454368 9.581319 -18.239215 -2.447716 13.228080 4000 3765.213797 3755.722132 1.000382
(11,) kernel_04 -12.750726 4.834691 -20.740265 -12.765520 -4.804670 4000 3431.418696 3880.480789 0.999577
(12,) kernel_04 -3.595982 5.778088 -12.961833 -3.660667 6.236989 4000 3410.139668 3652.386708 1.000816
(13,) kernel_04 -58.499371 2.609229 -62.750878 -58.546143 -54.147104 4000 2089.279937 3456.248353 1.001315
(14,) kernel_04 41.967918 3.651897 35.864888 42.005041 48.015962 4000 3880.537990 3776.407019 1.000062
(15,) kernel_04 2.158853 1.358711 -0.054415 2.201960 4.325503 4000 1675.947179 2866.601782 1.001129
(16,) kernel_04 7.572556 1.456800 5.200209 7.597540 9.913282 4000 3845.597316 3809.627550 0.999615
(17,) kernel_04 4.224003 3.875293 -2.297158 4.341357 10.549628 4000 3711.640189 3610.332456 0.999772
(18,) kernel_04 95.929138 11.976497 75.981076 96.160095 115.058223 4000 3751.910039 3728.348665 1.000273
$\tau_{tp(x)1}^2$ () kernel_02 11.939083 10.116852 2.905346 9.234104 30.152539 4000 1687.784985 2201.192031 1.003295
$\tau_{tp(x)}^2$ () kernel_05 460.618439 186.002197 239.091196 419.973419 819.621671 4000 3240.689942 3557.437275 0.999790

Predictions#

samples = results.get_posterior_samples()
gam.plot_1d_smooth(term=model.vars["tp(x)"], samples=samples)
gam.plot_1d_smooth(term=model.vars["tp(x)1"], samples=samples)