MRF: Markov Random Field#
Setup and Imports#
import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel_gam as gam
# import data from R
from ryp import r, to_py
r("library(mgcv)")
r("data(columb)")
r("data(columb.polys)")
columb = to_py("columb", format="pandas").reset_index()
polys = to_py("columb.polys", format="numpy")
Model Definition#
# removing some observations here to simulate a dataset with unobserved clusters
i = np.arange(columb.shape[0])
i10 = i[:10]
i20 = i[11:20]
i30 = i[21:30]
irest = i[31:]
i = np.concatenate((i10, i20, i30, irest))
df = columb.iloc[i, :].reset_index()
# standardizing makes it easier for the model
df["crime"] = (df["crime"] - df["crime"].mean()) / df["crime"].std()
loc = gam.AdditivePredictor("$\\mu$")
scale = gam.AdditivePredictor("$\\sigma$", inv_link=jnp.exp)
# initializing intercepts to sensible values
loc.intercept.value = df.crime.mean()
scale.intercept.value = jnp.log(df.crime.std())
y = lsl.Var.new_obs(
value=df.crime.to_numpy(),
distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
name="y",
)
tb = gam.TermBuilder.from_df(df)
loc += tb.mrf(
"district",
k=40, # using a low-rank MRF here, because we have only 1 observation per cluster.
polys=polys,
scale=gam.VarIGPrior(0.01, 0.01, 0.1),
factor_scale=True,
)
Build and plot model#
model = lsl.Model([y])
model.plot_vars()
liesel.model.model - INFO - Converted dtype of Value(name="y_value").value
liesel.model.model - INFO - Converted dtype of Value(name="$\beta_{0,\mu}$_value").value
Run MCMC#
eb = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
eb.add_burnin(3000)
eb.add_posterior(10_000, thinning=10)
engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{0,\\sigma}$', '$\\beta_{0,\\mu}$', '$\\beta_{mrf(district)}$', '$\\tau_{mrf(district)}^2$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: BURNIN, 3000 transitions, 1000 jitted together
100%|██████████████████████████████████████████| 3/3 [00:02<00:00, 1.14chunk/s]
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 10000 transitions, 1000 jitted together
100%|████████████████████████████████████████| 10/10 [00:01<00:00, 8.03chunk/s]
liesel.goose.engine - INFO - Finished epoch
MCMC summary#
summary = gs.Summary(results)
summary
Parameter summary:
| kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| parameter | index | ||||||||||
| $\beta_{0,\mu}$ | () | kernel_01 | -0.003580 | 0.104040 | -0.172420 | -0.003226 | 0.168574 | 4000 | 4128.141289 | 3861.134067 | 1.000269 |
| $\beta_{0,\sigma}$ | () | kernel_00 | -0.382368 | 0.179775 | -0.673807 | -0.387672 | -0.078226 | 4000 | 759.769426 | 1653.443321 | 1.001388 |
| $\beta_{mrf(district)}$ | (0,) | kernel_02 | 0.734054 | 0.930112 | -0.850236 | 0.757395 | 2.223578 | 4000 | 2224.144217 | 3795.382722 | 1.000776 |
| (1,) | kernel_02 | 0.183005 | 0.898392 | -1.295767 | 0.181252 | 1.641307 | 4000 | 3331.498057 | 3738.339084 | 1.000012 | |
| (2,) | kernel_02 | -0.319014 | 0.896056 | -1.772981 | -0.340988 | 1.181685 | 4000 | 3096.563629 | 3522.443074 | 1.000185 | |
| (3,) | kernel_02 | -0.120443 | 0.874687 | -1.554043 | -0.111225 | 1.336551 | 4000 | 3651.294647 | 3597.077531 | 1.000916 | |
| (4,) | kernel_02 | 0.144608 | 0.849683 | -1.228167 | 0.143426 | 1.506603 | 4000 | 3541.345596 | 3714.847753 | 1.000765 | |
| (5,) | kernel_02 | -0.048042 | 0.845411 | -1.402315 | -0.059043 | 1.321099 | 4000 | 3350.572873 | 3293.832139 | 0.999785 | |
| (6,) | kernel_02 | 0.890617 | 0.906810 | -0.640480 | 0.904657 | 2.333886 | 4000 | 2218.303185 | 3025.331910 | 1.000710 | |
| (7,) | kernel_02 | -0.349899 | 0.863267 | -1.741079 | -0.362708 | 1.079250 | 4000 | 3265.416856 | 3104.258288 | 1.000485 | |
| (8,) | kernel_02 | 0.012799 | 0.827630 | -1.373064 | 0.015367 | 1.340184 | 4000 | 3497.628681 | 3418.055080 | 0.999645 | |
| (9,) | kernel_02 | 0.290853 | 0.850292 | -1.128265 | 0.314719 | 1.663848 | 4000 | 3165.981568 | 3879.076364 | 1.000071 | |
| (10,) | kernel_02 | -0.433604 | 0.916335 | -1.950916 | -0.419934 | 1.078530 | 4000 | 3195.804415 | 3595.036898 | 1.000954 | |
| (11,) | kernel_02 | 0.324850 | 0.834942 | -1.081650 | 0.336686 | 1.670527 | 4000 | 3386.210634 | 3363.387032 | 1.000490 | |
| (12,) | kernel_02 | -1.306409 | 0.919171 | -2.768893 | -1.320707 | 0.269126 | 4000 | 1659.370184 | 2494.403658 | 1.000072 | |
| (13,) | kernel_02 | 0.151432 | 0.834099 | -1.219404 | 0.155445 | 1.506486 | 4000 | 3441.744746 | 3777.351803 | 1.000898 | |
| (14,) | kernel_02 | -0.139186 | 0.850732 | -1.538324 | -0.148588 | 1.273791 | 4000 | 3425.541042 | 3973.719668 | 1.001252 | |
| (15,) | kernel_02 | 0.367454 | 0.836638 | -1.012700 | 0.379260 | 1.720965 | 4000 | 3057.460590 | 3521.844530 | 1.000354 | |
| (16,) | kernel_02 | -0.230954 | 0.825652 | -1.554380 | -0.240329 | 1.128883 | 4000 | 3434.843555 | 3559.509891 | 1.000024 | |
| (17,) | kernel_02 | 0.220780 | 0.811259 | -1.123885 | 0.224643 | 1.573010 | 4000 | 3269.807101 | 3424.065108 | 0.999672 | |
| (18,) | kernel_02 | -0.330824 | 0.817928 | -1.647253 | -0.354557 | 1.011196 | 4000 | 3630.937224 | 3853.568860 | 1.000377 | |
| (19,) | kernel_02 | 0.539861 | 0.848227 | -0.881640 | 0.558531 | 1.895959 | 4000 | 2889.461567 | 2968.601754 | 1.000823 | |
| (20,) | kernel_02 | -0.396995 | 0.812252 | -1.682499 | -0.437580 | 0.979201 | 4000 | 3424.528463 | 3265.868819 | 0.999690 | |
| (21,) | kernel_02 | -0.355798 | 0.815720 | -1.663999 | -0.369793 | 1.047433 | 4000 | 2869.333926 | 3158.055884 | 1.000677 | |
| (22,) | kernel_02 | -0.112538 | 0.819784 | -1.451739 | -0.116155 | 1.247474 | 4000 | 3529.129979 | 3762.472460 | 1.000039 | |
| (23,) | kernel_02 | -0.573999 | 0.796707 | -1.840591 | -0.600600 | 0.766930 | 4000 | 3029.835200 | 3822.408057 | 1.001840 | |
| (24,) | kernel_02 | -0.452299 | 0.791887 | -1.775378 | -0.450625 | 0.873158 | 4000 | 3464.533027 | 3713.061222 | 1.000142 | |
| (25,) | kernel_02 | 0.374664 | 0.814586 | -0.958073 | 0.380786 | 1.677336 | 4000 | 3347.453033 | 3512.795151 | 1.000279 | |
| (26,) | kernel_02 | -0.401156 | 0.759131 | -1.618902 | -0.418309 | 0.844527 | 4000 | 3340.090736 | 3330.505596 | 0.999955 | |
| (27,) | kernel_02 | 0.644882 | 0.760511 | -0.639059 | 0.672120 | 1.858458 | 4000 | 3199.615685 | 3526.843969 | 0.999695 | |
| (28,) | kernel_02 | 0.230367 | 0.859052 | -1.198759 | 0.252664 | 1.600876 | 4000 | 3421.379739 | 3516.977664 | 1.001228 | |
| (29,) | kernel_02 | -0.158961 | 0.798545 | -1.459488 | -0.167566 | 1.172190 | 4000 | 2982.818369 | 3641.169531 | 1.001088 | |
| (30,) | kernel_02 | -0.756595 | 0.703527 | -1.866371 | -0.786238 | 0.417410 | 4000 | 2863.970671 | 2773.858380 | 1.000485 | |
| (31,) | kernel_02 | -1.012188 | 0.687997 | -2.131813 | -1.018673 | 0.132353 | 4000 | 2425.108041 | 2808.495116 | 0.999825 | |
| (32,) | kernel_02 | 0.751591 | 0.665425 | -0.370213 | 0.781156 | 1.782203 | 4000 | 3235.040705 | 3112.609190 | 1.000178 | |
| (33,) | kernel_02 | -1.241534 | 0.719319 | -2.386984 | -1.252372 | -0.060085 | 4000 | 1971.411656 | 2330.655004 | 1.001608 | |
| (34,) | kernel_02 | 0.679548 | 0.594236 | -0.288307 | 0.673541 | 1.627562 | 4000 | 3487.900814 | 3217.553146 | 1.000636 | |
| (35,) | kernel_02 | 1.364089 | 0.533356 | 0.519688 | 1.359027 | 2.246238 | 4000 | 3296.216997 | 2532.424887 | 1.002792 | |
| (36,) | kernel_02 | -1.444965 | 0.516227 | -2.334677 | -1.421517 | -0.660619 | 4000 | 2957.540390 | 2859.159832 | 1.000935 | |
| (37,) | kernel_02 | 0.380117 | 0.380897 | -0.214342 | 0.369101 | 1.005009 | 4000 | 3871.378499 | 2369.016009 | 1.001500 | |
| (38,) | kernel_02 | -0.634658 | 0.320186 | -1.233288 | -0.588642 | -0.205414 | 4000 | 1721.129550 | 1363.259482 | 1.001619 | |
| $\tau_{mrf(district)}^2$ | () | kernel_03 | 3.475489 | 1.868635 | 0.896338 | 3.234976 | 6.973548 | 4000 | 528.736199 | 456.614325 | 1.001866 |
Plots#
samples = results.get_posterior_samples()
gam.plot_regions(
term=loc.terms["mrf(district)"],
samples=samples,
polys=polys,
show_unobserved=False,
unobserved_color="red",
observed_color="none",
)
import plotnine as p9
(
gam.plot_regions(
term=loc.terms["mrf(district)"],
samples=samples,
polys=polys,
which=["hdi_low", "mean", "hdi_high"],
observed_color="none",
unobserved_color="red",
# show_unobserved=False
)
+ p9.theme(figure_size=(10, 3.5))
+ p9.scale_fill_cmap("RdYlBu")
)