RS: Random Slope#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
# import data from R
from ryp import r, to_py

r("library(mgcv)")
r("data(columb)")
r("data(columb.polys)")

columb = to_py("columb", format="pandas").reset_index()
polys = to_py("columb.polys", format="numpy")
Loading required package: nlme
This is mgcv 1.9-3. For overview type 'help("mgcv-package")'.
columb.head()
index area home.value income crime open.space district x y
0 0 0.309441 80.467003 19.531 15.725980 2.850747 0 8.827218 14.369076
1 1 0.259329 44.567001 21.232 18.801754 5.296720 1 8.332658 14.031624
2 2 0.192468 26.350000 15.956 30.626781 4.534649 2 9.012265 13.819719
3 3 0.083841 33.200001 4.477 32.387760 0.394427 3 8.460801 13.716962
4 4 0.488888 23.225000 11.252 50.731510 0.405664 4 9.007982 13.296366

Model Definition#

Setup response model#

df = columb
tb = gam.TermBuilder.from_df(df)

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)


y = lsl.Var.new_obs(
    value=df.crime.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)
loc += tb.rs("income", cluster="district", factor_scale=True)

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
liesel.model.model - INFO - Converted dtype of Value(name="y_value").value
../../_images/d65b3ebbcce4d489cc59ae15ec8fdb315ef215bb72fbf6a4487a65160bf450d3.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{ri(district)}$', '$\\tau_{ri(district)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:02<00:00,  1.13chunk/s]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:00<00:00, 12.22chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,\mu}$ () kernel_01 35.205700 2.449469 31.168429 35.229168 39.283052 4000 3924.292316 4030.000052 1.000404
$\beta_{0,\sigma}$ () kernel_00 2.814785 0.107803 2.643978 2.813509 2.996415 4000 2356.963316 2905.203126 1.000539
$\beta_{ri(district)}$ (0,) kernel_02 -0.149550 1.011764 -1.793733 -0.162534 1.539376 4000 3367.928786 3773.725689 1.000630
(1,) kernel_02 -0.156694 0.999332 -1.818371 -0.171744 1.461504 4000 3155.740826 3735.610473 1.000427
(2,) kernel_02 -0.022470 1.005256 -1.675336 -0.038248 1.637338 4000 3503.473663 3574.118301 1.001032
(3,) kernel_02 0.020327 0.989178 -1.587049 0.015807 1.639589 4000 3517.764780 3762.809801 1.000361
(4,) kernel_02 0.062295 0.974226 -1.544892 0.074181 1.672216 4000 3332.056868 3729.830537 1.000939
(5,) kernel_02 -0.067607 0.996991 -1.695521 -0.083134 1.582862 4000 3519.030512 3922.129330 0.999627
(6,) kernel_02 -0.140477 1.010062 -1.803054 -0.131058 1.529071 4000 3097.072549 3665.523298 1.001043
(7,) kernel_02 0.015098 0.994286 -1.572399 0.010808 1.668436 4000 3328.965457 3847.687621 0.999965
(8,) kernel_02 -0.029623 0.979467 -1.659029 -0.009136 1.546134 4000 3340.362424 3843.223659 1.000080
(9,) kernel_02 -0.007052 1.001350 -1.669084 -0.008519 1.670057 4000 3311.713423 3738.028882 1.000227
(10,) kernel_02 0.105332 0.981781 -1.514487 0.111449 1.672988 4000 3363.686726 3705.502605 1.000872
(11,) kernel_02 0.096634 0.985696 -1.517471 0.102809 1.716055 4000 3403.561674 3507.783298 1.000189
(12,) kernel_02 0.021292 0.983049 -1.629216 0.031030 1.620191 4000 3258.018811 3461.665919 1.001731
(13,) kernel_02 0.085475 1.011775 -1.567616 0.079958 1.806172 4000 3390.260886 3494.992028 1.000858
(14,) kernel_02 0.065382 0.989038 -1.545948 0.071293 1.682244 4000 3483.577938 3300.564350 0.999452
(15,) kernel_02 0.083964 0.998953 -1.534829 0.072235 1.717191 4000 2886.588212 3009.055486 0.999795
(16,) kernel_02 0.020062 0.992169 -1.600778 0.002476 1.650164 4000 3188.162443 3493.687826 0.999671
(17,) kernel_02 0.061893 0.987636 -1.578367 0.059938 1.678118 4000 3500.872930 3774.369502 1.000692
(18,) kernel_02 0.101397 1.015818 -1.580298 0.103325 1.737488 4000 3362.259617 3973.672341 1.000933
(19,) kernel_02 -0.409298 0.993877 -2.028548 -0.396448 1.196150 4000 1864.669623 3414.244414 1.003948
(20,) kernel_02 0.021687 0.991052 -1.607439 0.016225 1.652792 4000 3412.315951 3495.904836 0.999883
(21,) kernel_02 -0.021332 0.974231 -1.618732 -0.022891 1.587855 4000 3247.459893 3388.044071 1.000909
(22,) kernel_02 -0.145468 0.993787 -1.783674 -0.136202 1.462505 4000 2960.064175 3716.002313 1.001537
(23,) kernel_02 0.019553 1.003393 -1.599527 0.011594 1.674545 4000 3369.969354 3810.491327 1.000334
(24,) kernel_02 0.139218 1.002922 -1.517436 0.141735 1.790566 4000 3434.375932 3703.512824 0.999970
(25,) kernel_02 0.021509 1.008070 -1.650407 0.040934 1.683132 4000 3217.514499 3639.108952 1.000412
(26,) kernel_02 0.078028 0.985659 -1.542917 0.090441 1.700466 4000 3420.850759 3840.535596 0.999952
(27,) kernel_02 0.051379 1.002433 -1.634726 0.076623 1.676202 4000 3532.436647 3304.433850 1.000266
(28,) kernel_02 0.074784 1.002733 -1.603812 0.068848 1.701423 4000 3145.326895 3156.395323 1.000968
(29,) kernel_02 0.200842 1.001843 -1.429474 0.201164 1.843652 4000 2677.785319 3615.137987 1.000299
(30,) kernel_02 -0.129754 0.995484 -1.730766 -0.153370 1.563187 4000 3082.741793 3800.786186 1.001786
(31,) kernel_02 -0.093704 1.000903 -1.720932 -0.102037 1.581529 4000 3319.086242 3664.065457 1.000056
(32,) kernel_02 0.062308 0.987723 -1.552611 0.068107 1.698465 4000 3175.188307 3581.334684 0.999661
(33,) kernel_02 -0.086987 0.998064 -1.705815 -0.108131 1.560569 4000 3042.619591 3728.003983 1.000619
(34,) kernel_02 0.035038 0.999820 -1.610464 0.015525 1.652445 4000 3243.582598 3529.843222 1.000460
(35,) kernel_02 -0.161511 0.981592 -1.812029 -0.167667 1.433410 4000 3328.288214 3713.067457 1.001138
(36,) kernel_02 0.057044 0.989867 -1.550283 0.050247 1.622215 4000 3319.842979 3853.787938 0.999746
(37,) kernel_02 0.089711 0.993911 -1.567816 0.094480 1.690042 4000 3684.572040 3764.117878 1.000552
(38,) kernel_02 -0.101033 0.983928 -1.725200 -0.087279 1.551378 4000 3118.720846 3168.181417 1.000423
(39,) kernel_02 -0.240801 0.984874 -1.855641 -0.244579 1.395711 4000 2630.958506 3627.274554 0.999900
(40,) kernel_02 -0.152367 1.001512 -1.752209 -0.177241 1.462355 4000 3226.324696 3664.848533 1.000568
(41,) kernel_02 -0.166027 0.977543 -1.761164 -0.178427 1.437798 4000 2851.808542 3714.860008 1.000768
(42,) kernel_02 0.003221 1.001013 -1.657148 0.001118 1.597729 4000 3355.544983 3442.079222 0.999666
(43,) kernel_02 -0.032537 1.017155 -1.704087 -0.023640 1.625575 4000 3317.676506 3442.898932 1.000473
(44,) kernel_02 -0.043390 0.991569 -1.659987 -0.049425 1.595721 4000 3492.736744 3887.032527 0.999424
(45,) kernel_02 -0.149088 0.994964 -1.763726 -0.145288 1.541716 4000 2783.741794 3457.732355 1.001290
(46,) kernel_02 -0.070280 0.991442 -1.680281 -0.067694 1.569601 4000 3249.080499 3642.869324 0.999990
(47,) kernel_02 -0.040485 1.011548 -1.662248 -0.035718 1.657670 4000 3507.634937 3850.195077 1.000455
(48,) kernel_02 -0.114785 0.995690 -1.738748 -0.108378 1.541380 4000 3195.439066 3771.031957 1.000378
$\tau_{ri(district)}^2$ () kernel_03 0.022282 0.049855 0.001709 0.008178 0.089705 4000 67.388287 160.065913 1.051167

Plots#

samples = results.get_posterior_samples()
gam.plot_regions(model.vars["ri(district)"], samples, polys=polys)
gam.plot_forest(model.vars["ri(district)"], samples)
gam.plot_1d_smooth_clustered(
    clustered_term=model.vars["rs(income|district)"],
    samples=samples,
    ngrid=10,
)