from __future__ import annotations
from collections.abc import Sequence
from functools import reduce
from itertools import combinations
from typing import Any, Literal, Self
import jax
import jax.numpy as jnp
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
from formulaic import ModelSpec
from liesel_gam.category_mapping import CategoryMapping
from .basis import Basis, is_diagonal
from .dist import MultivariateNormalSingular, MultivariateNormalStructured
from .var import ScaleIG, UserVar, VarIGPrior, _append_name
InferenceTypes = Any
Array = jax.Array
ArrayLike = jax.typing.ArrayLike
def mvn_diag_prior(scale: lsl.Var) -> lsl.Dist:
"""
Create an independent normal prior for coefficient vectors.
Parameters
----------
scale
Scale variable passed to :class:`tensorflow_probability`'s normal
distribution.
Examples
--------
>>> scale = lsl.Var.new_value(1.0)
>>> mvn_diag_prior(scale).distribution
<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
"""
return lsl.Dist(tfd.Normal, loc=0.0, scale=scale)
def mvn_structured_prior(scale: lsl.Var, penalty: lsl.Var | lsl.Value) -> lsl.Dist:
"""
Create a structured Gaussian prior from a fixed penalty matrix.
The penalty must be strong/fixed; varying penalties are not supported here.
Examples
--------
>>> scale = lsl.Var.new_value(1.0)
>>> penalty = lsl.Value(jnp.eye(2))
>>> prior = mvn_structured_prior(scale, penalty)
>>> prior.distribution is MultivariateNormalSingular
True
"""
if isinstance(penalty, lsl.Var) and not penalty.strong:
raise NotImplementedError(
"Varying penalties are currently not supported by this function."
)
prior = lsl.Dist(
MultivariateNormalSingular,
loc=0.0,
scale=scale,
penalty=penalty,
penalty_rank=jnp.linalg.matrix_rank(penalty.value),
)
return prior
def term_prior(
scale: lsl.Var | None,
penalty: lsl.Var | lsl.Value | None,
) -> lsl.Dist | None:
"""
Select the coefficient prior for a term.
Returns ``None`` when ``scale`` is ``None``, an independent normal prior when no
penalty is supplied, and a structured multivariate normal prior otherwise.
Examples
--------
>>> term_prior(None, None) is None
True
>>> term_prior(lsl.Var.new_value(1.0), None).distribution
<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
"""
if scale is None:
if penalty is not None:
raise ValueError(f"If {scale=}, then penalty must also be None.")
return None
if penalty is None:
return mvn_diag_prior(scale)
return mvn_structured_prior(scale, penalty)
def _init_scale_ig(
x: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
validate_scalar: bool = False,
) -> ScaleIG | lsl.Var | None:
if isinstance(x, VarIGPrior):
concentration = jnp.asarray(x.concentration)
scale_ = jnp.asarray(x.scale)
if validate_scalar:
if not concentration.size == 1:
raise ValueError(
"Expected scalar hyperparameter 'concentration', "
f"got size {concentration.size}"
)
if not scale_.size == 1:
raise ValueError(
f"Expected scalar hyperparameter 'scale', got size {scale_.size}"
)
scale_var: ScaleIG | lsl.Var | None = ScaleIG(
value=jnp.sqrt(jnp.array(x.value)),
concentration=concentration,
scale=scale_,
)
elif isinstance(x, ScaleIG | lsl.Var):
if isinstance(x, ScaleIG):
if x._variance_param.strong:
x._variance_param.value = jnp.asarray(x._variance_param.value)
x.update()
elif x.strong:
try:
x.value = jnp.asarray(x.value)
except Exception as e:
raise TypeError(
f"Unexpected type for scale value: {type(x.value)}"
) from e
scale_var = x
if validate_scalar:
size = jnp.asarray(scale_var.value).size
if not size == 1:
raise ValueError(f"Expected scalar scale, got size {size}")
elif x is None:
scale_var = x
else:
try:
scale_var = lsl.Var.new_value(jnp.asarray(x))
except Exception as e:
raise TypeError(f"Unexpected type for scale: {type(x)}") from e
if validate_scalar:
size = scale_var.value.size
if not size == 1:
raise ValueError(f"Expected scalar scale, got size {size}")
return scale_var
def _validate_scalar_or_p_scale(scale_value: Array, p):
is_scalar = scale_value.size == 1
is_p = scale_value.size == p
if not (is_scalar or is_p):
raise ValueError(
f"Expected scale to have size 1 or {p}, got size {scale_value.size}"
)
[docs]
class StrctTerm(UserVar):
r"""
General structured additive term.
You probably want to initialize a term using :meth:`.StrctTerm.f`, which will
automatically take the penalty matrix from the supplied basis and has automatic
naming that is convenient in most situations.
A structured additive term represents a smooth or structured effect in a generalized
additive model. The term wraps a design/basis matrix together with a prior/penalty
and a set of coefficients. The object exposes the coefficient variable and evaluates
the term as the matrix-vector product of the basis and the coefficients. The term
evaluates to ``basis @ coef``.
Parameters
----------
basis
A :class:`.Basis` instance that produces the design matrix for the term. The
basis must evaluate to a 2-D array with shape ``(n_obs, n_bases)``.
penalty
Penalty matrix or a variable/value wrapping the penalty used to construct the
multivariate normal prior for the coefficients.
scale
Scale parameter passed to the coefficient prior.
name
Term name.
inference
Inference specification for this term's coefficient.
coef_name
Name for the coefficient variable. If ``None``, a default name based on ``name``
will be used.
_update_on_init
If ``True`` (default) the internal calculation/graph nodes are evaluated during
initialization. Set to ``False`` to delay initial evaluation.
validate_scalar_scale
If ``True`` (default), the term will error if the ``scale`` variable does not
hold a scalar scale. This is appropriate for most cases. If ``False``, the term
will also allow an array-valued ``scale`` variable of shape ``(nbases,)``. This
only really makes sense when also reparameterizing the term using
:meth:`.factor_scale`. Only use this if you know exactly what you are doing and
you are certain that this is what you want.
See Also
---------
.TermBuilder : Initializes structured additive terms.
.BasisBuilder : Initializes structured additive term basis matrices.
.Basis : Basis matrix object.
.StrctTerm.f : Alternative, more convenient constructor.
.StrctTensorProdTerm : Anisotropic tensor product terms.
Notes
-----
The terms created by this builder generally have the form
.. math::
s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i) \beta_j
= \mathbf{b}(\mathbf{x}_i)^\top \boldsymbol{\beta}
where
- :math:`i=1, \dots, N` is the observation index,
- :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate
observations, where :math:`M` denotes the number of covariates,
- :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i),
\dots, B_J(\mathbf{x}_i)]`
are a set of basis function evaluations, and
- :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]`
are the corresponding coefficients.
In many cases, :math:`\mathbf{x}_i` will consist
of only one covariate.
The basis matrix for such a term is
.. math::
\mathbf{B} = \begin{bmatrix}
\mathbf{b}(\mathbf{x}_1)^\top \\
\vdots \\
\mathbf{b}(\mathbf{x}_N)^\top
\end{bmatrix}.
The coefficient receives a potentially rank-deficient multivariate normal prior
.. math::
p(\boldsymbol{\beta}) \propto \left(\frac{1}{\tau^2}\right)^{
\operatorname{rk}(\mathbf{K})/2}
\exp \left(
- \frac{1}{\tau^2} \boldsymbol{\beta}^\top \mathbf{K} \boldsymbol{\beta}
\right)
with the potentially rank-deficient penalty matrix :math:`\mathbf{K}` of rank
:math:`\operatorname{rk}(\mathbf{K})`. The variance
parameter :math:`\tau^2` acts as an inverse smoothing parameter.
The choice of basis functions :math:`B_j` and penalty matrix :math:`\mathbf{K}`
determines the nature of the term.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=jnp.eye(2),
... )
>>> term = StrctTerm.f(basis, scale=1.0)
>>> term.name, term.nbases, term.value.shape
('f(x)', 2, (4,))
"""
def __init__(
self,
basis: Basis,
penalty: lsl.Var | lsl.Value | ArrayLike | None,
scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
name: str = "",
inference: InferenceTypes = None,
coef_name: str | None = None,
_update_on_init: bool = True,
validate_scalar_scale: bool = True,
):
scale = _init_scale_ig(scale, validate_scalar=validate_scalar_scale)
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
self._basis = basis
if isinstance(penalty, lsl.Var | lsl.Value):
nparam = jnp.shape(penalty.value)[-1]
self._penalty: lsl.Var | lsl.Value | None = penalty
elif penalty is not None:
nparam = jnp.shape(penalty)[-1]
self._penalty = lsl.Value(jnp.asarray(penalty))
else:
nparam = self.nbases
self._penalty = None
prior = term_prior(scale, self._penalty)
if scale is not None:
_validate_scalar_or_p_scale(scale.value, nparam)
self._coef = lsl.Var.new_param(
jnp.zeros(nparam), prior, inference=inference, name=coef_name
)
calc = lsl.Calc(
lambda basis, coef: jnp.dot(basis, coef),
basis=basis,
coef=self.coef,
_update_on_init=_update_on_init,
)
self._scale = scale
super().__init__(calc, name=name)
if _update_on_init:
self.coef.update()
self._scale_is_factored = False
self._disallow_scale_factorization = False
if hasattr(self.scale, "setup_gibbs_inference"):
try:
self.scale.setup_gibbs_inference(self.coef) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
@property
def scale_is_factored(self) -> bool:
"""
Whether the term has been reparameterized using :meth:`.factor_scale`.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm(basis, penalty=None, scale=1.0)
>>> term.scale_is_factored
False
"""
return self._scale_is_factored
@property
def coef(self) -> lsl.Var:
"""
The coefficient variable of this term.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm(basis, penalty=None, scale=1.0)
>>> term.coef.value.shape
(2,)
"""
return self._coef
@property
def basis(self) -> Basis:
"""
The basis matrix object of this term.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm(basis, penalty=None, scale=1.0)
>>> term.basis.name
'B(x)'
"""
return self._basis
@property
def nbases(self) -> int:
"""
Number of basis functions, equal to the number of basis-matrix columns.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> StrctTerm(basis, penalty=None, scale=1.0).nbases
2
"""
return jnp.shape(self.basis.value)[-1]
@property
def scale(self) -> lsl.Var | lsl.Node | None:
"""
The scale variable used by the coefficient prior.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm(basis, penalty=None, scale=1.0)
>>> float(term.scale.value)
1.0
"""
return self._scale
[docs]
def replace_scale(self, new: lsl.Var, disallow_factorization: bool = True) -> None:
"""
Replace the scale variable and update the coefficient prior.
Parameters
----------
new
Replacement scale variable.
disallow_factorization
Whether subsequent scale factorization should be disabled.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm(basis, penalty=None, scale=1.0)
>>> term.replace_scale(lsl.Var.new_value(2.0))
>>> float(term.scale.value)
2.0
"""
if self.scale_is_factored:
raise ValueError(
f"Scale of {self} cannot be replace, because it has been factored "
" using .factor_scale()."
)
self._scale = new
assert self.coef.dist_node is not None
self.coef.dist_node["scale"] = new
self._disallow_scale_factorization = disallow_factorization
def _validate_scale_for_factoring(self):
if self.scale is None:
raise ValueError(
f"Scale factorization of {self} fails, because {self.scale=}."
)
if self.scale.value.size > 1:
raise ValueError(
f"Scale factorization of {self} fails, "
f"because scale must be scalar, but got {self.scale.value.size=}."
)
def _validate_penalty_for_factoring(self, atol: float = 1e-5) -> Array:
if self._penalty is None:
return jnp.array(self.coef.value.shape[-1])
pen_rank = jnp.linalg.matrix_rank(self._penalty.value)
if pen_rank == self._penalty.value.shape[-1]:
# full-rank penalty always works
return pen_rank
if not is_diagonal(self._penalty.value, atol):
# rank-deficient penalty must be diagonal
raise ValueError(
"With rank deficient penalties, factoring out the scale is "
"only supported when using diagonalized penalties. "
"This is "
"because the scale is only applied to the penalized part, "
"and we cannot reliably distinguish the penalized and "
"unpenalized parts without diagonalization."
)
unpenalized_parts = self._penalty.value[pen_rank:, pen_rank:]
zeros = jnp.zeros_like(unpenalized_parts)
if not jnp.allclose(unpenalized_parts, zeros, atol=atol):
# rank-deficient part must be the last rows/columns of the penalty
raise ValueError(
"With rank deficient penalties, factoring out the scale is "
"only supported when using diagonalized penalties. "
"The null space of the penalty must be organized in the "
"last R rows/columns, i.e. these must be all zero. "
"R refers to the rank of the penalty, in your "
f"case: {pen_rank}. "
"Your penalty seems to be diagonal, but not have these "
"zero-row/columns."
"This is important"
"because the scale is only applied to the penalized part, "
"and we cannot reliably distinguish the penalized and "
"unpenalized parts without this structure."
)
return pen_rank
[docs]
def factor_scale(self, atol: float = 1e-5) -> Self:
"""
Turn this term into a partially standardized form.
This means the prior for the coefficient will be turned from ``coef ~ N(0,
scale^2 * inv(penalty))`` into ``latent_coef ~ N(0, inv(penalty)); coef = scale
* latent_coef``.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=jnp.eye(2),
... )
>>> term = StrctTerm.f(basis, scale=1.0)
>>> term.factor_scale().scale_is_factored
True
"""
self._validate_scale_for_factoring()
pen_rank = self._validate_penalty_for_factoring(atol)
if self._scale_is_factored:
return self
assert self.coef.dist_node is not None
self.coef.dist_node["scale"] = lsl.Value(jnp.array(1.0))
assert self.scale is not None # checked in validation method above
if self.scale.name and self.coef.name:
scaled_name = self.scale.name + "*" + self.coef.name
else:
scaled_name = _append_name(self.coef.name, "_scaled")
def scale_coef(scale, coef):
coef = coef.at[:pen_rank].set(coef[:pen_rank] * scale)
return coef
scaled_coef = lsl.Var.new_calc(
scale_coef,
self.scale,
self.coef,
name=scaled_name,
)
self.value_node["coef"] = scaled_coef
self.coef.update()
self.update()
self._scale_is_factored = True
if hasattr(self.scale, "setup_gibbs_inference_factored"):
try:
pen = self._penalty.value if self._penalty is not None else None
self.scale.update()
self.scale.setup_gibbs_inference_factored(
scaled_coef, self.coef, penalty=pen
) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
return self
[docs]
@classmethod
def f(
cls,
basis: Basis,
fname: str = "f",
scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
inference: InferenceTypes = None,
coef_name: str | None = None,
factor_scale: bool = False,
) -> Self:
"""
Construct a smooth term from a :class:`.Basis`.
This convenience constructor builds a named ``term`` using the
provided basis. The penalty matrix is taken from ``basis.penalty`` and
a coefficient variable with an appropriate multivariate-normal prior
is created. The returned term evaluates to ``basis @ coef``.
Parameters
----------
basis
Basis object that provides the design matrix and penalty for the \
smooth term. The basis must have an associated input variable with \
a meaningful name (used to compose the term name).
fname
Function-name prefix used when constructing the term name. Default \
is ``'f'`` which results in names like ``f(x)`` when the basis \
input is named ``x``.
scale
Scale parameter passed to the coefficient prior.
inference
Inference specification forwarded to the coefficient variable \
creation, a :class:`liesel.goose.MCMCSpec`.
factor_scale
If ``True``, the term is reparameterized by factoring out the scale \
form via :meth:`.factor_scale` before being returned.
coef_name
Coefficient name. The default coefficient name is a LaTeX-like string \
``"$\\beta_{f(x)}$"`` to improve readability in printed summaries.
Returns
-------
A term instance configured with the given basis and prior settings.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = StrctTerm.f(basis, scale=1.0)
>>> term.name, term.coef.value.shape
('f(x)', (2,))
"""
if not basis.x.name:
raise ValueError("basis.x must be named.")
if not basis.name:
raise ValueError("basis must be named.")
if not isinstance(fname, str):
raise TypeError(f"Expected type str, got {type(fname)}.")
name = f"{fname}({basis.x.name})"
coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$"
term = cls(
basis=basis,
penalty=basis.penalty if scale is not None else None,
scale=scale,
inference=inference,
coef_name=coef_name,
name=name,
validate_scalar_scale=not factor_scale,
)
if factor_scale:
term.factor_scale()
return term
def _assert_penalty_is_basis_penalty(self):
if self._penalty is None:
raise ValueError(
f"Penalty of {self} is None."
" This functionality is only available if the term is initialized with "
"the same penalty object as its basis."
)
if self._penalty is not self.basis.penalty:
raise ValueError(
f"Different penalty objects found on {self} and its basis {self.basis}."
" This functionality is only available if the term is initialized with "
"the same penalty object as its basis."
)
[docs]
def diagonalize_penalty(self, atol: float = 1e-6) -> Self:
"""
Diagonalize the penalty via an eigenvalue decomposition.
This method computes a transformation that diagonalizes the penalty matrix and
updates the internal basis function such that subsequent evaluations use the
accordingly transformed basis. The penalty is updated to the diagonalized
version.
Returns
-------
The modified term instance (self).
See Also
--------
.Basis.diagonalize_penalty : The term calls this method internally. More details
are documented there.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=jnp.eye(2),
... )
>>> term = StrctTerm.f(basis, scale=1.0).diagonalize_penalty()
>>> bool(is_diagonal(term.basis.penalty.value))
True
"""
self._assert_penalty_is_basis_penalty()
self.basis.diagonalize_penalty(atol)
return self
[docs]
def scale_penalty(self) -> Self:
"""
Scale the penalty matrix by its infinity norm.
The penalty matrix is divided by its infinity norm (max absolute row sum) so
that its values are numerically well-conditioned for downstream use. The updated
penalty replaces the previous one.
Returns
-------
The modified term instance (self).
See Also
--------
.Basis.scale_penalty : The term calls this method internally. More details
are documented there.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=2.0 * jnp.eye(2),
... )
>>> term = StrctTerm.f(basis, scale=1.0).scale_penalty()
>>> float(jnp.max(term.basis.penalty.value))
1.0
"""
self._assert_penalty_is_basis_penalty()
self.basis.scale_penalty()
return self
[docs]
def constrain(
self,
constraint: ArrayLike
| Literal["sumzero_term", "sumzero_coef", "constant_and_linear"],
) -> Self:
"""
Apply a linear constraint to the term's basis and corresponding penalty.
Parameters
----------
constraint
Type of constraint or custom linear constraint matrix to apply. If an
array is supplied, the constraint will be ``A @ coef == 0``, where ``A``
is the supplied constraint matrix.
Returns
-------
The modified term instance (self).
See Also
--------
.Basis.constrain : The term calls this method internally. More details
are documented there.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=jnp.eye(2),
... )
>>> term = StrctTerm.f(basis, scale=1.0).constrain(jnp.ones((1, 2)))
>>> term.nbases
1
"""
self._assert_penalty_is_basis_penalty()
self.basis.constrain(constraint)
self.coef.value = jnp.zeros(self.nbases)
return self
SmoothTerm = StrctTerm
[docs]
class MRFTerm(StrctTerm):
"""
Term object for Markov random fields.
Derived from :class:`.StrctTerm`, with a few additional attributes that give
access to information about the Markov random field setup.
Examples
--------
``MRFTerm`` objects are usually created by :class:`.TermBuilder`, which also
attaches label and neighborhood metadata.
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... nb={"a": ["b"], "b": ["a"]},
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> term.neighbors
{'a': ['b'], 'b': ['a']}
>>> term.labels
['a', 'b']
"""
_neighbors = None
_polygons = None
_ordered_labels = None
_labels = None
_mapping = None
@property
def neighbors(self) -> dict[str, list[str]] | None:
"""
Dictionary of neighborhood structure (if available).
The keys are region labels. The values are lists of the labels of neighboring
regions.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... nb={"a": ["b"], "b": ["a"]},
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> term.neighbors["a"]
['b']
"""
return self._neighbors
@neighbors.setter
def neighbors(self, value: dict[str, list[str]] | None) -> None:
"""
Set the neighborhood dictionary for the term.
This setter is primarily used by :meth:`.TermBuilder.mrf` after constructing
the Markov random field basis.
"""
self._neighbors = value
@property
def polygons(self) -> dict[str, ArrayLike] | None:
"""
Dictionary of arrays. The keys of the dict are the region labels. The
corresponding values define each region through a 2-D array of polygon
information.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["a", "b", "a"]})
>>> polys = {
... "a": jnp.array([[0.0, 0.0], [1.0, 0.0]]),
... "b": jnp.array([[1.0, 0.0], [2.0, 0.0]]),
... }
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... polys=polys,
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> sorted(term.polygons)
['a', 'b']
"""
return self._polygons
@polygons.setter
def polygons(self, value: dict[str, ArrayLike] | None) -> None:
"""
Set polygon coordinates keyed by region label.
This setter is primarily used by :meth:`.TermBuilder.mrf`.
"""
self._polygons = value
@property
def labels(self) -> list[str] | None:
"""
Region labels.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["b", "a", "b"]})
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... nb={"a": ["b"], "b": ["a"]},
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> term.labels
['a', 'b']
"""
return self._labels
@labels.setter
def labels(self, value: list[str]) -> None:
"""
Set the region labels.
This setter is primarily used by :meth:`.TermBuilder.mrf`.
"""
self._labels = value
@property
def mapping(self) -> CategoryMapping:
"""
A label-integer mapping for the regions in this term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... nb={"a": ["b"], "b": ["a"]},
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> term.mapping.labels_to_integers(["b"]).tolist()
[1]
"""
if self._mapping is None:
raise ValueError("No mapping defined.")
return self._mapping
@mapping.setter
def mapping(self, value: CategoryMapping) -> None:
"""
Set the label-integer mapping for the regions.
This setter is primarily used by :meth:`.TermBuilder.mrf`.
"""
self._mapping = value
@property
def ordered_labels(self) -> list[str] | None:
"""
Ordered labels, if they are available.
Ordering is such that the order corresponds to the columns of the basis
and penalty matrices. Only available for unconstrained MRF.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"region": ["b", "a", "b"]})
>>> term = gam.TermBuilder.from_df(df).mrf(
... "region",
... nb={"a": ["b"], "b": ["a"]},
... absorb_cons=False,
... diagonal_penalty=False,
... scale_penalty=False,
... scale=1.0,
... )
>>> term.ordered_labels
['a', 'b']
"""
return self._ordered_labels
@ordered_labels.setter
def ordered_labels(self, value: list[str]) -> None:
"""
Set labels ordered like the basis and penalty columns.
This setter is primarily used by :meth:`.TermBuilder.mrf` when the basis
still has a clear parameter-to-label correspondence.
"""
self._ordered_labels = value
[docs]
class IndexingTerm(StrctTerm):
"""
Term object for memory-efficient representation of sparse bases.
Derived from :class:`.StrctTerm`.
If the basis matrix of a term is a dummy matrix, where each column consists only of
binary (0/1) entries, and each row has only one non-zero entry, then it is not
necessary to store the full matrix in memory and evaluate the term as a dot product
``basis @ coef``.
Instead, we can simply store a 1-D array of indices, identifying the nonzero column
for each row of the basis matrix, and use this index to access the corresponding
coefficient. This scenario is common for independent random intercepts.
This class implements such a sparse representation.
In case you do need to materialize the full, sparse basis of such a term, you can
use :meth:`.IndexingTerm.init_full_basis`.
Examples
--------
>>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None)
>>> term = IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0)
>>> term.value.shape, term.nclusters
((4,), 2)
"""
def __init__(
self,
basis: Basis,
penalty: lsl.Var | lsl.Value | Array | None,
scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
name: str = "",
inference: InferenceTypes = None,
coef_name: str | None = None,
_update_on_init: bool = True,
validate_scalar_scale: bool = True,
):
if not basis.value.ndim == 1:
raise ValueError(f"IndexingTerm requires 1d basis, got {basis.value.ndim=}")
if not jnp.issubdtype(jnp.asarray(basis.value).dtype, jnp.integer):
raise TypeError(
"IndexingTerm requires integer basis, "
f"got {jnp.asarray(basis.value).dtype=}."
)
super().__init__(
basis=basis,
penalty=penalty,
scale=scale,
name=name,
inference=inference,
coef_name=coef_name,
_update_on_init=False,
validate_scalar_scale=validate_scalar_scale,
)
# mypy warns that self.value_node might be a lsl.Node, which does not have the
# attribute "function".
# But we can assume safely that self.value_node is a lsl.Calc, which does have
# one.
assert isinstance(self.value_node, lsl.Calc)
self.value_node.function = lambda basis, coef: jnp.take(coef, basis)
if _update_on_init:
self.coef.update()
self.update()
@property
def nbases(self) -> int:
"""
Number of coefficients represented by the indexed basis.
Examples
--------
>>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None)
>>> IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0).nbases
2
"""
return self.nclusters
@property
def nclusters(self) -> int:
"""
Number of unique clusters in this term (equals the number of coefficients).
Examples
--------
>>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None)
>>> IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0).nclusters
2
"""
nclusters = jnp.unique(self.basis.value).size
return int(nclusters)
[docs]
def init_full_basis(self) -> Basis:
"""
Materializes a :class:`.Basis` object that holds the full basis matrix
corresponding to this term.
Examples
--------
>>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None)
>>> term = IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0)
>>> term.init_full_basis().value.shape
(4, 2)
"""
full_basis = Basis(
self.basis.x, basis_fn=jax.nn.one_hot, num_classes=self.nclusters, name=""
)
return full_basis
[docs]
class RITerm(IndexingTerm):
"""
Term object for memory-efficient representation of independent random intercepts.
Specialized subclass of :class:`.IndexingTerm`, which itself is derived from
:class:`.StrctTerm`.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"group": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0)
>>> term.nclusters
2
>>> term.labels
['a', 'b']
"""
_labels = None
_mapping = None
@property
def nclusters(self) -> int:
"""
Number of clusters represented by this random-intercept term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"group": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0)
>>> term.nclusters
2
"""
try:
nclusters = len(self.mapping.labels_to_integers_map)
except ValueError:
nclusters = jnp.unique(self.basis.value).size
return int(nclusters)
@property
def labels(self) -> list[str]:
"""
List of labels for all clusters represented by this term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"group": ["b", "a", "b"]})
>>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0)
>>> term.labels
['a', 'b']
"""
if self._labels is None:
raise ValueError("No labels defined.")
return self._labels
@labels.setter
def labels(self, value: list[str]) -> None:
"""
Set the labels for all clusters.
This setter is primarily used by :meth:`.TermBuilder.ri`.
"""
if not len(value) == self.nclusters:
raise ValueError(f"Expected {self.nclusters} labels, got {len(value)}.")
self._labels = value
@property
def mapping(self) -> CategoryMapping:
"""
A label-integer mapping for the clusters in this term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"group": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0)
>>> term.mapping.labels_to_integers(["a"]).tolist()
[0]
"""
if self._mapping is None:
raise ValueError("No mapping defined.")
return self._mapping
@mapping.setter
def mapping(self, value: CategoryMapping) -> None:
"""
Set the label-integer mapping for the clusters.
This setter is primarily used by :meth:`.TermBuilder.ri`.
"""
self._mapping = value
[docs]
class BasisDot(UserVar):
"""
Basic term variable for a dot-product ``basis @ coef``.
In comparison to :class:`.StrctTerm`, this class makes fewer assumptions, since it
does not assume any prior distribution, or structure of the prior distribution, for
the coefficients. Instead, a prior for the coefficients of this term (if desired) is
defined manually as a :class:`liesel.model.Dist` in the ``prior`` argument.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = BasisDot(basis)
>>> term.coef.value.shape, term.value.shape
((2,), (4,))
"""
def __init__(
self,
basis: Basis,
prior: lsl.Dist | None = None,
name: str = "",
inference: InferenceTypes = None,
coef_name: str | None = None,
_update_on_init: bool = True,
):
self.basis = basis
self.nbases = self.basis.nbases
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
self.coef = lsl.Var.new_param(
jnp.zeros(self.basis.nbases), prior, inference=inference, name=coef_name
)
calc = lsl.Calc(
lambda basis, coef: jnp.dot(basis, coef),
basis=self.basis,
coef=self.coef,
_update_on_init=_update_on_init,
)
super().__init__(calc, name=name)
[docs]
class LinMixin:
"""
Mixin that adds formula metadata to linear-term classes.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"x": [0.0, 1.0], "z": [1.0, 2.0]})
>>> term = gam.TermBuilder.from_df(df).lin("x + z")
>>> term.column_names
['x', 'z']
"""
_model_spec: ModelSpec | None = None
_mappings: dict[str, CategoryMapping] | None = None
_column_names: list[str] | None = None
@property
def model_spec(self) -> ModelSpec:
"""
The model spec used internally by ``formulaic`` to set up the basis matrix.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"x": [0.0, 1.0], "z": [1.0, 2.0]})
>>> term = gam.TermBuilder.from_df(df).lin("x + z")
>>> str(term.model_spec.formula)
'1 + x + z'
"""
if self._model_spec is None:
raise ValueError("No model spec defined.")
return self._model_spec
@model_spec.setter
def model_spec(self, value: ModelSpec):
"""
Set the :class:`formulaic.ModelSpec` used by this linear term.
This setter is primarily used by :meth:`.TermBuilder.lin` and
:meth:`.TermBuilder.slin`.
"""
if not isinstance(value, ModelSpec):
raise TypeError(
f"Replacement must be of type {ModelSpec}, got {type(value)}."
)
self._model_spec = value
@property
def mappings(self) -> dict[str, CategoryMapping]:
"""
A dictionary of label-integer mappings for all categorical variables in this
term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"cat": ["a", "b", "a"]})
>>> term = gam.TermBuilder.from_df(df).lin("cat")
>>> term.mappings["cat"].labels_to_integers(["b"]).tolist()
[1]
"""
if self._mappings is None:
raise ValueError("No mappings defined.")
return self._mappings
@mappings.setter
def mappings(self, value: dict[str, CategoryMapping]):
"""
Set categorical label mappings for this linear term.
This setter is primarily used by :meth:`.TermBuilder.lin` and
:meth:`.TermBuilder.slin`.
"""
if not isinstance(value, dict):
raise TypeError(f"Replacement must be of type dict, got {type(value)}.")
for val in value.values():
if not isinstance(val, CategoryMapping):
raise TypeError(
f"The values in the replacement must be of type {CategoryMapping}, "
f"got {type(val)}."
)
self._mappings = value
@property
def column_names(self) -> list[str]:
"""
List of column names for this term.
Examples
--------
>>> import pandas as pd
>>> import liesel_gam as gam
>>> df = pd.DataFrame({"x": [0.0, 1.0], "cat": ["a", "b"]})
>>> term = gam.TermBuilder.from_df(df).lin("x + cat")
>>> term.column_names
['x', 'cat[T.b]']
"""
if self._column_names is None:
raise ValueError("No column names defined.")
return self._column_names
@column_names.setter
def column_names(self, value: Sequence[str]):
"""
Set column names for this term.
This setter is primarily used by :meth:`.TermBuilder.lin` and
:meth:`.TermBuilder.slin`.
"""
if not isinstance(value, Sequence):
raise TypeError(f"Replacement must be a sequence, got {type(value)}.")
if isinstance(value, str):
raise TypeError("Replacement type cannot be string.")
for val in value:
if not isinstance(val, str):
raise TypeError(
f"The values in the replacement must be of type str, "
f"got {type(val)}."
)
self._column_names = list(value)
[docs]
class LinTerm(BasisDot, LinMixin):
"""
Specialized :class:`.BasisDot` for general linear effects.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x")
>>> term = LinTerm(basis)
>>> term.value.shape
(4,)
"""
pass
[docs]
class StrctLinTerm(StrctTerm, LinMixin):
"""
Specialized :class:`.StrctTerm` for linear effects.
This term can be used, for example, to set up linear effects with a ridge prior.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> basis = Basis(
... jnp.column_stack([jnp.ones_like(x), x]),
... xname="x",
... penalty=jnp.eye(2),
... )
>>> term = StrctLinTerm(basis, penalty=basis.penalty, scale=1.0)
>>> term.coef.value.shape
(2,)
"""
pass
[docs]
class StrctInteractionTerm(UserVar):
r"""
Anisotropic structured additive interaction term.
Parameters
----------
*marginals
Marginal terms.
common_scale
A single, common scale to cover all marginal dimensions, resulting in an
isotropic tensor product. This means setting
:math:`\tau^2_1 = \dots = \tau^2_M = \tau^2` for all marginal smooths
in the notation used below.
name
Name of the term
coef_name
Name of the coefficient variable. If ``None``, created automatically based on
``name``.
basis_name
Name of the basis variable. This variable is internally created to represent the
tensor product of the marginal basis matrices. If ``None``, the name will be
created automatically based on the names of the observed input variables to the
marginal terms.
include_main_effects
If ``True``, the marginal terms will be added to this term's value.
_update_on_init
Whether to update the term upon initialization.
See Also
--------
.StrctTerm : Basic (isotropic) structured additive term.
.StrctTensorProdTerm : Full anisotropic tensor product term.
.TermBuilder : Initializes structured additive terms.
.BasisBuilder : Initializes structured additive term basis matrices.
.Basis : Basis matrix object.
.StrctTerm.f : Alternative, more convenient constructor.
Notes
-----
.. note::
The classes :class:`.StrctInteractionTerm` and :class:`.StrctTensorProdTerm`
are closely related. The former loosely corresponds to ``mgcv::ti``, and the
latter loosely corresponds to ``mgcv::te``, meaning that, when you supply
centered marginals, :class:`.StrctInteractionTerm` will *only* include the
highest-order interaction of the supplied marginals, while
:class:`.StrctTensorProdTerm` will include the highest-order interaction *and*
all lower-order interactions, including the main effects.
Assumes that the term is a tensor product of :math:`M` marginal bases that can be
written as
.. math::
s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i)\beta_j =
\mathbf{b}^\top \boldsymbol{\beta},
where
- :math:`i=1, \dots, N` is the observation index,
- :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate
observations, where :math:`M` denotes the number of covariates included in this
term,
- :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i),
\dots, B_J(\mathbf{x}_i)]`
are a set of basis function evaluations, and
- :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]`
are the corresponding coefficients.
The vector of basis function evaluations is the Kronecker product of the marginal
bases:
.. math::
\mathbf{b}(\mathbf{x}_i)^\top = \mathbf{b}_1(x_{i,1})^\top
\otimes \mathbf{b}_2(x_{i,2})^\top
\otimes \cdots \otimes
\mathbf{b}_M(x_{i,M})^\top,
In this notation, we assume that the marginal bases are often functions of just
one covariate each, which is the common case.
The individual terms have (potentially different) basis dimensions
:math:`J_1, \dots, J_M`, such that the tensor product basis dimension is
:math:`J = \prod_{m=1}^M J_m`.
The coefficient vector is equipped with a potentially rank-deficient multivariate
Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as
.. math::
p(\boldsymbol{\beta} | \boldsymbol{\tau}^2)
\propto
\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2}
\exp \left(
- \frac{1}{2}
\boldsymbol{\beta}^\top
\mathbf{K}(\boldsymbol{\tau}^2)
\boldsymbol{\beta}
\right),
with the precision matrix constructed from marginal penalties
:math:`\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M`
and variance parameters :math:`\tau^2_1,\dots, \tau^2_M` as
.. math::
\mathbf{K}(\boldsymbol{\tau}^2)
= \frac{\mathbf{K}_1}{\tau^2_1}
+
\cdots
+
\frac{\mathbf{K}_M}{\tau^2_M},
where
.. math::
\mathbf{K}_m = \mathbf{I}_{J_1}
\otimes \cdots \otimes
\mathbf{I}_{J_{m-1}}
\otimes
\tilde{\mathbf{K}}_m
\otimes
\mathbf{I}_{J_{m+1}}
\otimes
\cdots
\mathbf{I}_{J_{M}},
and :math:`\mathbf{I}_{J_m}` denotes the identity matrix of dimension
:math:`J_m \times J_m`.
Since :math:`\mathbf{K}(\boldsymbol{\tau}^2)` may be rank-deficient,
:math:`\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))` is the
pseudo-determinant, or generalized determinant.
This term exploits the clearly defined structure of the precision matrix
to obtain a computationally and memory-efficient evaluation of the prior,
implemented in the :class:`.MultivariateNormalStructured` distribution class.
We also implement the results obtained by Bach & Klein (2025) for efficiently
computing the pseudo-determinant; a key prerequisite for making higher-dimensional
tensor products feasible.
References
----------
- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego
system for building structured additive distributional regression models with
tensor product interactions. TEST, 28(1), 1–39.
https://doi.org/10.1007/s11749-019-00631-z
- Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using
Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43.
https://doi.org/10.1007/s11222-025-10569-y
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> sx = StrctTerm.f(bx, scale=1.0)
>>> sy = StrctTerm.f(by, scale=2.0)
>>> term = StrctInteractionTerm(sx, sy)
>>> term.basis.value.shape, term.coef.value.shape
((4, 4), (4,))
"""
def __init__(
self,
*marginals: StrctTerm | IndexingTerm | RITerm | MRFTerm,
common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
name: str = "",
inference: InferenceTypes = None,
coef_name: str | None = None,
basis_name: str | None = None,
include_main_effects: bool = False,
_update_on_init: bool = True,
):
for term__ in marginals:
if term__.scale is None:
raise ValueError(
f"Scale of {term__} is None, which is not allowed "
f"in {type(self).__name__}."
)
if include_main_effects and len(marginals) > 2:
raise ValueError(
"For more than two marginals, include_main_effects=True is probably "
"not the right tool, since it will not include lower-order "
"interactions. Please consider using liesel_gam.StrctTensorProdTerm "
"instead."
)
self._validate_marginals(marginals)
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
bases = self._get_bases(marginals)
penalties = self._get_penalties(bases)
if common_scale is None:
scales = [t.scale for t in marginals if t.scale is not None]
else:
scale_ = _init_scale_ig(common_scale)
if isinstance(scale_, ScaleIG):
var_ = scale_.value_node[0]
if isinstance(var_, lsl.Var):
if var_.inference is None or not hasattr(var_.inference, "kernel"):
pass
else:
if hasattr(var_.inference.kernel, "__name__"):
scale_kernel_name = var_.inference.kernel.__name__
else:
scale_kernel_name = ""
if scale_kernel_name == "StarVarianceGibbs":
raise ValueError(
f"{scale_kernel_name} kernel is invalid for a tensor "
"product. "
"Please manually set a valid .inference specification "
"for your "
"common scale."
)
assert scale_ is not None
scales = [scale_] * len(bases)
_rowwise_kron = jax.vmap(jnp.kron)
def rowwise_kron(*bases):
return reduce(_rowwise_kron, bases)
self.xnames = ",".join(list(self._input_obs(bases)))
if basis_name is None:
basis_name = "B(" + self.xnames + ")"
assert basis_name is not None
basis = lsl.Var.new_calc(rowwise_kron, *bases, name=basis_name)
nbases = jnp.shape(basis.value)[-1]
mvnds = MultivariateNormalStructured.get_locscale_constructor(
penalties=penalties
)
scales_var = lsl.Calc(lambda *x: jnp.stack(x, axis=-1), *scales)
prior = lsl.Dist(distribution=mvnds, loc=jnp.zeros(nbases), scales=scales_var)
coef = lsl.Var.new_param(
jnp.zeros(nbases),
distribution=prior,
inference=inference,
name=coef_name,
)
self.basis = basis
self.marginals = marginals
self.bases = bases
self.penalties = penalties
self.scales = scales
self.nbases = nbases
self.basis = basis
self.coef = coef
self.scale = scales_var
self.include_main_effects = include_main_effects
if include_main_effects:
calc = lsl.Calc(
lambda *marginals, basis, coef: sum(marginals) + jnp.dot(basis, coef),
*marginals,
basis=basis,
coef=self.coef,
_update_on_init=_update_on_init,
)
else:
calc = lsl.Calc(
lambda basis, coef: jnp.dot(basis, coef),
basis=basis,
coef=self.coef,
_update_on_init=_update_on_init,
)
super().__init__(calc, name=name)
if _update_on_init:
self.coef.update()
@staticmethod
def _get_bases(
marginals: Sequence[StrctTerm | RITerm | MRFTerm | IndexingTerm],
) -> list[Basis]:
bases = []
for t in marginals:
if hasattr(t, "init_full_basis"):
bases.append(t.init_full_basis())
else:
bases.append(t.basis)
return bases
@staticmethod
def _get_penalties(bases: Sequence[Basis]) -> list[Array]:
penalties = []
for b in bases:
if b.penalty is None:
raise TypeError(
f"All bases must have a penalty matrix, got 'None' for {b}."
)
penalties.append(b.penalty.value)
return penalties
@staticmethod
def _validate_marginals(marginals: Sequence[StrctTerm]):
for t in marginals:
if t.scale is None:
raise ValueError(f"Invalid scale for {t}: {t.scale}")
@property
def input_obs(self) -> dict[str, lsl.Var]:
"""
A dictionary of strong input variables.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> term = StrctInteractionTerm(
... StrctTerm.f(bx, scale=1.0),
... StrctTerm.f(by, scale=2.0),
... )
>>> sorted(term.input_obs)
['x', 'y']
"""
return self._input_obs(self.bases)
@staticmethod
def _input_obs(bases: Sequence[Basis]) -> dict[str, lsl.Var]:
# this method includes assumptions about how the individual bases are
# structured: Basis.x can be a strong observed variable directly, or a
# calculator variable that depends on strong observed variables.
# If these assumptions are violated, this method may produce unexpected results.
# The bases created by BasisBuilder fit theses assumptions.
_input_x = {}
for b in bases:
if isinstance(b.x, lsl.Var):
if b.x.strong and b.x.observed:
# case: ordinary univariate marginal basis, like ps
if not b.x.name:
raise ValueError(f"{b}.x is unnamed.")
_input_x[b.x.name] = b.x
elif b.x.weak:
# currently, I don't expect this case to be present
# but it would make sense
for xi in b.x.all_input_vars():
if xi.observed:
if not xi.name:
raise ValueError(f"Observed name not found for {b}")
_input_x[xi.name] = xi
else:
# case: potentially multivariate marginal, possibly thin plate,
# where basis.x is a calculator that collects the strong inputs.
for xj in b.x.all_input_nodes():
if xj.var is not None:
if xj.var.observed:
if not xj.var.name:
raise ValueError(f"Observed name not found for {b}")
_input_x[xj.var.name] = xj.var
return _input_x
[docs]
@classmethod
def f(
cls,
*marginals: StrctTerm,
common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
fname: str = "ta",
inference: InferenceTypes = None,
include_main_effects: bool = False,
_update_on_init: bool = True,
) -> Self:
"""
Alternative constructor with more opinionated automatic naming.
Parameters
----------
*marginals
Marginal terms.
common_scale
A single, common scale to cover both marginal dimensions, resulting in an
isotropic tensor product.
fname
Function-name prefix used when constructing the term name.
include_main_effects
If ``True``, the marginal terms will be added to this term's value.
_update_on_init
Whether to update the term upon initialization.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> term = StrctInteractionTerm.f(
... StrctTerm.f(bx, scale=1.0),
... StrctTerm.f(by, scale=2.0),
... )
>>> term.name, term.coef.value.shape
('ta(x,y)', (4,))
"""
xnames = list(cls._input_obs(cls._get_bases(marginals)))
name = fname + "(" + ",".join(xnames) + ")"
coef_name = "$\\beta_{" + name + "}$"
term = cls(
*marginals,
common_scale=common_scale,
inference=inference,
coef_name=coef_name,
name=name,
basis_name=None,
include_main_effects=include_main_effects,
_update_on_init=_update_on_init,
)
return term
[docs]
class StrctTensorProdTerm(UserVar):
r"""
Anisotropic structured additive tensor product term.
Parameters
----------
*marginals
Marginal terms.
common_scale
A single, common scale to cover all marginal dimensions, resulting in an
isotropic tensor product. This means setting
:math:`\tau^2_1 = \dots = \tau^2_M = \tau^2` for all marginal smooths
in the notation used below.
order
Sequence of integers identifying the orders of interactions to be included
in this term. For example, if you want to include only the bi- and
trivariate interactions when supplying three marginals, pass
``order=(2,3)``. The default ``order=None`` means that *all* orders will be
included; also the main effects.
names_prefix
Prefix to be added to the names of all variables created by this term.
The names of the main effects (marginals) will not be changed.
tx_name
Function component for the names of the interaction terms internally created by
this term. The naming pattern is ``"{tx_name}(x1, x2, ...)"``.
tf_name
Function component for this term's name.
The naming pattern is ``"{tf_name}(x1, x2, ...)"``.
coef_name
Symbol component of the coefficient names created by this term.
The naming pattern is ``"${coef_name}_{term_name}$"``, where ``term_name`` is
the name of a lower-order interaction in this term.
Does not affect the names of the marginal terms' coefficients.
basis_name
Function component for the names of interaction bases internally created by
this term. The naming pattern is ``"{basis_name}(x1, x2, ...)"``.
group_terms_by_order
If ``True``, an intermediate variable object will be created for each order
of interactions in this term. This can help to better organize the
plotted graph of a term, but otherwise has no effect except using slightly
more memory.
_update_on_init
Whether to update the term upon initialization.
See Also
--------
.StrctTerm : Basic (isotropic) structured additive term.
.StrctTensorProdTerm : Full anisotropic tensor product term.
.TermBuilder : Initializes structured additive terms.
.BasisBuilder : Initializes structured additive term basis matrices.
.Basis : Basis matrix object.
.StrctTerm.f : Alternative, more convenient constructor.
Notes
-----
.. note::
The classes :class:`.StrctInteractionTerm` and :class:`.StrctTensorProdTerm`
are closely related. The former loosely corresponds to ``mgcv::ti``, and the
latter loosely corresponds to ``mgcv::te``, meaning that, when you supply
centered marginals, :class:`.StrctInteractionTerm` will *only* include the
highest-order interaction of the supplied marginals, while
:class:`.StrctTensorProdTerm` will include the highest-order interaction *and*
all lower-order interactions, including the main effects.
Assumes that the term is a tensor product of :math:`M` marginal bases that can be
written as
.. math::
s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i)\beta_j =
\mathbf{b}^\top \boldsymbol{\beta},
where
- :math:`i=1, \dots, N` is the observation index,
- :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate
observations, where :math:`M` denotes the number of covariates included in this
term,
- :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i),
\dots, B_J(\mathbf{x}_i)]`
are a set of basis function evaluations, and
- :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]`
are the corresponding coefficients.
The vector of basis function evaluations is the Kronecker product of the marginal
bases:
.. math::
\mathbf{b}(\mathbf{x}_i)^\top = \mathbf{b}_1(x_{i,1})^\top
\otimes \mathbf{b}_2(x_{i,2})^\top
\otimes \cdots \otimes
\mathbf{b}_M(x_{i,M})^\top,
In this notation, we assume that the marginal bases are functions of just
one covariate each, which is the common case.
The individual terms have (potentially different) basis dimensions
:math:`J_1, \dots, J_M`, such that the tensor product basis dimension is
:math:`J = \prod_{m=1}^M J_m`.
The coefficient vector is equipped with a potentially rank-deficient multivariate
Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as
.. math::
p(\boldsymbol{\beta} | \boldsymbol{\tau}^2)
\propto
\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2}
\exp \left(
- \frac{1}{2}
\boldsymbol{\beta}^\top
\mathbf{K}(\boldsymbol{\tau}^2)
\boldsymbol{\beta}
\right),
with the precision matrix constructed from marginal penalties
:math:`\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M`
and variance parameters :math:`\tau^2_1,\dots, \tau^2_M` as
.. math::
\mathbf{K}(\boldsymbol{\tau}^2)
= \frac{\mathbf{K}_1}{\tau^2_1}
+
\cdots
+
\frac{\mathbf{K}_M}{\tau^2_M},
where
.. math::
\mathbf{K}_m = \mathbf{I}_{J_1}
\otimes \cdots \otimes
\mathbf{I}_{J_{m-1}}
\otimes
\tilde{\mathbf{K}}_m
\otimes
\mathbf{I}_{J_{m+1}}
\otimes
\cdots
\mathbf{I}_{J_{M}},
and :math:`\mathbf{I}_{J_m}` denotes the identity matrix of dimension
:math:`J_m \times J_m`.
Since :math:`\mathbf{K}(\boldsymbol{\tau}^2)` may be rank-deficient,
:math:`\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))` is the
pseudo-determinant, or generalized determinant.
This term exploits the clearly defined structure of the precision matrix
to obtain a computationally and memory-efficient evaluation of the prior,
implemented in the :class:`.MultivariateNormalPenaltyOperator` distribution class.
We also implement the results obtained by Bach & Klein (2025) for efficiently
computing the pseudo-determinant; a key prerequisite for making higher-dimensional
tensor products feasible.
References
----------
- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego
system for building structured additive distributional regression models with
tensor product interactions. TEST, 28(1), 1–39.
https://doi.org/10.1007/s11749-019-00631-z
- Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using
Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43.
https://doi.org/10.1007/s11222-025-10569-y
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> sx = StrctTerm.f(bx, scale=1.0)
>>> sy = StrctTerm.f(by, scale=2.0)
>>> term = StrctTensorProdTerm(sx, sy)
>>> len(term.terms), term.value.shape
(3, (4,))
"""
def __init__(
self,
*marginals: StrctTerm | IndexingTerm | RITerm | MRFTerm,
common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
order: Sequence[int] | None = None,
inference: InferenceTypes = None,
names_prefix: str = "",
tx_name: str = "tx",
tf_name: str = "tf",
coef_name: str = r"\beta",
basis_name: str = "B",
group_terms_by_order: bool = False,
_update_on_init: bool = True,
):
for term__ in marginals:
if term__.scale is None:
raise ValueError(
f"Scale of {term__} is None, which is not allowed "
f"in {type(self).__name__}."
)
nmargins = len(marginals)
terms_combinations = []
for i in range(2, (nmargins + 1)):
terms_combinations += list(combinations(marginals, i))
self.order = order if order is not None else tuple(range(1, len(marginals) + 1))
self.terms_by_order: dict[int, list[StrctTerm | StrctInteractionTerm]] = {}
if 1 in self.order:
self.terms_by_order[1] = list(marginals)
scale_ = _init_scale_ig(common_scale) if common_scale is not None else None
interactions = []
for term_marginals in terms_combinations:
order_term = len(term_marginals)
if order_term not in self.order:
continue
term = StrctInteractionTerm(
*term_marginals,
common_scale=scale_,
inference=inference,
_update_on_init=_update_on_init,
include_main_effects=False,
)
term.name = names_prefix + f"{tx_name}({term.xnames})"
term.coef.name = names_prefix + "$" + coef_name + r"_{" + term.name + r"}$"
term.basis.name = names_prefix + basis_name + "(" + term.xnames + ")"
interactions.append(term)
if not self.terms_by_order.get(order_term, None):
self.terms_by_order[order_term] = [term]
else:
self.terms_by_order[order_term].append(term)
for o in self.order:
if o not in self.terms_by_order:
raise ValueError(
f"Order {order} was supplied, but no interactions "
f"of order {o} found."
)
if common_scale is not None:
assert scale_ is not None
for term_ in marginals:
term_.replace_scale(scale_)
self.marginals = marginals
self._terms_list = list(marginals) + interactions
self.bases = StrctInteractionTerm._get_bases(marginals)
self.penalties = StrctInteractionTerm._get_penalties(self.bases)
self.xnames = ",".join(list(self.input_obs))
if group_terms_by_order:
self.term_groups = {}
for o, o_terms in self.terms_by_order.items():
self.term_groups[o] = lsl.Var.new_calc(
lambda *args: sum(args),
*o_terms,
_update_on_init=_update_on_init,
name=names_prefix + f"${tf_name}^{{({o})}}({self.xnames})$",
)
calc = lsl.Calc(
lambda *args: sum(args),
*list(self.term_groups.values()),
_update_on_init=_update_on_init,
)
else:
calc = lsl.Calc(
lambda *args: sum(args),
*self._terms_list,
_update_on_init=_update_on_init,
)
super().__init__(calc, name=names_prefix + f"{tf_name}({self.xnames})")
@property
def terms(
self,
) -> dict[str, StrctTerm | StrctInteractionTerm | IndexingTerm | RITerm | MRFTerm]:
"""
Dictionary of terms contained in this tensor product.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> term = StrctTensorProdTerm(
... StrctTerm.f(bx, scale=1.0),
... StrctTerm.f(by, scale=2.0),
... )
>>> sorted(term.terms)
['f(x)', 'f(y)', 'tx(x,y)']
"""
return {term.name: term for term in self._terms_list}
@property
def scales(self) -> list[lsl.Var | lsl.Node]:
"""
Unique scale variables used by the marginal and interaction terms.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> term = StrctTensorProdTerm(
... StrctTerm.f(bx, scale=1.0),
... StrctTerm.f(by, scale=2.0),
... )
>>> [float(scale.value) for scale in term.scales]
[1.0, 2.0]
"""
scales = []
for i in self.order:
if i == 1:
for term in self.terms_by_order[i]:
if term.scale not in scales and term.scale is not None:
scales.append(term.scale)
else:
for term in self.terms_by_order[i]:
assert hasattr(term, "scales")
for scale in term.scales:
if scale not in scales:
scales.append(scale)
return scales
@property
def input_obs(self) -> dict[str, lsl.Var]:
"""
A dictionary of strong input variables.
Examples
--------
>>> x = jnp.linspace(0.0, 1.0, 4)
>>> y = jnp.linspace(1.0, 2.0, 4)
>>> bx = Basis(
... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2)
... )
>>> by = Basis(
... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2)
... )
>>> term = StrctTensorProdTerm(
... StrctTerm.f(bx, scale=1.0),
... StrctTerm.f(by, scale=2.0),
... )
>>> sorted(term.input_obs)
['x', 'y']
"""
return StrctInteractionTerm._input_obs(self.bases)