RS: Random Scale for Smooth#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
# import data from R
from ryp import r, to_py

r("library(mgcv)")
r("data(columb)")
r("data(columb.polys)")

columb = to_py("columb", format="pandas").reset_index()
polys = to_py("columb.polys", format="numpy")
Loading required package: nlme
This is mgcv 1.9-3. For overview type 'help("mgcv-package")'.
columb.head()
index area home.value income crime open.space district x y
0 0 0.309441 80.467003 19.531 15.725980 2.850747 0 8.827218 14.369076
1 1 0.259329 44.567001 21.232 18.801754 5.296720 1 8.332658 14.031624
2 2 0.192468 26.350000 15.956 30.626781 4.534649 2 9.012265 13.819719
3 3 0.083841 33.200001 4.477 32.387760 0.394427 3 8.460801 13.716962
4 4 0.488888 23.225000 11.252 50.731510 0.405664 4 9.007982 13.296366

Model Definition#

Setup response model#

df = columb
tb = gam.TermBuilder.from_df(df)

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)


y = lsl.Var.new_obs(
    value=df.crime.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)


smooth = tb.ps("area", k=20)

loc += smooth
loc += tb.rs(x=smooth, cluster="district")

loc += tb.ri("district", factor_scale=True)
Warning message:
In smooth.construct.ps.smooth.spec(object, dk$data, dk$knots) :
  there is *no* information about some basis coefficients
