MRF: Markov Random Field#

Setup and Imports#

import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel_gam as gam
# import data from R
from ryp import r, to_py

r("library(mgcv)")
r("data(columb)")
r("data(columb.polys)")

columb = to_py("columb", format="pandas").reset_index()
polys = to_py("columb.polys", format="numpy")
gam.plot_polys(region="district", which=["crime"], df=columb, polys=polys)

Model Definition#

# removing some observations here to simulate a dataset with unobserved clusters
i = np.arange(columb.shape[0])
i10 = i[:10]
i20 = i[11:20]
i30 = i[21:30]
irest = i[31:]
i = np.concatenate((i10, i20, i30, irest))
df = columb.iloc[i, :].reset_index()
# standardizing makes it easier for the model
df["crime"] = (df["crime"] - df["crime"].mean()) / df["crime"].std()

loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)

# initializing intercepts to sensible values
loc.intercept.value = df.crime.mean()
scale.intercept.value = jnp.log(df.crime.std())


y = lsl.Var.new_obs(
    value=df.crime.to_numpy(),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)

tb = gam.TermBuilder.from_df(df)


loc += tb.mrf(
    "district",
    k=40,  # using a low-rank MRF here, because we have only 1 observation per cluster.
    polys=polys,
    scale=gam.VarIGPrior(0.01, 0.01, 0.1),
    factor_scale=True,
)

Build and plot model#

model = lsl.Model([y])
model.plot_vars()
liesel.model.model - INFO - Converted dtype of Value(name="y_value").value
liesel.model.model - INFO - Converted dtype of Value(name="$\beta_{0,\mu}$_value").value
../../_images/a14666e7fe6af4651829c22d600bfe286b8a5e84209708f2937f13fc3af98f6b.png

Run MCMC#

eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)

eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)

engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{mrf(district)}$', '$\\tau_{mrf(district)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:02<00:00,  1.14chunk/s]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:01<00:00,  8.03chunk/s]
liesel.goose.engine - INFO - Finished epoch

MCMC summary#

