RI: Random Intercept#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
# import data from R
from ryp import r, to_py

r("library(mgcv)")
r("data(columb)")
r("data(columb.polys)")

columb = to_py("columb", format="pandas").reset_index()
polys = to_py("columb.polys", format="numpy")
Loading required package: nlme
This is mgcv 1.9-3. For overview type 'help("mgcv-package")'.
gam.plot_polys(region="district", which=["crime"], df=columb, polys=polys)

Model Definition#

Setup response model#

df = columb

# standardizing the response makes it a bit easier for the model
df["crime"] = (df["crime"] - df["crime"].mean()) / df["crime"].std()

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)

# initializing intercepts to sensible values
loc.intercept.value = df.crime.mean()
scale.intercept.value = jnp.log(df.crime.std())


y = lsl.Var.new_obs(
    value=df.crime.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)

tb = gam.TermBuilder.from_df(df)
loc += tb.ri("district", factor_scale=True)

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
liesel.model.model - INFO - Converted dtype of Value(name="y_value").value
liesel.model.model - INFO - Converted dtype of Value(name="$\beta_{0,\mu}$_value").value
../../_images/7caa69e9c85a5522464737bc0dc17fc0bc4f7647cb095cf837a73fffffb65e14.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{ri(district)}$', '$\\tau_{ri(district)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:03<00:00,  1.28s/chunk]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:01<00:00,  8.88chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,\mu}$ () kernel_01 -0.005114 0.145367 -0.243600 -0.003515 0.238386 4000 4188.683621 3967.090929 1.000383
$\beta_{0,\sigma}$ () kernel_00 -0.000040 0.107251 -0.171180 -0.001376 0.180301 4000 2746.958717 2764.413367 1.000328
$\beta_{ri(district)}$ (0,) kernel_02 -0.114520 1.013351 -1.776968 -0.124597 1.566036 4000 3416.221126 3735.908558 1.000933
(1,) kernel_02 -0.117665 1.002974 -1.791761 -0.127745 1.517511 4000 3264.898699 3810.867258 1.000232
(2,) kernel_02 -0.019843 1.003361 -1.684396 -0.031249 1.629325 4000 3522.489255 3707.318795 1.000291
(3,) kernel_02 0.006285 0.983708 -1.601071 0.008543 1.612648 4000 3476.530424 3738.541850 1.000166
(4,) kernel_02 0.090275 0.972542 -1.503284 0.098644 1.680782 4000 3316.099908 3633.753244 1.001332
(5,) kernel_02 -0.066115 0.999573 -1.679540 -0.081496 1.595152 4000 3536.744382 3765.834930 0.999882
(6,) kernel_02 -0.237052 1.014906 -1.917265 -0.234726 1.427050 4000 2884.461308 3630.016225 1.000964
(7,) kernel_02 0.022843 0.995122 -1.581131 0.018598 1.678300 4000 3356.598942 3931.357816 0.999812
(8,) kernel_02 -0.024984 0.979623 -1.679014 0.007145 1.561869 4000 3262.258745 3892.776173 0.999910
(9,) kernel_02 -0.004644 0.997477 -1.669456 -0.003536 1.662232 4000 3280.931903 3700.907189 1.000410
(10,) kernel_02 0.187606 0.982266 -1.413700 0.176972 1.764643 4000 3171.840884 3283.873203 1.002087
(11,) kernel_02 0.142257 0.982789 -1.473042 0.150151 1.782089 4000 3441.517823 3605.312268 1.000575
(12,) kernel_02 0.057404 0.981682 -1.594380 0.064655 1.673046 4000 3030.690831 3230.407205 1.001477
(13,) kernel_02 0.132081 1.007395 -1.525492 0.117020 1.839210 4000 3086.934991 3560.734483 1.000830
(14,) kernel_02 0.103865 0.992644 -1.510785 0.107614 1.705495 4000 3481.887299 3633.815683 0.999431
(15,) kernel_02 0.140757 0.995787 -1.460193 0.138621 1.760465 4000 2693.980971 2983.575567 0.999771
(16,) kernel_02 0.029202 0.992062 -1.579299 0.001288 1.655426 4000 3120.448275 3579.729374 0.999752
(17,) kernel_02 0.074492 0.978627 -1.560794 0.069119 1.667364 4000 3502.576451 3813.028105 1.000503
(18,) kernel_02 0.124898 1.013200 -1.541736 0.120315 1.767122 4000 3342.462512 3932.532976 1.001116
(19,) kernel_02 -0.221489 0.993773 -1.834133 -0.199446 1.361556 4000 2850.053212 3737.042192 1.000819
(20,) kernel_02 0.028995 0.981227 -1.580984 0.020811 1.644505 4000 3372.729809 3500.899917 0.999673
(21,) kernel_02 -0.024141 0.976827 -1.604377 -0.027562 1.597849 4000 3246.404692 3441.047827 1.001504
(22,) kernel_02 -0.109212 0.993421 -1.757174 -0.099849 1.490357 4000 3138.551202 3737.813432 1.001187
(23,) kernel_02 0.021594 1.002394 -1.596089 0.020058 1.663528 4000 3288.702620 3699.102801 1.000253
(24,) kernel_02 0.212803 1.000352 -1.436119 0.214356 1.868925 4000 3369.214662 3405.297049 1.000385
(25,) kernel_02 0.039339 1.002009 -1.645125 0.054513 1.693271 4000 3156.929198 3563.902520 1.000621
(26,) kernel_02 0.112780 0.981519 -1.499247 0.125068 1.739748 4000 3423.025647 3776.053329 0.999761
(27,) kernel_02 0.114932 1.002042 -1.583529 0.140368 1.733516 4000 3487.933207 3201.791924 0.999937
(28,) kernel_02 0.147575 1.004352 -1.501877 0.153652 1.774238 4000 2733.098370 3329.716033 1.000988
(29,) kernel_02 0.219882 0.998673 -1.397360 0.221177 1.861081 4000 2701.017705 3795.854366 0.999908
(30,) kernel_02 -0.116383 0.997538 -1.735624 -0.128141 1.574517 4000 3217.890179 3801.157941 1.001203
(31,) kernel_02 -0.080508 1.002403 -1.707431 -0.091880 1.594098 4000 3299.231083 3735.996022 1.000038
(32,) kernel_02 0.071810 0.981285 -1.546547 0.080775 1.713290 4000 3274.200728 3581.658592 0.999757
(33,) kernel_02 -0.087403 1.001078 -1.711038 -0.106376 1.570298 4000 2967.668301 3575.274930 1.000736
(34,) kernel_02 0.032245 0.997768 -1.583967 0.008987 1.649590 4000 3327.468348 3434.114192 1.000098
(35,) kernel_02 -0.128174 0.977477 -1.788555 -0.128229 1.467960 4000 3430.672160 3929.622040 1.000706
(36,) kernel_02 0.062274 0.992500 -1.554487 0.058193 1.648628 4000 3315.985625 3776.864888 0.999781
(37,) kernel_02 0.122681 0.992455 -1.533738 0.130786 1.727846 4000 3671.058446 3764.216992 1.000509
(38,) kernel_02 -0.077014 0.985287 -1.709078 -0.066543 1.585702 4000 3225.777713 3398.981197 1.000053
(39,) kernel_02 -0.143057 0.997734 -1.793833 -0.148066 1.480052 4000 2941.655515 3737.339716 1.000174
(40,) kernel_02 -0.112754 0.999094 -1.731114 -0.132243 1.505967 4000 3198.686090 3761.180329 1.000729
(41,) kernel_02 -0.092626 0.975503 -1.706437 -0.110860 1.495511 4000 2968.046558 3776.094934 1.000923
(42,) kernel_02 0.006991 1.000951 -1.659625 0.006960 1.602366 4000 3392.910944 3545.398646 0.999589
(43,) kernel_02 -0.023423 1.017520 -1.704195 -0.016146 1.662377 4000 3281.050568 3477.662056 1.000715
(44,) kernel_02 -0.041573 0.990635 -1.635391 -0.055294 1.588969 4000 3478.239695 3848.778160 0.999713
(45,) kernel_02 -0.128339 0.997562 -1.752092 -0.128040 1.554265 4000 2742.316891 3626.814222 1.001243
(46,) kernel_02 -0.058310 0.991914 -1.661159 -0.052917 1.593579 4000 3265.650162 3645.057141 0.999911
(47,) kernel_02 -0.045507 1.011601 -1.666413 -0.043299 1.659506 4000 3452.193099 3876.832874 1.000524
(48,) kernel_02 -0.096334 0.997661 -1.726075 -0.098920 1.572426 4000 3303.496173 3773.385828 1.000576
$\tau_{ri(district)}^2$ () kernel_03 0.018002 0.038290 0.001708 0.007032 0.067597 4000 104.295276 145.836175 1.040663

Plots#

samples = results.get_posterior_samples()
gam.plot_regions(term=loc.terms["ri(district)"], samples=samples, polys=polys)
gam.plot_forest(
    term=loc.terms["ri(district)"],
    samples=samples,
)