NP: P-Spline without linear trend#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import pandas as pd
import plotnine as p9
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
from scipy import stats

rng = np.random.default_rng(1)
x = rng.uniform(-2, 2, 200)

log_sigma = -1.0 + 0.3 * (
    0.5 * x + 15 * stats.norm.pdf(2 * (x - 0.2)) - stats.norm.pdf(x + 0.4)
)
mu = -x + np.pi * np.sin(np.pi * x)
y = mu + jnp.exp(log_sigma) * rng.normal(0.0, 1.0, 200)

df = pd.DataFrame({"y": y, "x": x})
(p9.ggplot(df) + p9.geom_point(p9.aes("x", "y")))

Model Definition#

Setup response model#

loc = gam.AdditivePredictor("loc")
scale = gam.AdditivePredictor("scale", inv_link=jnp.exp)


y = lsl.Var.new_obs(
    value=df.y.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)


registry = gam.PandasRegistry(df)
tbl = gam.TermBuilder(registry, prefix_names_by="loc.")
tbs = gam.TermBuilder(registry, prefix_names_by="scale.")

loc += tbl.np("x", k=20)
loc += tbl.lin("x")  # adding the linear part here individually

scale += tbs.np("x", k=20)
scale += tbs.lin("x")  # adding the linear part here individually

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
../../_images/a2772b814b155ff590ac3f57681644f849060671c5ef11636dc4edf395846cfa.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,scale}$', '$\\beta_{scale.lin(scale.X)}$', '$\\beta_{scale.np(x)}$', '$\\tau_{scale.np(x)}^2$', '$\\beta_{0,loc}$', '$\\beta_{loc.lin(loc.X)}$', '$\\beta_{loc.np(x)}$', '$\\tau_{loc.np(x)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:05<00:00,  1.95s/chunk]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:02<00:00,  3.55chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,loc}$ () kernel_04 -0.262565 0.063026 -0.367806 -0.261719 -0.159686 4000 731.135465 1315.568448 1.003694
$\beta_{0,scale}$ () kernel_00 -0.644294 0.054516 -0.733098 -0.644662 -0.552065 4000 3679.561860 3768.375510 0.999715
$\beta_{loc.lin(loc.X)}$ (0,) kernel_05 -1.699004 0.028991 -1.746209 -1.699281 -1.650748 4000 1596.645440 2457.752385 1.002202
$\beta_{loc.np(x)}$ (0,) kernel_06 -0.023772 0.058601 -0.117582 -0.024000 0.072800 4000 3673.579489 3853.021331 1.000588
(1,) kernel_06 -0.024977 0.052526 -0.111056 -0.025245 0.059996 4000 4152.731755 3821.546991 1.000870
(2,) kernel_06 -0.056511 0.060965 -0.157582 -0.056107 0.042630 4000 3896.934944 3357.358740 1.000126
(3,) kernel_06 0.005730 0.040782 -0.062335 0.005211 0.073685 4000 3757.991069 3580.980547 1.000284
(4,) kernel_06 -0.007901 0.055968 -0.102620 -0.006762 0.081456 4000 3818.878681 3847.971846 1.000081
(5,) kernel_06 0.045408 0.040415 -0.020446 0.045365 0.111851 4000 3669.943702 3972.009447 0.999675
(6,) kernel_06 -0.023389 0.050966 -0.106563 -0.021877 0.057516 4000 3719.711294 3930.528930 0.999972
(7,) kernel_06 -0.013281 0.049341 -0.093760 -0.013230 0.068373 4000 3725.867059 3725.745145 0.999609
(8,) kernel_06 0.007076 0.033208 -0.047084 0.007049 0.061971 4000 3858.359498 3812.729644 0.999782
(9,) kernel_06 -0.011697 0.038719 -0.076326 -0.010973 0.050388 4000 3286.633610 3087.640627 1.000076
(10,) kernel_06 -0.018240 0.028765 -0.065644 -0.018369 0.029211 4000 3387.247658 3765.233153 1.000099
(11,) kernel_06 -0.019275 0.026975 -0.063971 -0.019060 0.025981 4000 3139.965313 3852.364610 1.000345
(12,) kernel_06 -0.010433 0.020907 -0.044253 -0.010316 0.023129 4000 3051.205113 3625.121863 1.000090
(13,) kernel_06 0.043470 0.015434 0.018575 0.043568 0.068089 4000 2867.159664 3659.797307 1.000412
(14,) kernel_06 -0.202883 0.012822 -0.224227 -0.203045 -0.181614 4000 3125.674756 3558.386312 0.999988
(15,) kernel_06 0.018861 0.007626 0.006146 0.018904 0.031303 4000 2328.286296 3321.083852 1.000501
(16,) kernel_06 -0.019449 0.004636 -0.026990 -0.019497 -0.011909 4000 3262.832509 3827.522054 0.999820
(17,) kernel_06 -0.000499 0.001527 -0.002959 -0.000502 0.001958 4000 2209.676668 3449.875344 1.000911
$\beta_{scale.lin(scale.X)}$ (0,) kernel_01 0.308974 0.049332 0.229038 0.308665 0.392484 4000 4112.142533 3611.075758 1.001403
$\beta_{scale.np(x)}$ (0,) kernel_02 -0.030080 0.036138 -0.090191 -0.028474 0.026542 4000 3402.077869 3377.567459 1.001110
(1,) kernel_02 0.014640 0.034437 -0.040182 0.013937 0.070212 4000 3226.629280 3562.209234 1.000468
(2,) kernel_02 -0.007568 0.038298 -0.073435 -0.006822 0.052785 4000 3350.123627 3387.046762 1.001403
(3,) kernel_02 0.010039 0.036874 -0.050209 0.008992 0.071646 4000 3767.399897 3683.884391 1.000068
(4,) kernel_02 0.003103 0.035499 -0.055316 0.003609 0.061267 4000 3571.414176 3762.777028 1.000248
(5,) kernel_02 -0.000883 0.034998 -0.060303 -0.000923 0.054520 4000 3772.873127 3834.187148 0.999781
(6,) kernel_02 -0.022797 0.034212 -0.080215 -0.021482 0.032424 4000 3481.451038 3582.676970 1.000341
(7,) kernel_02 0.012474 0.032472 -0.038521 0.011810 0.066931 4000 3617.493122 3922.096288 1.000383
(8,) kernel_02 0.003202 0.029793 -0.045754 0.003420 0.051959 4000 3397.225682 3859.807403 1.000337
(9,) kernel_02 0.007653 0.027542 -0.036162 0.006799 0.053635 4000 3481.877072 3666.544849 1.000064
(10,) kernel_02 -0.007116 0.026728 -0.050537 -0.007288 0.035927 4000 3481.777823 3546.634577 0.999608
(11,) kernel_02 0.003584 0.022366 -0.033361 0.003476 0.040707 4000 3242.625520 3738.237714 1.000315
(12,) kernel_02 -0.028370 0.020249 -0.061930 -0.028054 0.004187 4000 3187.345846 3651.581387 1.000063
(13,) kernel_02 0.019051 0.017124 -0.008231 0.018610 0.047378 4000 3502.302557 3800.172761 1.000053
(14,) kernel_02 -0.001877 0.012511 -0.021904 -0.002114 0.018657 4000 3259.037202 3619.829076 1.000479
(15,) kernel_02 0.012170 0.007718 -0.000280 0.012275 0.024904 4000 3203.367996 3689.792445 0.999534
(16,) kernel_02 0.001885 0.005395 -0.006974 0.001918 0.010648 4000 3235.225329 3689.129479 1.000310
(17,) kernel_02 0.007870 0.001818 0.004747 0.007895 0.010839 4000 3060.010313 3333.807839 1.000105
$\tau_{loc.np(x)}^2$ () kernel_07 0.004900 0.002147 0.002425 0.004450 0.008883 4000 3667.804651 3727.488991 1.001028
$\tau_{scale.np(x)}^2$ () kernel_03 0.001506 0.000774 0.000672 0.001319 0.002985 4000 3321.349810 3666.860804 1.000857

MCMC trace plots#

gs.plot_trace(results)
../../_images/fd002d9f5297aeb4ec27da6d17e9b50f5482714cc271702f31e585cce4b307bb.png
<seaborn.axisgrid.FacetGrid at 0x136711d30>

Predictions#

samples = results.get_posterior_samples()
gam.plot_1d_smooth(term=model.vars["loc.np(x)"], samples=samples)
gam.plot_1d_smooth(term=model.vars["scale.np(x)"], samples=samples)