Warning message:
In smooth.construct.ps.smooth.spec(object, dk$data, dk$knots) :
  there is *no* information about some basis coefficients

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
liesel.model.model - INFO - Converted dtype of Value(name="y_value").value
../../_images/2bb8a55255e7cfeb8dbf6b131ed0ebe1f6036ad75752158adcf679e55066d4c9.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{ri(district)1}$', '$\\tau_{ri(district)1}^2$', '$\\beta_{ri(district)}$', '$\\tau_{ri(district)}^2$', '$\\beta_{ps(area)}$', '$\\tau_{ps(area)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:05<00:00,  1.97s/chunk]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:02<00:00,  4.09chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,\mu}$ () kernel_01 34.560280 1.562532 32.490276 34.752068 37.517949 4000 18.735609 134.285662 1.467820
$\beta_{0,\sigma}$ () kernel_00 -2.314584 5.060426 -10.453270 -0.725150 2.772572 4000 5.444350 12.086463 1.986426
$\beta_{ps(area)}$ (0,) kernel_06 -0.034334 0.436838 -0.602005 -0.059070 0.686317 4000 17.202776 336.204159 1.155748
(1,) kernel_06 0.055793 0.478513 -0.589033 -0.034411 0.955265 4000 20.653830 102.445664 1.144747
(2,) kernel_06 -0.099708 0.400067 -0.622008 -0.091584 0.558553 4000 24.052051 821.728335 1.115812
(3,) kernel_06 0.163132 0.413789 -0.493787 0.152299 0.741832 4000 22.697727 83.547707 1.118194
(4,) kernel_06 0.046597 0.482175 -0.535355 0.000273 0.914321 4000 14.968767 57.580296 1.191884
(5,) kernel_06 -0.183150 0.433294 -0.773132 -0.170342 0.540009 4000 16.079404 225.542933 1.166451
(6,) kernel_06 -0.008397 0.381778 -0.565468 -0.022081 0.602477 4000 35.169724 625.079329 1.075295
(7,) kernel_06 0.165807 0.392643 -0.483105 0.187697 0.650284 4000 17.119914 291.030064 1.160779
(8,) kernel_06 -0.127015 0.401156 -0.719396 -0.109715 0.473841 4000 25.128343 358.230064 1.241952
(9,) kernel_06 0.166496 0.426673 -0.530213 0.176164 0.714673 4000 11.747482 194.110019 1.249471
(10,) kernel_06 0.049425 0.525392 -0.850056 0.044701 0.715587 4000 8.683869 51.887989 1.386288
(11,) kernel_06 -0.012129 0.412609 -0.587041 -0.035999 0.690778 4000 15.234813 137.271878 1.197513
(12,) kernel_06 -0.061306 0.372631 -0.565858 -0.075590 0.554469 4000 24.519723 401.944175 1.109482
(13,) kernel_06 -0.028801 0.403503 -0.613669 -0.017792 0.650820 4000 28.125367 207.312296 1.094357
(14,) kernel_06 0.257033 0.428621 -0.417775 0.239404 0.847522 4000 13.575460 71.279237 1.207953
(15,) kernel_06 -0.058105 0.361044 -0.607081 -0.051086 0.425249 4000 12.089223 147.390087 1.240711
(16,) kernel_06 -0.156314 0.396944 -0.717581 -0.114378 0.360091 4000 6.496481 25.480715 1.640415
(17,) kernel_06 0.487242 0.190459 0.122755 0.496571 0.822002 4000 41.627772 130.316891 1.489187
(18,) kernel_06 -0.861178 0.667154 -2.099230 -0.784701 0.085061 4000 27.728257 122.587370 1.287118
$\beta_{ri(district)1}$ (0,) kernel_02 -0.210397 0.740211 -1.208341 -0.287891 1.188438 4000 23.138103 105.744392 1.115732
(1,) kernel_02 -0.231975 0.692929 -1.141416 -0.246511 1.142817 4000 26.356160 92.952687 1.209690
(2,) kernel_02 0.062086 0.676410 -1.175381 0.023497 1.241417 4000 36.295276 111.349071 1.202638
(3,) kernel_02 -0.452722 0.780259 -1.321563 -0.665373 1.184502 4000 20.731557 110.626846 1.459223
(4,) kernel_02 1.129287 1.123565 -1.110904 1.656574 2.271644 4000 8.540031 88.632568 1.404480
(5,) kernel_02 0.109048 0.693470 -1.172898 0.023826 1.167032 4000 32.661799 114.540020 1.317076
(6,) kernel_02 -1.060851 1.096300 -2.233408 -1.474931 1.132582 4000 8.382973 96.310757 1.414668
(7,) kernel_02 0.466980 0.764957 -1.123821 0.553576 1.324791 4000 16.086639 140.391500 1.418436
(8,) kernel_02 0.264345 0.696725 -1.183168 0.341251 1.191645 4000 32.288652 132.929602 1.360367
(9,) kernel_02 0.465457 0.813487 -1.155299 0.509998 1.417796 4000 12.010607 119.787695 1.432486
(10,) kernel_02 0.643772 0.872154 -1.173777 0.850263 1.586231 4000 9.836230 109.755496 1.330349
(11,) kernel_02 0.309434 0.721403 -1.124535 0.315746 1.157232 4000 17.083001 101.537798 1.272516
(12,) kernel_02 -0.081668 0.670474 -1.160042 -0.096998 1.114781 4000 28.165619 165.988403 1.093010
(13,) kernel_02 0.594503 0.834426 -1.131151 0.972451 1.330932 4000 13.123226 133.360176 1.291769
(14,) kernel_02 0.385967 0.712060 -1.104641 0.601598 1.223950 4000 51.378999 139.588272 1.387321
(15,) kernel_02 0.612820 0.802635 -1.057363 0.940876 1.324676 4000 20.108663 108.132612 1.324230
(16,) kernel_02 -0.227549 0.676514 -1.192853 -0.330187 1.147926 4000 51.326553 115.840268 1.569003
(17,) kernel_02 -0.138231 0.703200 -1.162477 -0.084101 1.145223 4000 22.134504 108.211888 1.140934
(18,) kernel_02 0.412815 0.781696 -1.234762 0.543583 1.229529 4000 15.026583 112.034411 1.327440
(19,) kernel_02 -0.986920 1.044361 -1.968388 -1.491401 1.162952 4000 11.113658 98.778707 1.343208
(20,) kernel_02 0.230925 0.801186 -1.201906 0.070954 1.393350 4000 20.150471 97.492404 1.476524
(21,) kernel_02 0.180308 0.681637 -1.168158 0.145603 1.166621 4000 29.117435 102.460288 1.485201
(22,) kernel_02 -0.168793 0.711451 -1.175013 -0.228368 1.152732 4000 24.863779 128.296490 1.106957
(23,) kernel_02 0.395157 0.752233 -1.156565 0.467216 1.215543 4000 14.576257 114.213514 1.433247
(24,) kernel_02 1.351564 1.249704 -1.102056 1.903437 2.688057 4000 7.472297 69.707119 1.497787
(25,) kernel_02 0.047273 0.655953 -1.193755 0.063258 1.205156 4000 200.388062 101.349621 1.787167
(26,) kernel_02 0.729052 0.873019 -1.129991 1.102707 1.519112 4000 13.484559 113.100512 1.366030
(27,) kernel_02 1.132830 1.167319 -1.215946 1.632173 2.362207 4000 7.583196 69.744421 1.479838
(28,) kernel_02 0.780808 0.925196 -1.190862 1.166052 1.661330 4000 10.423349 96.750041 1.344328
(29,) kernel_02 1.220885 1.178876 -1.077651 1.814745 2.399752 4000 9.543208 93.422886 1.367211
(30,) kernel_02 -0.973321 0.978733 -1.871164 -1.443396 1.061687 4000 11.593601 122.360551 1.315178
(31,) kernel_02 -0.475671 0.803943 -1.261170 -0.678146 1.200655 4000 13.983002 87.677902 1.393235
(32,) kernel_02 -0.084451 0.663400 -1.146282 0.007450 1.159211 4000 38.813649 113.936478 1.459016
(33,) kernel_02 -0.079085 0.698643 -1.200582 -0.160723 1.163111 4000 23.840740 127.446272 1.124157
(34,) kernel_02 0.340175 0.724636 -1.171661 0.431308 1.176727 4000 22.324178 134.357205 1.477083
(35,) kernel_02 -0.739811 0.919037 -1.645923 -1.101170 1.173585 4000 10.267470 85.696996 1.375068
(36,) kernel_02 0.160795 0.672669 -1.154445 0.191888 1.207212 4000 60.625259 124.205534 1.555407
(37,) kernel_02 0.352233 0.745126 -1.157745 0.440405 1.184331 4000 17.363734 140.480046 1.351546
(38,) kernel_02 -0.531034 0.801360 -1.342097 -0.688830 1.097220 4000 14.786406 86.777160 1.375103
(39,) kernel_02 -0.282801 0.697194 -1.218451 -0.335077 1.118890 4000 28.939941 111.755848 1.367269
(40,) kernel_02 -0.799280 0.925178 -1.615940 -1.214909 1.174641 4000 12.733657 114.386709 1.284430
(41,) kernel_02 -0.390402 0.726635 -1.157517 -0.590107 1.121922 4000 18.240927 100.000594 1.354350
(42,) kernel_02 -0.419510 0.762337 -1.308000 -0.554701 1.181960 4000 21.155510 132.973264 1.493734
(43,) kernel_02 -0.261450 0.725753 -1.241222 -0.376921 1.228634 4000 28.134646 108.089946 1.349016
(44,) kernel_02 0.212092 0.701066 -1.213525 0.168330 1.166764 4000 22.749940 92.548676 1.464938
(45,) kernel_02 -0.986106 1.023072 -1.925927 -1.500493 1.120921 4000 10.124921 137.696116 1.300160
(46,) kernel_02 0.166518 0.689830 -1.126864 0.066489 1.136345 4000 24.350092 119.069851 1.401489
(47,) kernel_02 -0.759161 0.895805 -1.672453 -1.086105 1.124089 4000 13.319756 102.233084 1.432165
(48,) kernel_02 -0.227137 0.720506 -1.234200 -0.298955 1.168923 4000 24.598455 114.779287 1.104993
$\beta_{ri(district)}$ (0,) kernel_04 0.047229 0.118916 -0.143021 0.023086 0.198332 4000 13.418754 193.481880 1.584372
(1,) kernel_04 0.014594 0.097405 -0.136229 0.028880 0.145455 4000 51.063239 149.331167 1.316570
(2,) kernel_04 0.005225 0.097543 -0.132859 0.006773 0.148755 4000 174.205631 135.468083 1.213922
(3,) kernel_04 -0.027421 0.098052 -0.144303 -0.049500 0.134784 4000 51.987875 181.505487 1.555147
(4,) kernel_04 -0.064163 0.113380 -0.212736 -0.051250 0.109011 4000 25.522609 119.826148 1.561254
(5,) kernel_04 -0.041028 0.103384 -0.184235 -0.059735 0.132460 4000 21.477930 143.548060 1.282516
(6,) kernel_04 0.029280 0.101880 -0.111004 0.021606 0.184811 4000 49.216133 156.385572 1.053620
(7,) kernel_04 -0.023320 0.114344 -0.154980 -0.009135 0.149605 4000 13.272549 152.621682 1.212604
(8,) kernel_04 0.026883 0.113182 -0.156069 0.037852 0.237574 4000 14.319801 25.105032 1.195341
(9,) kernel_04 0.088846 0.157398 -0.135001 0.088518 0.319894 4000 8.397407 81.627804 1.676276
(10,) kernel_04 0.008287 0.098580 -0.126779 -0.007517 0.163779 4000 47.040739 116.388437 1.192137
(11,) kernel_04 0.066116 0.114881 -0.118648 0.087816 0.159374 4000 14.310578 132.322131 1.195257
(12,) kernel_04 -0.013234 0.093450 -0.137976 -0.016009 0.125289 4000 57.177594 109.013366 1.183942
(13,) kernel_04 -0.083035 0.176295 -0.372703 -0.031608 0.150402 4000 8.577789 15.607517 1.920676
(14,) kernel_04 -0.030141 0.112032 -0.179747 -0.009829 0.143731 4000 12.835090 151.130882 1.272613
(15,) kernel_04 -0.026541 0.142134 -0.233822 0.012305 0.144922 4000 9.834768 107.558895 1.325157
(16,) kernel_04 0.064576 0.132229 -0.142519 0.083507 0.246909 4000 9.715126 125.731069 1.355246
(17,) kernel_04 0.036524 0.138995 -0.133593 -0.005258 0.285847 4000 9.111601 17.722244 1.355814
(18,) kernel_04 -0.026100 0.104505 -0.136027 -0.024464 0.142363 4000 15.648701 154.916577 1.173933
(19,) kernel_04 0.062875 0.140521 -0.113384 0.005259 0.272945 4000 11.269739 139.448203 1.335932
(20,) kernel_04 -0.035150 0.109229 -0.179480 -0.032882 0.139683 4000 40.581767 209.038716 1.338626
(21,) kernel_04 -0.007429 0.106233 -0.141170 0.009506 0.143021 4000 37.220998 150.795945 1.131612
(22,) kernel_04 0.058256 0.147896 -0.139913 0.015675 0.286919 4000 10.615476 82.323343 1.569527
(23,) kernel_04 0.026233 0.136882 -0.145514 -0.008880 0.215057 4000 9.798700 71.894894 1.320895
(24,) kernel_04 0.029461 0.142042 -0.152766 -0.003200 0.245248 4000 10.331970 144.359328 1.296346
(25,) kernel_04 -0.048708 0.123778 -0.225166 -0.015621 0.131567 4000 12.602690 208.028835 1.433780
(26,) kernel_04 -0.020998 0.111355 -0.167778 -0.034113 0.162452 4000 19.625106 224.073173 1.131365
(27,) kernel_04 0.047846 0.145432 -0.144941 0.048181 0.410522 4000 10.986621 13.740739 1.435892
(28,) kernel_04 0.015312 0.106665 -0.110207 0.003497 0.186921 4000 15.557007 118.582631 1.175642
(29,) kernel_04 0.060696 0.151277 -0.114089 0.002283 0.325635 4000 9.550471 21.878903 1.334894
(30,) kernel_04 -0.041786 0.114668 -0.165535 -0.052663 0.140416 4000 31.672095 270.152315 1.143585
(31,) kernel_04 -0.037367 0.112731 -0.176247 -0.037638 0.154296 4000 15.424747 169.345209 1.183066
(32,) kernel_04 0.074653 0.146235 -0.137294 0.068334 0.339031 4000 7.814564 17.082677 1.668177
(33,) kernel_04 0.037015 0.167099 -0.139765 -0.008537 0.300371 4000 7.891730 55.823378 1.445308
(34,) kernel_04 0.002268 0.110832 -0.171839 0.014766 0.132161 4000 25.209897 167.136046 1.106625
(35,) kernel_04 -0.050530 0.121963 -0.161559 -0.055744 0.142271 4000 12.006862 113.480451 1.244361
(36,) kernel_04 0.012846 0.160402 -0.213198 0.008874 0.215399 4000 6.813592 102.729015 1.581480
(37,) kernel_04 -0.022358 0.101680 -0.138509 -0.050042 0.143538 4000 19.430965 129.337095 1.299413
(38,) kernel_04 0.103251 0.172750 -0.141274 0.110883 0.424867 4000 8.526050 43.594026 1.506704
(39,) kernel_04 -0.006913 0.122935 -0.163297 0.014727 0.148352 4000 11.465263 139.989122 1.258119
(40,) kernel_04 0.005470 0.126884 -0.154979 -0.017621 0.220493 4000 13.777033 45.432736 1.210269
(41,) kernel_04 -0.080501 0.162613 -0.361462 -0.037965 0.137927 4000 7.996195 14.022373 1.928843
(42,) kernel_04 0.039222 0.120962 -0.153197 0.021278 0.179992 4000 9.649868 89.326565 1.331018
(43,) kernel_04 -0.073462 0.171591 -0.402516 -0.012812 0.138327 4000 8.976526 14.491843 1.448537
(44,) kernel_04 0.012244 0.095088 -0.135333 0.022006 0.130795 4000 31.278508 113.338322 1.095503
(45,) kernel_04 0.036122 0.111700 -0.141798 0.032313 0.175361 4000 19.346124 270.771501 1.399930
(46,) kernel_04 0.053805 0.116540 -0.135362 0.069666 0.196974 4000 10.565007 100.779168 1.484020
(47,) kernel_04 -0.065835 0.111461 -0.212778 -0.086043 0.111801 4000 14.770707 105.533097 1.603224
(48,) kernel_04 0.039685 0.176952 -0.132460 0.000329 0.360946 4000 7.448330 16.829189 1.493100
$\tau_{ps(area)}^2$ () kernel_07 0.198391 0.231142 0.009850 0.148903 0.524625 4000 23.183294 175.114580 1.126574
$\tau_{ri(district)1}^2$ () kernel_03 106.969002 94.016472 0.002374 144.080162 236.058330 4000 7.647214 76.696954 1.479171
$\tau_{ri(district)}^2$ () kernel_05 0.018785 0.025597 0.002189 0.008328 0.053526 4000 16.129285 87.178800 1.317112

Plots#

samples = results.get_posterior_samples()
gam.plot_1d_smooth(term=model.vars[smooth.name], samples=samples)
gam.plot_1d_smooth_clustered(
    clustered_term=model.vars["rs(ps(area)|district)"],
    samples=samples,
    ngrid=30,
)
(
    gam.plot_regions(
        term=model.vars["ri(district)"],
        samples=samples,
        polys=polys,
        observed_color="black",
    )
)
gam.plot_regions(term=model.vars["ri(district)1"], samples=samples, polys=polys)