Source code for liesel_gam.term

from __future__ import annotations

from collections.abc import Sequence
from functools import reduce
from itertools import combinations
from typing import Any, Literal, Self

import jax
import jax.numpy as jnp
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
from formulaic import ModelSpec

from liesel_gam.category_mapping import CategoryMapping

from .basis import Basis, is_diagonal
from .dist import MultivariateNormalSingular, MultivariateNormalStructured
from .var import ScaleIG, UserVar, VarIGPrior, _append_name

InferenceTypes = Any
Array = jax.Array
ArrayLike = jax.typing.ArrayLike


def mvn_diag_prior(scale: lsl.Var) -> lsl.Dist:
    """
    Create an independent normal prior for coefficient vectors.

    Parameters
    ----------
    scale
        Scale variable passed to :class:`tensorflow_probability`'s normal
        distribution.

    Examples
    --------
    >>> scale = lsl.Var.new_value(1.0)
    >>> mvn_diag_prior(scale).distribution
    <class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
    """
    return lsl.Dist(tfd.Normal, loc=0.0, scale=scale)


def mvn_structured_prior(scale: lsl.Var, penalty: lsl.Var | lsl.Value) -> lsl.Dist:
    """
    Create a structured Gaussian prior from a fixed penalty matrix.

    The penalty must be strong/fixed; varying penalties are not supported here.

    Examples
    --------
    >>> scale = lsl.Var.new_value(1.0)
    >>> penalty = lsl.Value(jnp.eye(2))
    >>> prior = mvn_structured_prior(scale, penalty)
    >>> prior.distribution is MultivariateNormalSingular
    True
    """
    if isinstance(penalty, lsl.Var) and not penalty.strong:
        raise NotImplementedError(
            "Varying penalties are currently not supported by this function."
        )
    prior = lsl.Dist(
        MultivariateNormalSingular,
        loc=0.0,
        scale=scale,
        penalty=penalty,
        penalty_rank=jnp.linalg.matrix_rank(penalty.value),
    )
    return prior


def term_prior(
    scale: lsl.Var | None,
    penalty: lsl.Var | lsl.Value | None,
) -> lsl.Dist | None:
    """
    Select the coefficient prior for a term.

    Returns ``None`` when ``scale`` is ``None``, an independent normal prior when no
    penalty is supplied, and a structured multivariate normal prior otherwise.

    Examples
    --------
    >>> term_prior(None, None) is None
    True
    >>> term_prior(lsl.Var.new_value(1.0), None).distribution
    <class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
    """
    if scale is None:
        if penalty is not None:
            raise ValueError(f"If {scale=}, then penalty must also be None.")
        return None

    if penalty is None:
        return mvn_diag_prior(scale)

    return mvn_structured_prior(scale, penalty)


def _init_scale_ig(
    x: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
    validate_scalar: bool = False,
) -> ScaleIG | lsl.Var | None:
    if isinstance(x, VarIGPrior):
        concentration = jnp.asarray(x.concentration)
        scale_ = jnp.asarray(x.scale)

        if validate_scalar:
            if not concentration.size == 1:
                raise ValueError(
                    "Expected scalar hyperparameter 'concentration', "
                    f"got size {concentration.size}"
                )

            if not scale_.size == 1:
                raise ValueError(
                    f"Expected scalar hyperparameter 'scale', got size {scale_.size}"
                )

        scale_var: ScaleIG | lsl.Var | None = ScaleIG(
            value=jnp.sqrt(jnp.array(x.value)),
            concentration=concentration,
            scale=scale_,
        )
    elif isinstance(x, ScaleIG | lsl.Var):
        if isinstance(x, ScaleIG):
            if x._variance_param.strong:
                x._variance_param.value = jnp.asarray(x._variance_param.value)
                x.update()
        elif x.strong:
            try:
                x.value = jnp.asarray(x.value)
            except Exception as e:
                raise TypeError(
                    f"Unexpected type for scale value: {type(x.value)}"
                ) from e

        scale_var = x
        if validate_scalar:
            size = jnp.asarray(scale_var.value).size
            if not size == 1:
                raise ValueError(f"Expected scalar scale, got size {size}")
    elif x is None:
        scale_var = x
    else:
        try:
            scale_var = lsl.Var.new_value(jnp.asarray(x))
        except Exception as e:
            raise TypeError(f"Unexpected type for scale: {type(x)}") from e
        if validate_scalar:
            size = scale_var.value.size
            if not size == 1:
                raise ValueError(f"Expected scalar scale, got size {size}")

    return scale_var


def _validate_scalar_or_p_scale(scale_value: Array, p):
    is_scalar = scale_value.size == 1
    is_p = scale_value.size == p
    if not (is_scalar or is_p):
        raise ValueError(
            f"Expected scale to have size 1 or {p}, got size {scale_value.size}"
        )


