from collections.abc import Mapping, Sequence
from typing import Any, Literal
import jax
import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import pandas as pd
from jax import Array
from jax.typing import ArrayLike
from liesel.goose.summary_m import SummaryQuantities
from .registry import CategoryMapping
from .term import (
LinTerm,
MRFTerm,
RITerm,
StrctInteractionTerm,
StrctLinTerm,
StrctTensorProdTerm,
StrctTerm,
)
KeyArray = Any
def _summarise_which(which: str | Sequence[str] | None) -> Sequence[SummaryQuantities]:
basics: Sequence[SummaryQuantities] = ["mean", "sd", "var", "hdi", "quantiles"]
if which is None:
return basics
if "hdi" not in ",".join(which):
basics = [w for w in basics if w != "hdi"]
if "q_" not in ",".join(which):
basics = [w for w in basics if w != "quantiles"]
return basics
[docs]
def summarise_by_samples(
key: Array, a: ArrayLike, name: str = "value", n: int = 100
) -> pd.DataFrame:
"""
Summarizes an array of posterior samples via subsamples.
Parameters
----------
key
Jax key-array (created by ``jax.random.key``) for drawing subsamples.
a
The array to be summarized, assumed to have shape ``(C, S, N)``, where
``C`` is the number of MCMC chains, ``S`` is the number of samples, and
``N`` is the dimension of the quantity to summarize.
Arrays of shape ``(C, S)`` are currently not supported by this function.
name
Column name for the value column in the returned dataframe.
n
Number of subsamples to draw from ``a``.
Returns
-------
A dataframe with the following columns:
- value: sample value
- index: index of the flattened array
- sample: sample number
- obs: observation number (enumerates quantity dimension, ``N``)
- chain: chain number
"""
a = jnp.asarray(a)
iterations = a.shape[1]
a = np.concatenate(a, axis=0)
idx = jax.random.choice(key, a.shape[0], shape=(n,), replace=True)
a_column = a[idx, :].ravel()
sample_column = np.repeat(np.arange(n), a.shape[-1])
index_column = np.repeat(idx, a.shape[-1])
obs_column = np.tile(np.arange(a.shape[-1]), n)
data = {name: a_column, "sample": sample_column}
data["index"] = index_column
data["obs"] = obs_column
df = pd.DataFrame(data)
df["chain"] = df["index"] // iterations
return df
[docs]
def summarise_1d_smooth(
term: StrctTerm,
samples: Mapping[str, ArrayLike],
newdata: None | Mapping[str, ArrayLike] = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
ngrid: int = 150,
) -> pd.DataFrame:
"""
Creates a summary dataframe for a one-dimensional :class:`.StrctTerm`.
Parameters
----------
term
The term to summarise.
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
newdata
Optional dictionary of covariate data at which to summarise the term.
If ``None``, a grid of length ``ngrid`` will be created internally, using the
minimum and maximum observed values of this term's input covariate.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
ngrid
Number of covariate values in the grid used for summary, if ``newdata=None``.
"""
if newdata is None:
# TODO: Currently, this branch of the function assumes that term.basis.x is
# a strong node.
# That is not necessarily always the case.
xgrid = np.linspace(term.basis.x.value.min(), term.basis.x.value.max(), ngrid)
newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: xgrid}
else:
newdata_x = newdata
xgrid = np.asarray(newdata[term.basis.x.name])
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
term_samples = term.predict(dict(samples), newdata=newdata_x)
term_summary = (
gs.SamplesSummary.from_array(
term_samples,
name=term.name,
quantiles=quantiles,
hdi_prob=hdi_prob,
which=_summarise_which(None),
)
.to_dataframe()
.reset_index()
)
term_summary[term.basis.x.name] = xgrid
return term_summary
def grid_nd(inputs: dict[str, jax.typing.ArrayLike], ngrid: int) -> dict[str, Any]:
"""
Creates a meshgrid of the values in ``inputs``.
"""
mins = {k: jnp.min(v) for k, v in inputs.items()}
maxs = {k: jnp.max(v) for k, v in inputs.items()}
grids = {k: np.linspace(mins[k], maxs[k], ngrid) for k in inputs}
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grids.values())]
full_grids = dict(zip(inputs.keys(), full_grid_arrays))
return full_grids
def input_grid_nd_smooth(
term: StrctTensorProdTerm | StrctInteractionTerm | StrctTerm | LinTerm, ngrid: int
) -> dict[str, jax.typing.ArrayLike]:
"""
Creates a dictionary of meshgrids of the covariate input variables to the ``term``
argument.
"""
if isinstance(term, StrctInteractionTerm | StrctTensorProdTerm):
inputs = {k: v.value for k, v in term.input_obs.items()}
return grid_nd(inputs, ngrid)
if not isinstance(term.basis.x, lsl.TransientCalc | lsl.Calc):
raise NotImplementedError(
"Function not implemented for bases with inputs of "
f"type {type(term.basis.x)}."
)
inputs = {n.var.name: n.var.value for n in term.basis.x.all_input_nodes()} # type: ignore
return grid_nd(inputs, ngrid)
def grid_nd_nomesh(
inputs: dict[str, jax.typing.ArrayLike], ngrid: int
) -> dict[str, Any]:
"""
Creates a dictionary of grids based on the ``inputs``.
"""
mins = {k: jnp.min(v) for k, v in inputs.items()}
maxs = {k: jnp.max(v) for k, v in inputs.items()}
grids = {k: np.linspace(mins[k], maxs[k], ngrid) for k in inputs}
return grids
def input_grid_nd_smooth_nomesh(
term: StrctTerm | LinTerm, ngrid: int
) -> dict[str, jax.typing.ArrayLike]:
"""
Creates a dictionary of grids for the covariate input variables to the ``term``
argument.
"""
if not isinstance(term.basis.x, lsl.TransientCalc | lsl.Calc):
raise NotImplementedError(
"Function not implemented for bases with inputs of "
f"type {type(term.basis.x)}."
)
inputs = {n.var.name: n.var.value for n in term.basis.x.all_input_nodes()} # type: ignore
return grid_nd_nomesh(inputs, ngrid)
# using q_0.05 and q_0.95 explicitly here
# even though users could choose to return other quantiles like 0.1 and 0.9
# then they can supply q_0.1 and q_0.9, etc.
PlotVars = Literal[
"mean", "sd", "var", "hdi_low", "hdi_high", "q_0.05", "q_0.5", "q_0.95"
]
[docs]
def summarise_nd_smooth(
term: StrctTerm | StrctInteractionTerm | StrctTensorProdTerm,
samples: Mapping[str, ArrayLike],
newdata: None | Mapping[str, ArrayLike] = None,
ngrid: int = 20,
which: PlotVars | Sequence[PlotVars] = "mean",
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
newdata_meshgrid: bool = False,
) -> pd.DataFrame:
"""
Summarises an n-dimensional smooth.
Parameters
----------
term
The term to summarise, a :class:`.StrctTerm` or :class:`.StrctTensorProdTerm`.
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
newdata
Optional dictionary of covariate data at which to summarise the term. If
``None``, a grid will be created internally, using the minimum and maximum
observed values of this term's input covariates. The ``ngrid`` argument refers
to the number of grid elements used in the marginal grids, so the total grid
length will be ``ngrid**k``, where ``k`` is the number of terms.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
newdata_meshgrid
If *True*, then the function will create a large grid of all combinations of
covariate values in ``newdata`` that correspond to this term.
"""
if isinstance(which, str):
which = [which]
if newdata is None:
newdata_x: Mapping[str, ArrayLike] = input_grid_nd_smooth(term, ngrid=ngrid)
elif newdata_meshgrid:
full_grid_arrays = [v.flatten() for v in np.meshgrid(*newdata.values())]
newdata_x = dict(zip(newdata.keys(), full_grid_arrays))
else:
newdata_x = newdata
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
term_samples = term.predict(dict(samples), newdata=newdata_x)
ci_quantiles_ = (0.05, 0.95) if quantiles is None else quantiles
hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
term_summary = (
gs.SamplesSummary.from_array(
term_samples,
name=term.name,
quantiles=ci_quantiles_,
hdi_prob=hdi_prob_,
which=_summarise_which(which),
)
.to_dataframe()
.reset_index()
)
for k, v in newdata_x.items():
term_summary[k] = np.asarray(v)
term_summary.reset_index(inplace=True)
term_summary = term_summary.melt(
id_vars=["index"] + list(newdata_x.keys()),
value_vars=which,
var_name="variable",
value_name="value",
)
term_summary["variable"] = pd.Categorical(
term_summary["variable"], categories=which
)
return term_summary
[docs]
def polys_to_df(polys: Mapping[str, ArrayLike]) -> pd.DataFrame:
"""
Turns a ``polys`` dictionary into a dataframe appropriate for plotting.
Parameters
----------
polys
Dictionary of arrays. The keys of the dict are the region labels. The
corresponding values define the region by defining polygons. The
neighborhood structure can be inferred from this polygon information.
"""
poly_labels = list(polys)
poly_coords = list(polys.values())
poly_coord_dim = np.shape(poly_coords[0])[-1]
poly_df = pd.concat(
[
pd.DataFrame(
poly_coords[i], columns=[f"V{i}" for i in range(poly_coord_dim)]
).assign(vertex=lambda df: df.index + 1, id=i, label=poly_labels[i])
for i in range(len(polys))
],
ignore_index=True,
)
return poly_df
def _convert_to_integers(
grid: np.typing.NDArray,
labels: Sequence[str] | CategoryMapping | None,
term: RITerm | MRFTerm | lsl.Var,
) -> np.typing.NDArray[np.int_]:
if isinstance(labels, CategoryMapping):
grid = labels.to_integers(grid)
else:
try:
grid = term.mapping.to_integers(grid) # type: ignore
except (ValueError, AttributeError):
if not np.issubdtype(grid.dtype, np.integer):
raise TypeError(
f"There's no mapping available on the term {term}. "
"In this case, its values in 'newdata' must be specified "
f"as integer codes. Got data type {grid.dtype}"
)
return grid
[docs]
def summarise_cluster(
term: RITerm | MRFTerm | StrctTerm,
samples: Mapping[str, ArrayLike],
newdata: gs.Position
| None
| Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = None,
labels: CategoryMapping | Sequence[str] | None = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
) -> pd.DataFrame:
"""
Summarises a discrete term represented by :class:`.RITerm` or :class:`.MRFTerm`.
Parameters
----------
term
The term to summarise.
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
newdata
Dictionary of covariate data at which to summarise the term. If ``None``, uses
the unique clusters known to the term.
labels
Custom mapping to use for mapping between string labels and integer codes.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
"""
if labels is None:
try:
labels = term.mapping # type: ignore
except (AttributeError, ValueError):
labels = None
if newdata is None and isinstance(labels, CategoryMapping):
grid = np.asarray(list(labels.integers_to_labels_map))
unique_x = np.unique(term.basis.x.value)
newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: grid}
observed = [x in unique_x for x in grid]
elif newdata is None:
grid = np.unique(term.basis.x.value)
newdata_x = {term.basis.x.name: grid}
observed = [True for _ in grid]
else:
unique_x = np.unique(term.basis.x.value)
grid = np.asarray(newdata[term.basis.x.name])
grid = _convert_to_integers(grid, labels, term)
observed = [x in unique_x for x in grid]
newdata_x = {term.basis.x.name: grid}
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
predictions = term.predict(samples=dict(samples), newdata=newdata_x)
predictions_summary = (
gs.SamplesSummary.from_array(
predictions,
quantiles=quantiles,
hdi_prob=0.9 if hdi_prob is None else hdi_prob,
which=_summarise_which(None),
)
.to_dataframe()
.reset_index()
)
if isinstance(labels, CategoryMapping):
codes = newdata_x[term.basis.x.name]
labels_str = list(labels.integers_to_labels(codes))
categories = list(labels.labels_to_integers_map)
predictions_summary[term.basis.x.name] = pd.Categorical(
labels_str, categories=categories
)
elif labels is not None:
labels_str = list(labels)
categories = sorted(set(labels_str))
predictions_summary[term.basis.x.name] = pd.Categorical(
labels_str, categories=categories
)
else:
predictions_summary[term.basis.x.name] = pd.Categorical(
np.asarray(term.basis.x.value)
)
predictions_summary["observed"] = observed
return predictions_summary
[docs]
def summarise_regions(
term: RITerm | MRFTerm | StrctTerm,
samples: Mapping[str, ArrayLike],
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
which: PlotVars | Sequence[PlotVars] = "mean",
polys: Mapping[str, ArrayLike] | None = None,
labels: CategoryMapping | Sequence[str] | None = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
) -> pd.DataFrame:
"""
Summarises a discrete spatial term.
Parameters
----------
term
The term to summarise, a :class:`.RITerm` or :class:`.MRFTerm`.
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
newdata
Dictionary of covariate data at which to summarise the term. If ``None``, uses
the unique clusters known to the term.
which
Sequence of strings, indicating the summary quantities to include.
polys
Dictionary of arrays. The keys of the dict are the region labels. The
corresponding values define the region by defining polygons. The
neighborhood structure can be inferred from this polygon information.
labels
Custom mapping to use for mapping between string labels and integer codes.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
"""
polygons = None
if polys is not None:
polygons = polys
else:
try:
# using type ignore here, since the case of term not having the attribute
# polygons is handle by the try except
polygons = term.polygons # type: ignore
except AttributeError:
pass
if not polygons:
raise ValueError(
"When passing a term without polygons, polygons must be supplied manually "
"through the argument 'polys'"
)
df = summarise_cluster(
term=term,
samples=samples,
newdata=newdata,
labels=labels,
quantiles=quantiles,
hdi_prob=hdi_prob,
)
region = term.basis.x.name
if isinstance(which, str):
which = [which]
unique_labels_in_df = df[term.basis.x.name].unique().tolist()
assert polygons is not None
for region_label in polygons:
if region_label not in unique_labels_in_df:
raise ValueError(
f"Label '{region_label}' found in polys, but not in cluster summary. "
f"Known labels: {unique_labels_in_df}"
)
poly_df = polys_to_df(polygons)
df["label"] = df[region].astype(str)
plot_df = poly_df.merge(df, on="label")
plot_df = plot_df.melt(
id_vars=["label", "V0", "V1", "observed"],
value_vars=which,
var_name="variable",
value_name="value",
)
plot_df["variable"] = pd.Categorical(plot_df["variable"], categories=which)
return plot_df
[docs]
def summarise_lin(
term: LinTerm | StrctLinTerm,
samples: Mapping[str, ArrayLike],
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
indices: Sequence[int] | None = None,
) -> pd.DataFrame:
"""
Summarises a linear term.
Parameters
----------
term
The term to summarise.
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
indices
Sequence of integers, selects coefficients or clusters to be included in the
plot. If ``None``, all coefficients/clusters are plotted.
"""
if indices is not None:
coef_samples = jnp.asarray(samples[term.coef.name])[..., indices]
colnames = [term.column_names[i] for i in indices]
else:
coef_samples = jnp.asarray(samples[term.coef.name])
colnames = term.column_names
df = (
gs.SamplesSummary.from_array(
coef_samples,
quantiles=quantiles,
hdi_prob=hdi_prob,
which=_summarise_which(None),
)
.to_dataframe()
.reset_index()
)
df["x"] = colnames
df.drop(["variable", "var_fqn", "var_index"], axis=1, inplace=True)
df.insert(0, "x", df.pop("x"))
return df
[docs]
def summarise_1d_smooth_clustered(
clustered_term: lsl.Var,
samples: Mapping[str, ArrayLike],
newdata: Mapping[str, ArrayLike] | None = None,
which: PlotVars | Sequence[PlotVars] = "mean",
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
labels: CategoryMapping | None | Sequence[str] = None,
ngrid: int = 20,
newdata_meshgrid: bool = False,
):
"""
Summarises a clustered smooth or linear function.
Intended for terms, as returned by :meth:`.TermBuilder.rs`.
Parameters
----------
clustered_term
The term to plot. Must be a weak :class:`liesel.model.Var` with named inputs
``"x"`` (the function) and ``"cluster"`` (the cluster).
samples
Dictionary of posterior samples. Must contain samples for the term's
coefficient.
newdata
Dictionary of covariate data at which to plot the term. If ``None``, plots the
term for the unique clusters known to the term, and uses a grid of length
``ngrid`` between the minimum and maximum observed value in the clustered
function's covariate.
which
Sequence of strings, indicating the summary quantities to include.
quantiles
Probability levels of quantiles to include.
hdi_prob
Probability level for highest posterior density interval.
labels
Custom mapping to use for mapping between string labels and integer codes.
ngrid
Number of covariate values in the grid used for plotting, if ``newdata=None``.
newdata_meshgrid
If *True*, then the function will create a large grid of all combinations of
covariate values in ``newdata`` that correspond to this term.
"""
if isinstance(which, str):
which = [which]
term = clustered_term.value_node["x"]
cluster = clustered_term.value_node["cluster"]
assert isinstance(term, StrctTerm | lsl.Var)
assert isinstance(cluster, RITerm | MRFTerm)
if labels is None:
try:
labels = cluster.mapping # type: ignore
except (AttributeError, ValueError):
labels = None
if isinstance(term, StrctTerm):
x = term.basis.x
else:
x = term
if newdata is None:
if not jnp.issubdtype(x.value.dtype, jnp.floating):
raise TypeError(
"Automatic grid creation is valid only for continuous x, got "
f"dtype {jnp.asarray(x.value).dtype} for {x}."
)
if newdata is None and isinstance(labels, CategoryMapping):
cgrid = np.asarray(list(labels.integers_to_labels_map)) # integer codes
unique_clusters = np.unique(cluster.basis.x.value) # unique codes
if isinstance(x, lsl.Node) and len(x.inputs) == 1:
xgrid: Mapping[str, ArrayLike] = {
x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)
}
elif isinstance(x, lsl.Var) and x.strong:
xgrid = {x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)}
else:
assert isinstance(term, StrctTerm | LinTerm), (
f"Wrong type for term: {type(term)}"
)
ncols = jnp.shape(term.basis.x.value)[-1]
xgrid = {}
xgrid = input_grid_nd_smooth_nomesh(
term, ngrid=int(np.pow(ngrid, 1 / ncols))
)
grid: Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = dict(xgrid) | {
cluster.basis.x.name: cgrid
}
# code : bool
observed = {x: x in unique_clusters for x in cgrid}
elif newdata is None:
cgrid = np.unique(cluster.basis.x.value)
if isinstance(x, lsl.Node) and len(x.inputs) == 1:
xgrid = {x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)}
elif isinstance(x, lsl.Var) and x.strong:
xgrid = {x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)}
else:
assert isinstance(term, StrctTerm | LinTerm), (
f"Wrong type for term: {type(term)}"
)
ncols = jnp.shape(term.basis.x.value)[-1]
xgrid = input_grid_nd_smooth_nomesh(
term, ngrid=int(np.pow(ngrid, 1 / ncols))
)
grid = xgrid | {cluster.basis.x.name: cgrid}
# code : bool
observed = {x: True for x in cgrid}
else:
pass
if newdata is not None and newdata_meshgrid:
cgrid = np.asarray(newdata[cluster.basis.x.name])
cgrid = _convert_to_integers(cgrid, labels, cluster)
grid = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
newdata_x: Mapping[str, ArrayLike] = dict(zip(grid.keys(), full_grid_arrays))
observed = {x: x in cluster.basis.x.value for x in cgrid}
elif newdata is not None:
cgrid = np.asarray(newdata[cluster.basis.x.name])
cgrid = _convert_to_integers(cgrid, labels, cluster)
newdata_x = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
# code : bool
if isinstance(labels, CategoryMapping):
observed = {x: x in cluster.basis.x.value for x in cgrid}
else:
observed = {x: True for x in cgrid}
else: # then we use the grid created from observed data
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
newdata_x = dict(zip(grid.keys(), full_grid_arrays))
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
term_samples = clustered_term.predict(dict(samples), newdata=newdata_x)
term_summary = (
gs.SamplesSummary.from_array(
term_samples,
name=clustered_term.name,
quantiles=quantiles,
hdi_prob=hdi_prob,
)
.to_dataframe()
.reset_index()
)
for k, v in newdata_x.items():
term_summary[k] = np.asarray(v)
if labels is not None:
if isinstance(labels, CategoryMapping):
labels_long = labels.to_labels(newdata_x[cluster.basis.x.name])
categories = list(labels.labels_to_integers_map)
term_summary[cluster.basis.x.name] = pd.Categorical(
labels_long, categories=categories
)
else:
term_summary[cluster.basis.x.name] = labels
term_summary["observed"] = [
observed[x] for x in np.asarray(newdata_x[cluster.basis.x.name])
]
term_summary.reset_index(inplace=True)
return term_summary