summary = gs.Summary(results)
summary

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,\mu}$ () kernel_01 -0.003580 0.104040 -0.172420 -0.003226 0.168574 4000 4128.141289 3861.134067 1.000269
$\beta_{0,\sigma}$ () kernel_00 -0.382368 0.179775 -0.673807 -0.387672 -0.078226 4000 759.769426 1653.443321 1.001388
$\beta_{mrf(district)}$ (0,) kernel_02 0.734054 0.930112 -0.850236 0.757395 2.223578 4000 2224.144217 3795.382722 1.000776
(1,) kernel_02 0.183005 0.898392 -1.295767 0.181252 1.641307 4000 3331.498057 3738.339084 1.000012
(2,) kernel_02 -0.319014 0.896056 -1.772981 -0.340988 1.181685 4000 3096.563629 3522.443074 1.000185
(3,) kernel_02 -0.120443 0.874687 -1.554043 -0.111225 1.336551 4000 3651.294647 3597.077531 1.000916
(4,) kernel_02 0.144608 0.849683 -1.228167 0.143426 1.506603 4000 3541.345596 3714.847753 1.000765
(5,) kernel_02 -0.048042 0.845411 -1.402315 -0.059043 1.321099 4000 3350.572873 3293.832139 0.999785
(6,) kernel_02 0.890617 0.906810 -0.640480 0.904657 2.333886 4000 2218.303185 3025.331910 1.000710
(7,) kernel_02 -0.349899 0.863267 -1.741079 -0.362708 1.079250 4000 3265.416856 3104.258288 1.000485
(8,) kernel_02 0.012799 0.827630 -1.373064 0.015367 1.340184 4000 3497.628681 3418.055080 0.999645
(9,) kernel_02 0.290853 0.850292 -1.128265 0.314719 1.663848 4000 3165.981568 3879.076364 1.000071
(10,) kernel_02 -0.433604 0.916335 -1.950916 -0.419934 1.078530 4000 3195.804415 3595.036898 1.000954
(11,) kernel_02 0.324850 0.834942 -1.081650 0.336686 1.670527 4000 3386.210634 3363.387032 1.000490
(12,) kernel_02 -1.306409 0.919171 -2.768893 -1.320707 0.269126 4000 1659.370184 2494.403658 1.000072
(13,) kernel_02 0.151432 0.834099 -1.219404 0.155445 1.506486 4000 3441.744746 3777.351803 1.000898
(14,) kernel_02 -0.139186 0.850732 -1.538324 -0.148588 1.273791 4000 3425.541042 3973.719668 1.001252
(15,) kernel_02 0.367454 0.836638 -1.012700 0.379260 1.720965 4000 3057.460590 3521.844530 1.000354
(16,) kernel_02 -0.230954 0.825652 -1.554380 -0.240329 1.128883 4000 3434.843555 3559.509891 1.000024
(17,) kernel_02 0.220780 0.811259 -1.123885 0.224643 1.573010 4000 3269.807101 3424.065108 0.999672
(18,) kernel_02 -0.330824 0.817928 -1.647253 -0.354557 1.011196 4000 3630.937224 3853.568860 1.000377
(19,) kernel_02 0.539861 0.848227 -0.881640 0.558531 1.895959 4000 2889.461567 2968.601754 1.000823
(20,) kernel_02 -0.396995 0.812252 -1.682499 -0.437580 0.979201 4000 3424.528463 3265.868819 0.999690
(21,) kernel_02 -0.355798 0.815720 -1.663999 -0.369793 1.047433 4000 2869.333926 3158.055884 1.000677
(22,) kernel_02 -0.112538 0.819784 -1.451739 -0.116155 1.247474 4000 3529.129979 3762.472460 1.000039
(23,) kernel_02 -0.573999 0.796707 -1.840591 -0.600600 0.766930 4000 3029.835200 3822.408057 1.001840
(24,) kernel_02 -0.452299 0.791887 -1.775378 -0.450625 0.873158 4000 3464.533027 3713.061222 1.000142
(25,) kernel_02 0.374664 0.814586 -0.958073 0.380786 1.677336 4000 3347.453033 3512.795151 1.000279
(26,) kernel_02 -0.401156 0.759131 -1.618902 -0.418309 0.844527 4000 3340.090736 3330.505596 0.999955
(27,) kernel_02 0.644882 0.760511 -0.639059 0.672120 1.858458 4000 3199.615685 3526.843969 0.999695
(28,) kernel_02 0.230367 0.859052 -1.198759 0.252664 1.600876 4000 3421.379739 3516.977664 1.001228
(29,) kernel_02 -0.158961 0.798545 -1.459488 -0.167566 1.172190 4000 2982.818369 3641.169531 1.001088
(30,) kernel_02 -0.756595 0.703527 -1.866371 -0.786238 0.417410 4000 2863.970671 2773.858380 1.000485
(31,) kernel_02 -1.012188 0.687997 -2.131813 -1.018673 0.132353 4000 2425.108041 2808.495116 0.999825
(32,) kernel_02 0.751591 0.665425 -0.370213 0.781156 1.782203 4000 3235.040705 3112.609190 1.000178
(33,) kernel_02 -1.241534 0.719319 -2.386984 -1.252372 -0.060085 4000 1971.411656 2330.655004 1.001608
(34,) kernel_02 0.679548 0.594236 -0.288307 0.673541 1.627562 4000 3487.900814 3217.553146 1.000636
(35,) kernel_02 1.364089 0.533356 0.519688 1.359027 2.246238 4000 3296.216997 2532.424887 1.002792
(36,) kernel_02 -1.444965 0.516227 -2.334677 -1.421517 -0.660619 4000 2957.540390 2859.159832 1.000935
(37,) kernel_02 0.380117 0.380897 -0.214342 0.369101 1.005009 4000 3871.378499 2369.016009 1.001500
(38,) kernel_02 -0.634658 0.320186 -1.233288 -0.588642 -0.205414 4000 1721.129550 1363.259482 1.001619
$\tau_{mrf(district)}^2$ () kernel_03 3.475489 1.868635 0.896338 3.234976 6.973548 4000 528.736199 456.614325 1.001866

Plots#

samples = results.get_posterior_samples()
gam.plot_regions(term=loc.terms["mrf(district)"], samples=samples, polys=polys)
gam.plot_regions(
    term=loc.terms["mrf(district)"],
    samples=samples,
    polys=polys,
    show_unobserved=False,
    unobserved_color="red",
    observed_color="none",
)
import plotnine as p9

(
    gam.plot_regions(
        term=loc.terms["mrf(district)"],
        samples=samples,
        polys=polys,
        which=["hdi_low", "mean", "hdi_high"],
        observed_color="none",
        unobserved_color="red",
        # show_unobserved=False
    )
    + p9.theme(figure_size=(10, 3.5))
    + p9.scale_fill_cmap("RdYlBu")
)
gam.plot_forest(
    term=loc.terms["mrf(district)"],
    samples=samples,
    ymin="q_0.05",
    ymax="q_0.95",
    show_unobserved=True,
) + p9.theme(figure_size=(4, 7))