[docs] class StrctTerm(UserVar): r""" General structured additive term. You probably want to initialize a term using :meth:`.StrctTerm.f`, which will automatically take the penalty matrix from the supplied basis and has automatic naming that is convenient in most situations. A structured additive term represents a smooth or structured effect in a generalized additive model. The term wraps a design/basis matrix together with a prior/penalty and a set of coefficients. The object exposes the coefficient variable and evaluates the term as the matrix-vector product of the basis and the coefficients. The term evaluates to ``basis @ coef``. Parameters ---------- basis A :class:`.Basis` instance that produces the design matrix for the term. The basis must evaluate to a 2-D array with shape ``(n_obs, n_bases)``. penalty Penalty matrix or a variable/value wrapping the penalty used to construct the multivariate normal prior for the coefficients. scale Scale parameter passed to the coefficient prior. name Term name. inference Inference specification for this term's coefficient. coef_name Name for the coefficient variable. If ``None``, a default name based on ``name`` will be used. _update_on_init If ``True`` (default) the internal calculation/graph nodes are evaluated during initialization. Set to ``False`` to delay initial evaluation. validate_scalar_scale If ``True`` (default), the term will error if the ``scale`` variable does not hold a scalar scale. This is appropriate for most cases. If ``False``, the term will also allow an array-valued ``scale`` variable of shape ``(nbases,)``. This only really makes sense when also reparameterizing the term using :meth:`.factor_scale`. Only use this if you know exactly what you are doing and you are certain that this is what you want. See Also --------- .TermBuilder : Initializes structured additive terms. .BasisBuilder : Initializes structured additive term basis matrices. .Basis : Basis matrix object. .StrctTerm.f : Alternative, more convenient constructor. .StrctTensorProdTerm : Anisotropic tensor product terms. Notes ----- The terms created by this builder generally have the form .. math:: s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i) \beta_j = \mathbf{b}(\mathbf{x}_i)^\top \boldsymbol{\beta} where - :math:`i=1, \dots, N` is the observation index, - :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate observations, where :math:`M` denotes the number of covariates, - :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i), \dots, B_J(\mathbf{x}_i)]` are a set of basis function evaluations, and - :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]` are the corresponding coefficients. In many cases, :math:`\mathbf{x}_i` will consist of only one covariate. The basis matrix for such a term is .. math:: \mathbf{B} = \begin{bmatrix} \mathbf{b}(\mathbf{x}_1)^\top \\ \vdots \\ \mathbf{b}(\mathbf{x}_N)^\top \end{bmatrix}. The coefficient receives a potentially rank-deficient multivariate normal prior .. math:: p(\boldsymbol{\beta}) \propto \left(\frac{1}{\tau^2}\right)^{ \operatorname{rk}(\mathbf{K})/2} \exp \left( - \frac{1}{\tau^2} \boldsymbol{\beta}^\top \mathbf{K} \boldsymbol{\beta} \right) with the potentially rank-deficient penalty matrix :math:`\mathbf{K}` of rank :math:`\operatorname{rk}(\mathbf{K})`. The variance parameter :math:`\tau^2` acts as an inverse smoothing parameter. The choice of basis functions :math:`B_j` and penalty matrix :math:`\mathbf{K}` determines the nature of the term. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=jnp.eye(2), ... ) >>> term = StrctTerm.f(basis, scale=1.0) >>> term.name, term.nbases, term.value.shape ('f(x)', 2, (4,)) """ def __init__( self, basis: Basis, penalty: lsl.Var | lsl.Value | ArrayLike | None, scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None, name: str = "", inference: InferenceTypes = None, coef_name: str | None = None, _update_on_init: bool = True, validate_scalar_scale: bool = True, ): scale = _init_scale_ig(scale, validate_scalar=validate_scalar_scale) coef_name = _append_name(name, "_coef") if coef_name is None else coef_name self._basis = basis if isinstance(penalty, lsl.Var | lsl.Value): nparam = jnp.shape(penalty.value)[-1] self._penalty: lsl.Var | lsl.Value | None = penalty elif penalty is not None: nparam = jnp.shape(penalty)[-1] self._penalty = lsl.Value(jnp.asarray(penalty)) else: nparam = self.nbases self._penalty = None prior = term_prior(scale, self._penalty) if scale is not None: _validate_scalar_or_p_scale(scale.value, nparam) self._coef = lsl.Var.new_param( jnp.zeros(nparam), prior, inference=inference, name=coef_name ) calc = lsl.Calc( lambda basis, coef: jnp.dot(basis, coef), basis=basis, coef=self.coef, _update_on_init=_update_on_init, ) self._scale = scale super().__init__(calc, name=name) if _update_on_init: self.coef.update() self._scale_is_factored = False self._disallow_scale_factorization = False if hasattr(self.scale, "setup_gibbs_inference"): try: self.scale.setup_gibbs_inference(self.coef) # type: ignore except Exception as e: raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e @property def scale_is_factored(self) -> bool: """ Whether the term has been reparameterized using :meth:`.factor_scale`. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm(basis, penalty=None, scale=1.0) >>> term.scale_is_factored False """ return self._scale_is_factored @property def coef(self) -> lsl.Var: """ The coefficient variable of this term. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm(basis, penalty=None, scale=1.0) >>> term.coef.value.shape (2,) """ return self._coef @property def basis(self) -> Basis: """ The basis matrix object of this term. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm(basis, penalty=None, scale=1.0) >>> term.basis.name 'B(x)' """ return self._basis @property def nbases(self) -> int: """ Number of basis functions, equal to the number of basis-matrix columns. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> StrctTerm(basis, penalty=None, scale=1.0).nbases 2 """ return jnp.shape(self.basis.value)[-1] @property def scale(self) -> lsl.Var | lsl.Node | None: """ The scale variable used by the coefficient prior. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm(basis, penalty=None, scale=1.0) >>> float(term.scale.value) 1.0 """ return self._scale
[docs] def replace_scale(self, new: lsl.Var, disallow_factorization: bool = True) -> None: """ Replace the scale variable and update the coefficient prior. Parameters ---------- new Replacement scale variable. disallow_factorization Whether subsequent scale factorization should be disabled. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm(basis, penalty=None, scale=1.0) >>> term.replace_scale(lsl.Var.new_value(2.0)) >>> float(term.scale.value) 2.0 """ if self.scale_is_factored: raise ValueError( f"Scale of {self} cannot be replace, because it has been factored " " using .factor_scale()." ) self._scale = new assert self.coef.dist_node is not None self.coef.dist_node["scale"] = new self._disallow_scale_factorization = disallow_factorization
def _validate_scale_for_factoring(self): if self.scale is None: raise ValueError( f"Scale factorization of {self} fails, because {self.scale=}." ) if self.scale.value.size > 1: raise ValueError( f"Scale factorization of {self} fails, " f"because scale must be scalar, but got {self.scale.value.size=}." ) def _validate_penalty_for_factoring(self, atol: float = 1e-5) -> Array: if self._penalty is None: return jnp.array(self.coef.value.shape[-1]) pen_rank = jnp.linalg.matrix_rank(self._penalty.value) if pen_rank == self._penalty.value.shape[-1]: # full-rank penalty always works return pen_rank if not is_diagonal(self._penalty.value, atol): # rank-deficient penalty must be diagonal raise ValueError( "With rank deficient penalties, factoring out the scale is " "only supported when using diagonalized penalties. " "This is " "because the scale is only applied to the penalized part, " "and we cannot reliably distinguish the penalized and " "unpenalized parts without diagonalization." ) unpenalized_parts = self._penalty.value[pen_rank:, pen_rank:] zeros = jnp.zeros_like(unpenalized_parts) if not jnp.allclose(unpenalized_parts, zeros, atol=atol): # rank-deficient part must be the last rows/columns of the penalty raise ValueError( "With rank deficient penalties, factoring out the scale is " "only supported when using diagonalized penalties. " "The null space of the penalty must be organized in the " "last R rows/columns, i.e. these must be all zero. " "R refers to the rank of the penalty, in your " f"case: {pen_rank}. " "Your penalty seems to be diagonal, but not have these " "zero-row/columns." "This is important" "because the scale is only applied to the penalized part, " "and we cannot reliably distinguish the penalized and " "unpenalized parts without this structure." ) return pen_rank
[docs] def factor_scale(self, atol: float = 1e-5) -> Self: """ Turn this term into a partially standardized form. This means the prior for the coefficient will be turned from ``coef ~ N(0, scale^2 * inv(penalty))`` into ``latent_coef ~ N(0, inv(penalty)); coef = scale * latent_coef``. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=jnp.eye(2), ... ) >>> term = StrctTerm.f(basis, scale=1.0) >>> term.factor_scale().scale_is_factored True """ self._validate_scale_for_factoring() pen_rank = self._validate_penalty_for_factoring(atol) if self._scale_is_factored: return self assert self.coef.dist_node is not None self.coef.dist_node["scale"] = lsl.Value(jnp.array(1.0)) assert self.scale is not None # checked in validation method above if self.scale.name and self.coef.name: scaled_name = self.scale.name + "*" + self.coef.name else: scaled_name = _append_name(self.coef.name, "_scaled") def scale_coef(scale, coef): coef = coef.at[:pen_rank].set(coef[:pen_rank] * scale) return coef scaled_coef = lsl.Var.new_calc( scale_coef, self.scale, self.coef, name=scaled_name, ) self.value_node["coef"] = scaled_coef self.coef.update() self.update() self._scale_is_factored = True if hasattr(self.scale, "setup_gibbs_inference_factored"): try: pen = self._penalty.value if self._penalty is not None else None self.scale.update() self.scale.setup_gibbs_inference_factored( scaled_coef, self.coef, penalty=pen ) # type: ignore except Exception as e: raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e return self
[docs] @classmethod def f( cls, basis: Basis, fname: str = "f", scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None, inference: InferenceTypes = None, coef_name: str | None = None, factor_scale: bool = False, ) -> Self: """ Construct a smooth term from a :class:`.Basis`. This convenience constructor builds a named ``term`` using the provided basis. The penalty matrix is taken from ``basis.penalty`` and a coefficient variable with an appropriate multivariate-normal prior is created. The returned term evaluates to ``basis @ coef``. Parameters ---------- basis Basis object that provides the design matrix and penalty for the \ smooth term. The basis must have an associated input variable with \ a meaningful name (used to compose the term name). fname Function-name prefix used when constructing the term name. Default \ is ``'f'`` which results in names like ``f(x)`` when the basis \ input is named ``x``. scale Scale parameter passed to the coefficient prior. inference Inference specification forwarded to the coefficient variable \ creation, a :class:`liesel.goose.MCMCSpec`. factor_scale If ``True``, the term is reparameterized by factoring out the scale \ form via :meth:`.factor_scale` before being returned. coef_name Coefficient name. The default coefficient name is a LaTeX-like string \ ``"$\\beta_{f(x)}$"`` to improve readability in printed summaries. Returns ------- A term instance configured with the given basis and prior settings. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = StrctTerm.f(basis, scale=1.0) >>> term.name, term.coef.value.shape ('f(x)', (2,)) """ if not basis.x.name: raise ValueError("basis.x must be named.") if not basis.name: raise ValueError("basis must be named.") if not isinstance(fname, str): raise TypeError(f"Expected type str, got {type(fname)}.") name = f"{fname}({basis.x.name})" coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$" term = cls( basis=basis, penalty=basis.penalty if scale is not None else None, scale=scale, inference=inference, coef_name=coef_name, name=name, validate_scalar_scale=not factor_scale, ) if factor_scale: term.factor_scale() return term
def _assert_penalty_is_basis_penalty(self): if self._penalty is None: raise ValueError( f"Penalty of {self} is None." " This functionality is only available if the term is initialized with " "the same penalty object as its basis." ) if self._penalty is not self.basis.penalty: raise ValueError( f"Different penalty objects found on {self} and its basis {self.basis}." " This functionality is only available if the term is initialized with " "the same penalty object as its basis." )
[docs] def diagonalize_penalty(self, atol: float = 1e-6) -> Self: """ Diagonalize the penalty via an eigenvalue decomposition. This method computes a transformation that diagonalizes the penalty matrix and updates the internal basis function such that subsequent evaluations use the accordingly transformed basis. The penalty is updated to the diagonalized version. Returns ------- The modified term instance (self). See Also -------- .Basis.diagonalize_penalty : The term calls this method internally. More details are documented there. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=jnp.eye(2), ... ) >>> term = StrctTerm.f(basis, scale=1.0).diagonalize_penalty() >>> bool(is_diagonal(term.basis.penalty.value)) True """ self._assert_penalty_is_basis_penalty() self.basis.diagonalize_penalty(atol) return self
[docs] def scale_penalty(self) -> Self: """ Scale the penalty matrix by its infinity norm. The penalty matrix is divided by its infinity norm (max absolute row sum) so that its values are numerically well-conditioned for downstream use. The updated penalty replaces the previous one. Returns ------- The modified term instance (self). See Also -------- .Basis.scale_penalty : The term calls this method internally. More details are documented there. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=2.0 * jnp.eye(2), ... ) >>> term = StrctTerm.f(basis, scale=1.0).scale_penalty() >>> float(jnp.max(term.basis.penalty.value)) 1.0 """ self._assert_penalty_is_basis_penalty() self.basis.scale_penalty() return self
[docs] def constrain( self, constraint: ArrayLike | Literal["sumzero_term", "sumzero_coef", "constant_and_linear"], ) -> Self: """ Apply a linear constraint to the term's basis and corresponding penalty. Parameters ---------- constraint Type of constraint or custom linear constraint matrix to apply. If an array is supplied, the constraint will be ``A @ coef == 0``, where ``A`` is the supplied constraint matrix. Returns ------- The modified term instance (self). See Also -------- .Basis.constrain : The term calls this method internally. More details are documented there. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=jnp.eye(2), ... ) >>> term = StrctTerm.f(basis, scale=1.0).constrain(jnp.ones((1, 2))) >>> term.nbases 1 """ self._assert_penalty_is_basis_penalty() self.basis.constrain(constraint) self.coef.value = jnp.zeros(self.nbases) return self
SmoothTerm = StrctTerm
[docs] class MRFTerm(StrctTerm): """ Term object for Markov random fields. Derived from :class:`.StrctTerm`, with a few additional attributes that give access to information about the Markov random field setup. Examples -------- ``MRFTerm`` objects are usually created by :class:`.TermBuilder`, which also attaches label and neighborhood metadata. >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... nb={"a": ["b"], "b": ["a"]}, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> term.neighbors {'a': ['b'], 'b': ['a']} >>> term.labels ['a', 'b'] """ _neighbors = None _polygons = None _ordered_labels = None _labels = None _mapping = None @property def neighbors(self) -> dict[str, list[str]] | None: """ Dictionary of neighborhood structure (if available). The keys are region labels. The values are lists of the labels of neighboring regions. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... nb={"a": ["b"], "b": ["a"]}, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> term.neighbors["a"] ['b'] """ return self._neighbors @neighbors.setter def neighbors(self, value: dict[str, list[str]] | None) -> None: """ Set the neighborhood dictionary for the term. This setter is primarily used by :meth:`.TermBuilder.mrf` after constructing the Markov random field basis. """ self._neighbors = value @property def polygons(self) -> dict[str, ArrayLike] | None: """ Dictionary of arrays. The keys of the dict are the region labels. The corresponding values define each region through a 2-D array of polygon information. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["a", "b", "a"]}) >>> polys = { ... "a": jnp.array([[0.0, 0.0], [1.0, 0.0]]), ... "b": jnp.array([[1.0, 0.0], [2.0, 0.0]]), ... } >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... polys=polys, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> sorted(term.polygons) ['a', 'b'] """ return self._polygons @polygons.setter def polygons(self, value: dict[str, ArrayLike] | None) -> None: """ Set polygon coordinates keyed by region label. This setter is primarily used by :meth:`.TermBuilder.mrf`. """ self._polygons = value @property def labels(self) -> list[str] | None: """ Region labels. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["b", "a", "b"]}) >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... nb={"a": ["b"], "b": ["a"]}, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> term.labels ['a', 'b'] """ return self._labels @labels.setter def labels(self, value: list[str]) -> None: """ Set the region labels. This setter is primarily used by :meth:`.TermBuilder.mrf`. """ self._labels = value @property def mapping(self) -> CategoryMapping: """ A label-integer mapping for the regions in this term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... nb={"a": ["b"], "b": ["a"]}, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> term.mapping.labels_to_integers(["b"]).tolist() [1] """ if self._mapping is None: raise ValueError("No mapping defined.") return self._mapping @mapping.setter def mapping(self, value: CategoryMapping) -> None: """ Set the label-integer mapping for the regions. This setter is primarily used by :meth:`.TermBuilder.mrf`. """ self._mapping = value @property def ordered_labels(self) -> list[str] | None: """ Ordered labels, if they are available. Ordering is such that the order corresponds to the columns of the basis and penalty matrices. Only available for unconstrained MRF. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"region": ["b", "a", "b"]}) >>> term = gam.TermBuilder.from_df(df).mrf( ... "region", ... nb={"a": ["b"], "b": ["a"]}, ... absorb_cons=False, ... diagonal_penalty=False, ... scale_penalty=False, ... scale=1.0, ... ) >>> term.ordered_labels ['a', 'b'] """ return self._ordered_labels @ordered_labels.setter def ordered_labels(self, value: list[str]) -> None: """ Set labels ordered like the basis and penalty columns. This setter is primarily used by :meth:`.TermBuilder.mrf` when the basis still has a clear parameter-to-label correspondence. """ self._ordered_labels = value
[docs] class IndexingTerm(StrctTerm): """ Term object for memory-efficient representation of sparse bases. Derived from :class:`.StrctTerm`. If the basis matrix of a term is a dummy matrix, where each column consists only of binary (0/1) entries, and each row has only one non-zero entry, then it is not necessary to store the full matrix in memory and evaluate the term as a dot product ``basis @ coef``. Instead, we can simply store a 1-D array of indices, identifying the nonzero column for each row of the basis matrix, and use this index to access the corresponding coefficient. This scenario is common for independent random intercepts. This class implements such a sparse representation. In case you do need to materialize the full, sparse basis of such a term, you can use :meth:`.IndexingTerm.init_full_basis`. Examples -------- >>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None) >>> term = IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0) >>> term.value.shape, term.nclusters ((4,), 2) """ def __init__( self, basis: Basis, penalty: lsl.Var | lsl.Value | Array | None, scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None, name: str = "", inference: InferenceTypes = None, coef_name: str | None = None, _update_on_init: bool = True, validate_scalar_scale: bool = True, ): if not basis.value.ndim == 1: raise ValueError(f"IndexingTerm requires 1d basis, got {basis.value.ndim=}") if not jnp.issubdtype(jnp.asarray(basis.value).dtype, jnp.integer): raise TypeError( "IndexingTerm requires integer basis, " f"got {jnp.asarray(basis.value).dtype=}." ) super().__init__( basis=basis, penalty=penalty, scale=scale, name=name, inference=inference, coef_name=coef_name, _update_on_init=False, validate_scalar_scale=validate_scalar_scale, ) # mypy warns that self.value_node might be a lsl.Node, which does not have the # attribute "function". # But we can assume safely that self.value_node is a lsl.Calc, which does have # one. assert isinstance(self.value_node, lsl.Calc) self.value_node.function = lambda basis, coef: jnp.take(coef, basis) if _update_on_init: self.coef.update() self.update() @property def nbases(self) -> int: """ Number of coefficients represented by the indexed basis. Examples -------- >>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None) >>> IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0).nbases 2 """ return self.nclusters @property def nclusters(self) -> int: """ Number of unique clusters in this term (equals the number of coefficients). Examples -------- >>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None) >>> IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0).nclusters 2 """ nclusters = jnp.unique(self.basis.value).size return int(nclusters)
[docs] def init_full_basis(self) -> Basis: """ Materializes a :class:`.Basis` object that holds the full basis matrix corresponding to this term. Examples -------- >>> basis = Basis(jnp.array([0, 1, 0, 1]), xname="group", penalty=None) >>> term = IndexingTerm(basis, penalty=jnp.eye(2), scale=1.0) >>> term.init_full_basis().value.shape (4, 2) """ full_basis = Basis( self.basis.x, basis_fn=jax.nn.one_hot, num_classes=self.nclusters, name="" ) return full_basis
[docs] class RITerm(IndexingTerm): """ Term object for memory-efficient representation of independent random intercepts. Specialized subclass of :class:`.IndexingTerm`, which itself is derived from :class:`.StrctTerm`. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"group": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0) >>> term.nclusters 2 >>> term.labels ['a', 'b'] """ _labels = None _mapping = None @property def nclusters(self) -> int: """ Number of clusters represented by this random-intercept term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"group": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0) >>> term.nclusters 2 """ try: nclusters = len(self.mapping.labels_to_integers_map) except ValueError: nclusters = jnp.unique(self.basis.value).size return int(nclusters) @property def labels(self) -> list[str]: """ List of labels for all clusters represented by this term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"group": ["b", "a", "b"]}) >>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0) >>> term.labels ['a', 'b'] """ if self._labels is None: raise ValueError("No labels defined.") return self._labels @labels.setter def labels(self, value: list[str]) -> None: """ Set the labels for all clusters. This setter is primarily used by :meth:`.TermBuilder.ri`. """ if not len(value) == self.nclusters: raise ValueError(f"Expected {self.nclusters} labels, got {len(value)}.") self._labels = value @property def mapping(self) -> CategoryMapping: """ A label-integer mapping for the clusters in this term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"group": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).ri("group", scale=1.0) >>> term.mapping.labels_to_integers(["a"]).tolist() [0] """ if self._mapping is None: raise ValueError("No mapping defined.") return self._mapping @mapping.setter def mapping(self, value: CategoryMapping) -> None: """ Set the label-integer mapping for the clusters. This setter is primarily used by :meth:`.TermBuilder.ri`. """ self._mapping = value
[docs] class BasisDot(UserVar): """ Basic term variable for a dot-product ``basis @ coef``. In comparison to :class:`.StrctTerm`, this class makes fewer assumptions, since it does not assume any prior distribution, or structure of the prior distribution, for the coefficients. Instead, a prior for the coefficients of this term (if desired) is defined manually as a :class:`liesel.model.Dist` in the ``prior`` argument. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = BasisDot(basis) >>> term.coef.value.shape, term.value.shape ((2,), (4,)) """ def __init__( self, basis: Basis, prior: lsl.Dist | None = None, name: str = "", inference: InferenceTypes = None, coef_name: str | None = None, _update_on_init: bool = True, ): self.basis = basis self.nbases = self.basis.nbases coef_name = _append_name(name, "_coef") if coef_name is None else coef_name self.coef = lsl.Var.new_param( jnp.zeros(self.basis.nbases), prior, inference=inference, name=coef_name ) calc = lsl.Calc( lambda basis, coef: jnp.dot(basis, coef), basis=self.basis, coef=self.coef, _update_on_init=_update_on_init, ) super().__init__(calc, name=name)
[docs] class LinMixin: """ Mixin that adds formula metadata to linear-term classes. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"x": [0.0, 1.0], "z": [1.0, 2.0]}) >>> term = gam.TermBuilder.from_df(df).lin("x + z") >>> term.column_names ['x', 'z'] """ _model_spec: ModelSpec | None = None _mappings: dict[str, CategoryMapping] | None = None _column_names: list[str] | None = None @property def model_spec(self) -> ModelSpec: """ The model spec used internally by ``formulaic`` to set up the basis matrix. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"x": [0.0, 1.0], "z": [1.0, 2.0]}) >>> term = gam.TermBuilder.from_df(df).lin("x + z") >>> str(term.model_spec.formula) '1 + x + z' """ if self._model_spec is None: raise ValueError("No model spec defined.") return self._model_spec @model_spec.setter def model_spec(self, value: ModelSpec): """ Set the :class:`formulaic.ModelSpec` used by this linear term. This setter is primarily used by :meth:`.TermBuilder.lin` and :meth:`.TermBuilder.slin`. """ if not isinstance(value, ModelSpec): raise TypeError( f"Replacement must be of type {ModelSpec}, got {type(value)}." ) self._model_spec = value @property def mappings(self) -> dict[str, CategoryMapping]: """ A dictionary of label-integer mappings for all categorical variables in this term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"cat": ["a", "b", "a"]}) >>> term = gam.TermBuilder.from_df(df).lin("cat") >>> term.mappings["cat"].labels_to_integers(["b"]).tolist() [1] """ if self._mappings is None: raise ValueError("No mappings defined.") return self._mappings @mappings.setter def mappings(self, value: dict[str, CategoryMapping]): """ Set categorical label mappings for this linear term. This setter is primarily used by :meth:`.TermBuilder.lin` and :meth:`.TermBuilder.slin`. """ if not isinstance(value, dict): raise TypeError(f"Replacement must be of type dict, got {type(value)}.") for val in value.values(): if not isinstance(val, CategoryMapping): raise TypeError( f"The values in the replacement must be of type {CategoryMapping}, " f"got {type(val)}." ) self._mappings = value @property def column_names(self) -> list[str]: """ List of column names for this term. Examples -------- >>> import pandas as pd >>> import liesel_gam as gam >>> df = pd.DataFrame({"x": [0.0, 1.0], "cat": ["a", "b"]}) >>> term = gam.TermBuilder.from_df(df).lin("x + cat") >>> term.column_names ['x', 'cat[T.b]'] """ if self._column_names is None: raise ValueError("No column names defined.") return self._column_names @column_names.setter def column_names(self, value: Sequence[str]): """ Set column names for this term. This setter is primarily used by :meth:`.TermBuilder.lin` and :meth:`.TermBuilder.slin`. """ if not isinstance(value, Sequence): raise TypeError(f"Replacement must be a sequence, got {type(value)}.") if isinstance(value, str): raise TypeError("Replacement type cannot be string.") for val in value: if not isinstance(val, str): raise TypeError( f"The values in the replacement must be of type str, " f"got {type(val)}." ) self._column_names = list(value)
[docs] class LinTerm(BasisDot, LinMixin): """ Specialized :class:`.BasisDot` for general linear effects. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis(jnp.column_stack([jnp.ones_like(x), x]), xname="x") >>> term = LinTerm(basis) >>> term.value.shape (4,) """ pass
[docs] class StrctLinTerm(StrctTerm, LinMixin): """ Specialized :class:`.StrctTerm` for linear effects. This term can be used, for example, to set up linear effects with a ridge prior. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> basis = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), ... xname="x", ... penalty=jnp.eye(2), ... ) >>> term = StrctLinTerm(basis, penalty=basis.penalty, scale=1.0) >>> term.coef.value.shape (2,) """ pass
[docs] class StrctInteractionTerm(UserVar): r""" Anisotropic structured additive interaction term. Parameters ---------- *marginals Marginal terms. common_scale A single, common scale to cover all marginal dimensions, resulting in an isotropic tensor product. This means setting :math:`\tau^2_1 = \dots = \tau^2_M = \tau^2` for all marginal smooths in the notation used below. name Name of the term coef_name Name of the coefficient variable. If ``None``, created automatically based on ``name``. basis_name Name of the basis variable. This variable is internally created to represent the tensor product of the marginal basis matrices. If ``None``, the name will be created automatically based on the names of the observed input variables to the marginal terms. include_main_effects If ``True``, the marginal terms will be added to this term's value. _update_on_init Whether to update the term upon initialization. See Also -------- .StrctTerm : Basic (isotropic) structured additive term. .StrctTensorProdTerm : Full anisotropic tensor product term. .TermBuilder : Initializes structured additive terms. .BasisBuilder : Initializes structured additive term basis matrices. .Basis : Basis matrix object. .StrctTerm.f : Alternative, more convenient constructor. Notes ----- .. note:: The classes :class:`.StrctInteractionTerm` and :class:`.StrctTensorProdTerm` are closely related. The former loosely corresponds to ``mgcv::ti``, and the latter loosely corresponds to ``mgcv::te``, meaning that, when you supply centered marginals, :class:`.StrctInteractionTerm` will *only* include the highest-order interaction of the supplied marginals, while :class:`.StrctTensorProdTerm` will include the highest-order interaction *and* all lower-order interactions, including the main effects. Assumes that the term is a tensor product of :math:`M` marginal bases that can be written as .. math:: s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i)\beta_j = \mathbf{b}^\top \boldsymbol{\beta}, where - :math:`i=1, \dots, N` is the observation index, - :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate observations, where :math:`M` denotes the number of covariates included in this term, - :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i), \dots, B_J(\mathbf{x}_i)]` are a set of basis function evaluations, and - :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]` are the corresponding coefficients. The vector of basis function evaluations is the Kronecker product of the marginal bases: .. math:: \mathbf{b}(\mathbf{x}_i)^\top = \mathbf{b}_1(x_{i,1})^\top \otimes \mathbf{b}_2(x_{i,2})^\top \otimes \cdots \otimes \mathbf{b}_M(x_{i,M})^\top, In this notation, we assume that the marginal bases are often functions of just one covariate each, which is the common case. The individual terms have (potentially different) basis dimensions :math:`J_1, \dots, J_M`, such that the tensor product basis dimension is :math:`J = \prod_{m=1}^M J_m`. The coefficient vector is equipped with a potentially rank-deficient multivariate Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as .. math:: p(\boldsymbol{\beta} | \boldsymbol{\tau}^2) \propto \operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2} \exp \left( - \frac{1}{2} \boldsymbol{\beta}^\top \mathbf{K}(\boldsymbol{\tau}^2) \boldsymbol{\beta} \right), with the precision matrix constructed from marginal penalties :math:`\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M` and variance parameters :math:`\tau^2_1,\dots, \tau^2_M` as .. math:: \mathbf{K}(\boldsymbol{\tau}^2) = \frac{\mathbf{K}_1}{\tau^2_1} + \cdots + \frac{\mathbf{K}_M}{\tau^2_M}, where .. math:: \mathbf{K}_m = \mathbf{I}_{J_1} \otimes \cdots \otimes \mathbf{I}_{J_{m-1}} \otimes \tilde{\mathbf{K}}_m \otimes \mathbf{I}_{J_{m+1}} \otimes \cdots \mathbf{I}_{J_{M}}, and :math:`\mathbf{I}_{J_m}` denotes the identity matrix of dimension :math:`J_m \times J_m`. Since :math:`\mathbf{K}(\boldsymbol{\tau}^2)` may be rank-deficient, :math:`\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))` is the pseudo-determinant, or generalized determinant. This term exploits the clearly defined structure of the precision matrix to obtain a computationally and memory-efficient evaluation of the prior, implemented in the :class:`.MultivariateNormalStructured` distribution class. We also implement the results obtained by Bach & Klein (2025) for efficiently computing the pseudo-determinant; a key prerequisite for making higher-dimensional tensor products feasible. References ---------- - Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z - Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43. https://doi.org/10.1007/s11222-025-10569-y Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> sx = StrctTerm.f(bx, scale=1.0) >>> sy = StrctTerm.f(by, scale=2.0) >>> term = StrctInteractionTerm(sx, sy) >>> term.basis.value.shape, term.coef.value.shape ((4, 4), (4,)) """ def __init__( self, *marginals: StrctTerm | IndexingTerm | RITerm | MRFTerm, common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None, name: str = "", inference: InferenceTypes = None, coef_name: str | None = None, basis_name: str | None = None, include_main_effects: bool = False, _update_on_init: bool = True, ): for term__ in marginals: if term__.scale is None: raise ValueError( f"Scale of {term__} is None, which is not allowed " f"in {type(self).__name__}." ) if include_main_effects and len(marginals) > 2: raise ValueError( "For more than two marginals, include_main_effects=True is probably " "not the right tool, since it will not include lower-order " "interactions. Please consider using liesel_gam.StrctTensorProdTerm " "instead." ) self._validate_marginals(marginals) coef_name = _append_name(name, "_coef") if coef_name is None else coef_name bases = self._get_bases(marginals) penalties = self._get_penalties(bases) if common_scale is None: scales = [t.scale for t in marginals if t.scale is not None] else: scale_ = _init_scale_ig(common_scale) if isinstance(scale_, ScaleIG): var_ = scale_.value_node[0] if isinstance(var_, lsl.Var): if var_.inference is None or not hasattr(var_.inference, "kernel"): pass else: if hasattr(var_.inference.kernel, "__name__"): scale_kernel_name = var_.inference.kernel.__name__ else: scale_kernel_name = "" if scale_kernel_name == "StarVarianceGibbs": raise ValueError( f"{scale_kernel_name} kernel is invalid for a tensor " "product. " "Please manually set a valid .inference specification " "for your " "common scale." ) assert scale_ is not None scales = [scale_] * len(bases) _rowwise_kron = jax.vmap(jnp.kron) def rowwise_kron(*bases): return reduce(_rowwise_kron, bases) self.xnames = ",".join(list(self._input_obs(bases))) if basis_name is None: basis_name = "B(" + self.xnames + ")" assert basis_name is not None basis = lsl.Var.new_calc(rowwise_kron, *bases, name=basis_name) nbases = jnp.shape(basis.value)[-1] mvnds = MultivariateNormalStructured.get_locscale_constructor( penalties=penalties ) scales_var = lsl.Calc(lambda *x: jnp.stack(x, axis=-1), *scales) prior = lsl.Dist(distribution=mvnds, loc=jnp.zeros(nbases), scales=scales_var) coef = lsl.Var.new_param( jnp.zeros(nbases), distribution=prior, inference=inference, name=coef_name, ) self.basis = basis self.marginals = marginals self.bases = bases self.penalties = penalties self.scales = scales self.nbases = nbases self.basis = basis self.coef = coef self.scale = scales_var self.include_main_effects = include_main_effects if include_main_effects: calc = lsl.Calc( lambda *marginals, basis, coef: sum(marginals) + jnp.dot(basis, coef), *marginals, basis=basis, coef=self.coef, _update_on_init=_update_on_init, ) else: calc = lsl.Calc( lambda basis, coef: jnp.dot(basis, coef), basis=basis, coef=self.coef, _update_on_init=_update_on_init, ) super().__init__(calc, name=name) if _update_on_init: self.coef.update() @staticmethod def _get_bases( marginals: Sequence[StrctTerm | RITerm | MRFTerm | IndexingTerm], ) -> list[Basis]: bases = [] for t in marginals: if hasattr(t, "init_full_basis"): bases.append(t.init_full_basis()) else: bases.append(t.basis) return bases @staticmethod def _get_penalties(bases: Sequence[Basis]) -> list[Array]: penalties = [] for b in bases: if b.penalty is None: raise TypeError( f"All bases must have a penalty matrix, got 'None' for {b}." ) penalties.append(b.penalty.value) return penalties @staticmethod def _validate_marginals(marginals: Sequence[StrctTerm]): for t in marginals: if t.scale is None: raise ValueError(f"Invalid scale for {t}: {t.scale}") @property def input_obs(self) -> dict[str, lsl.Var]: """ A dictionary of strong input variables. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> term = StrctInteractionTerm( ... StrctTerm.f(bx, scale=1.0), ... StrctTerm.f(by, scale=2.0), ... ) >>> sorted(term.input_obs) ['x', 'y'] """ return self._input_obs(self.bases) @staticmethod def _input_obs(bases: Sequence[Basis]) -> dict[str, lsl.Var]: # this method includes assumptions about how the individual bases are # structured: Basis.x can be a strong observed variable directly, or a # calculator variable that depends on strong observed variables. # If these assumptions are violated, this method may produce unexpected results. # The bases created by BasisBuilder fit theses assumptions. _input_x = {} for b in bases: if isinstance(b.x, lsl.Var): if b.x.strong and b.x.observed: # case: ordinary univariate marginal basis, like ps if not b.x.name: raise ValueError(f"{b}.x is unnamed.") _input_x[b.x.name] = b.x elif b.x.weak: # currently, I don't expect this case to be present # but it would make sense for xi in b.x.all_input_vars(): if xi.observed: if not xi.name: raise ValueError(f"Observed name not found for {b}") _input_x[xi.name] = xi else: # case: potentially multivariate marginal, possibly thin plate, # where basis.x is a calculator that collects the strong inputs. for xj in b.x.all_input_nodes(): if xj.var is not None: if xj.var.observed: if not xj.var.name: raise ValueError(f"Observed name not found for {b}") _input_x[xj.var.name] = xj.var return _input_x
[docs] @classmethod def f( cls, *marginals: StrctTerm, common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None, fname: str = "ta", inference: InferenceTypes = None, include_main_effects: bool = False, _update_on_init: bool = True, ) -> Self: """ Alternative constructor with more opinionated automatic naming. Parameters ---------- *marginals Marginal terms. common_scale A single, common scale to cover both marginal dimensions, resulting in an isotropic tensor product. fname Function-name prefix used when constructing the term name. include_main_effects If ``True``, the marginal terms will be added to this term's value. _update_on_init Whether to update the term upon initialization. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> term = StrctInteractionTerm.f( ... StrctTerm.f(bx, scale=1.0), ... StrctTerm.f(by, scale=2.0), ... ) >>> term.name, term.coef.value.shape ('ta(x,y)', (4,)) """ xnames = list(cls._input_obs(cls._get_bases(marginals))) name = fname + "(" + ",".join(xnames) + ")" coef_name = "$\\beta_{" + name + "}$" term = cls( *marginals, common_scale=common_scale, inference=inference, coef_name=coef_name, name=name, basis_name=None, include_main_effects=include_main_effects, _update_on_init=_update_on_init, ) return term
[docs] class StrctTensorProdTerm(UserVar): r""" Anisotropic structured additive tensor product term. Parameters ---------- *marginals Marginal terms. common_scale A single, common scale to cover all marginal dimensions, resulting in an isotropic tensor product. This means setting :math:`\tau^2_1 = \dots = \tau^2_M = \tau^2` for all marginal smooths in the notation used below. order Sequence of integers identifying the orders of interactions to be included in this term. For example, if you want to include only the bi- and trivariate interactions when supplying three marginals, pass ``order=(2,3)``. The default ``order=None`` means that *all* orders will be included; also the main effects. names_prefix Prefix to be added to the names of all variables created by this term. The names of the main effects (marginals) will not be changed. tx_name Function component for the names of the interaction terms internally created by this term. The naming pattern is ``"{tx_name}(x1, x2, ...)"``. tf_name Function component for this term's name. The naming pattern is ``"{tf_name}(x1, x2, ...)"``. coef_name Symbol component of the coefficient names created by this term. The naming pattern is ``"${coef_name}_{term_name}$"``, where ``term_name`` is the name of a lower-order interaction in this term. Does not affect the names of the marginal terms' coefficients. basis_name Function component for the names of interaction bases internally created by this term. The naming pattern is ``"{basis_name}(x1, x2, ...)"``. group_terms_by_order If ``True``, an intermediate variable object will be created for each order of interactions in this term. This can help to better organize the plotted graph of a term, but otherwise has no effect except using slightly more memory. _update_on_init Whether to update the term upon initialization. See Also -------- .StrctTerm : Basic (isotropic) structured additive term. .StrctTensorProdTerm : Full anisotropic tensor product term. .TermBuilder : Initializes structured additive terms. .BasisBuilder : Initializes structured additive term basis matrices. .Basis : Basis matrix object. .StrctTerm.f : Alternative, more convenient constructor. Notes ----- .. note:: The classes :class:`.StrctInteractionTerm` and :class:`.StrctTensorProdTerm` are closely related. The former loosely corresponds to ``mgcv::ti``, and the latter loosely corresponds to ``mgcv::te``, meaning that, when you supply centered marginals, :class:`.StrctInteractionTerm` will *only* include the highest-order interaction of the supplied marginals, while :class:`.StrctTensorProdTerm` will include the highest-order interaction *and* all lower-order interactions, including the main effects. Assumes that the term is a tensor product of :math:`M` marginal bases that can be written as .. math:: s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i)\beta_j = \mathbf{b}^\top \boldsymbol{\beta}, where - :math:`i=1, \dots, N` is the observation index, - :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate observations, where :math:`M` denotes the number of covariates included in this term, - :math:`\mathbf{b}(\mathbf{x}_i)^\top = [B_1(\mathbf{x}_i), \dots, B_J(\mathbf{x}_i)]` are a set of basis function evaluations, and - :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]` are the corresponding coefficients. The vector of basis function evaluations is the Kronecker product of the marginal bases: .. math:: \mathbf{b}(\mathbf{x}_i)^\top = \mathbf{b}_1(x_{i,1})^\top \otimes \mathbf{b}_2(x_{i,2})^\top \otimes \cdots \otimes \mathbf{b}_M(x_{i,M})^\top, In this notation, we assume that the marginal bases are functions of just one covariate each, which is the common case. The individual terms have (potentially different) basis dimensions :math:`J_1, \dots, J_M`, such that the tensor product basis dimension is :math:`J = \prod_{m=1}^M J_m`. The coefficient vector is equipped with a potentially rank-deficient multivariate Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as .. math:: p(\boldsymbol{\beta} | \boldsymbol{\tau}^2) \propto \operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2} \exp \left( - \frac{1}{2} \boldsymbol{\beta}^\top \mathbf{K}(\boldsymbol{\tau}^2) \boldsymbol{\beta} \right), with the precision matrix constructed from marginal penalties :math:`\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M` and variance parameters :math:`\tau^2_1,\dots, \tau^2_M` as .. math:: \mathbf{K}(\boldsymbol{\tau}^2) = \frac{\mathbf{K}_1}{\tau^2_1} + \cdots + \frac{\mathbf{K}_M}{\tau^2_M}, where .. math:: \mathbf{K}_m = \mathbf{I}_{J_1} \otimes \cdots \otimes \mathbf{I}_{J_{m-1}} \otimes \tilde{\mathbf{K}}_m \otimes \mathbf{I}_{J_{m+1}} \otimes \cdots \mathbf{I}_{J_{M}}, and :math:`\mathbf{I}_{J_m}` denotes the identity matrix of dimension :math:`J_m \times J_m`. Since :math:`\mathbf{K}(\boldsymbol{\tau}^2)` may be rank-deficient, :math:`\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))` is the pseudo-determinant, or generalized determinant. This term exploits the clearly defined structure of the precision matrix to obtain a computationally and memory-efficient evaluation of the prior, implemented in the :class:`.MultivariateNormalPenaltyOperator` distribution class. We also implement the results obtained by Bach & Klein (2025) for efficiently computing the pseudo-determinant; a key prerequisite for making higher-dimensional tensor products feasible. References ---------- - Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z - Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43. https://doi.org/10.1007/s11222-025-10569-y Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> sx = StrctTerm.f(bx, scale=1.0) >>> sy = StrctTerm.f(by, scale=2.0) >>> term = StrctTensorProdTerm(sx, sy) >>> len(term.terms), term.value.shape (3, (4,)) """ def __init__( self, *marginals: StrctTerm | IndexingTerm | RITerm | MRFTerm, common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None, order: Sequence[int] | None = None, inference: InferenceTypes = None, names_prefix: str = "", tx_name: str = "tx", tf_name: str = "tf", coef_name: str = r"\beta", basis_name: str = "B", group_terms_by_order: bool = False, _update_on_init: bool = True, ): for term__ in marginals: if term__.scale is None: raise ValueError( f"Scale of {term__} is None, which is not allowed " f"in {type(self).__name__}." ) nmargins = len(marginals) terms_combinations = [] for i in range(2, (nmargins + 1)): terms_combinations += list(combinations(marginals, i)) self.order = order if order is not None else tuple(range(1, len(marginals) + 1)) self.terms_by_order: dict[int, list[StrctTerm | StrctInteractionTerm]] = {} if 1 in self.order: self.terms_by_order[1] = list(marginals) scale_ = _init_scale_ig(common_scale) if common_scale is not None else None interactions = [] for term_marginals in terms_combinations: order_term = len(term_marginals) if order_term not in self.order: continue term = StrctInteractionTerm( *term_marginals, common_scale=scale_, inference=inference, _update_on_init=_update_on_init, include_main_effects=False, ) term.name = names_prefix + f"{tx_name}({term.xnames})" term.coef.name = names_prefix + "$" + coef_name + r"_{" + term.name + r"}$" term.basis.name = names_prefix + basis_name + "(" + term.xnames + ")" interactions.append(term) if not self.terms_by_order.get(order_term, None): self.terms_by_order[order_term] = [term] else: self.terms_by_order[order_term].append(term) for o in self.order: if o not in self.terms_by_order: raise ValueError( f"Order {order} was supplied, but no interactions " f"of order {o} found." ) if common_scale is not None: assert scale_ is not None for term_ in marginals: term_.replace_scale(scale_) self.marginals = marginals self._terms_list = list(marginals) + interactions self.bases = StrctInteractionTerm._get_bases(marginals) self.penalties = StrctInteractionTerm._get_penalties(self.bases) self.xnames = ",".join(list(self.input_obs)) if group_terms_by_order: self.term_groups = {} for o, o_terms in self.terms_by_order.items(): self.term_groups[o] = lsl.Var.new_calc( lambda *args: sum(args), *o_terms, _update_on_init=_update_on_init, name=names_prefix + f"${tf_name}^{{({o})}}({self.xnames})$", ) calc = lsl.Calc( lambda *args: sum(args), *list(self.term_groups.values()), _update_on_init=_update_on_init, ) else: calc = lsl.Calc( lambda *args: sum(args), *self._terms_list, _update_on_init=_update_on_init, ) super().__init__(calc, name=names_prefix + f"{tf_name}({self.xnames})") @property def terms( self, ) -> dict[str, StrctTerm | StrctInteractionTerm | IndexingTerm | RITerm | MRFTerm]: """ Dictionary of terms contained in this tensor product. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> term = StrctTensorProdTerm( ... StrctTerm.f(bx, scale=1.0), ... StrctTerm.f(by, scale=2.0), ... ) >>> sorted(term.terms) ['f(x)', 'f(y)', 'tx(x,y)'] """ return {term.name: term for term in self._terms_list} @property def scales(self) -> list[lsl.Var | lsl.Node]: """ Unique scale variables used by the marginal and interaction terms. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> term = StrctTensorProdTerm( ... StrctTerm.f(bx, scale=1.0), ... StrctTerm.f(by, scale=2.0), ... ) >>> [float(scale.value) for scale in term.scales] [1.0, 2.0] """ scales = [] for i in self.order: if i == 1: for term in self.terms_by_order[i]: if term.scale not in scales and term.scale is not None: scales.append(term.scale) else: for term in self.terms_by_order[i]: assert hasattr(term, "scales") for scale in term.scales: if scale not in scales: scales.append(scale) return scales @property def input_obs(self) -> dict[str, lsl.Var]: """ A dictionary of strong input variables. Examples -------- >>> x = jnp.linspace(0.0, 1.0, 4) >>> y = jnp.linspace(1.0, 2.0, 4) >>> bx = Basis( ... jnp.column_stack([jnp.ones_like(x), x]), xname="x", penalty=jnp.eye(2) ... ) >>> by = Basis( ... jnp.column_stack([jnp.ones_like(y), y]), xname="y", penalty=jnp.eye(2) ... ) >>> term = StrctTensorProdTerm( ... StrctTerm.f(bx, scale=1.0), ... StrctTerm.f(by, scale=2.0), ... ) >>> sorted(term.input_obs) ['x', 'y'] """ return StrctInteractionTerm._input_obs(self.bases)