Source code for liesel_gam.basis

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, Self

import jax
import jax.numpy as jnp
import liesel.model as lsl
from formulaic import ModelSpec

from liesel_gam.category_mapping import CategoryMapping

from .constraint import LinearConstraintEVD, penalty_to_unit_design
from .var import UserVar, _append_name, _ensure_var_or_node

InferenceTypes = Any
Array = jax.Array
ArrayLike = jax.typing.ArrayLike


def make_callback(function, output_shape, dtype, m: int = 0):
    k = output_shape[-1] if len(output_shape) else None

    def fn(x, **basis_kwargs):
        n = jnp.shape(jnp.atleast_1d(x))[0]

        if len(output_shape) == 2:
            shape = (n - m, k)
        elif len(output_shape) == 1:
            shape = (n - m,)
        elif not len(output_shape):
            shape = ()
        else:
            raise RuntimeError(
                "Return shape of 'basis_fn(value)' must"
                f" have <= 2 dimensions, got {output_shape}"
            )

        sig = jax.ShapeDtypeStruct(shape, dtype)

        # ordered=True enforces sequencing of callback executions
        return jax.experimental.io_callback(
            function,
            sig,
            x,
            ordered=True,
            **basis_kwargs,
        )

    return fn


def is_diagonal(M, atol=1e-12):
    # mask for off-diagonal elements
    off_diag_mask = ~jnp.eye(M.shape[-1], dtype=bool)
    off_diag_values = M[off_diag_mask]
    return jnp.all(jnp.abs(off_diag_values) < atol)


[docs] class Basis(UserVar): """ General basis for a structured additive term. The ``Basis`` class wraps an observation variable (or an array) and a basis-generation function. It constructs an internal calculation node that produces the basis (design) matrix by computing ``basis_fn(value)``. The basis function may be executed via a callback, in which case it does not need to be jax-compatible. This is the default, but it is potentially very slow, if the value of the basis needs to be recomputed during estimation. We recommend it only for bases that remain static during estimation. Parameters ---------- value If a :class:`liesel.model.Var` or node is provided it is used as the input variable for the basis. Otherwise a raw array-like object may be supplied together with ``xname`` to create an observed variable internally. basis_fn Function mapping the input variable's values to a basis matrix or vector. It must accept the input array and any ``basis_kwargs`` and return an array of shape ``(n_obs, n_bases)`` (or a scalar/1-d array for simpler bases). By default this is the identity function (``lambda x: x``). name Optional name for the basis object. If omitted, a sensible name is constructed from the input variable's name (``B(<xname>)``). xname Required when ``value`` is a raw array: provides a name for the observation variable that will be created. use_callback If ``True`` (default) the basis_fn is wrapped in a JAX ``pure_callback`` via :func:`make_callback` to allow arbitrary Python basis functions while preserving JAX tracing. If ``False`` the function is used directly and must be jittable via JAX. cache_basis If ``True`` the computed basis is cached in a persistent calculation node (``lsl.Calc``), which avoids re-computation when not required, but uses memory. If ``False`` a transient calculation node (``lsl.TransientCalc``) is used and the basis will be recomputed with each evaluation of ``Basis.value``, but not stored in memory. penalty Penalty matrix associated with the basis. If ``"identity"``, a default identity penalty is created based on the number of basis functions. If *None*, an identity penalty is assumed, but not materialized, which saves memory but must be handled explicitly later, if downstream functionality relies on an explicit penalty matrix. **basis_kwargs Additional keyword arguments forwarded to ``basis_fn``. See Also --------- .TermBuilder : Initializes structured additive terms. .BasisBuilder : Initializesstructured additive terms. .StrctTerm : A general structured additive term. Notes ----- The basis is evaluated once during initialization (via ``self.update()``) to determine its shape and dtype. The internal callback wrapper inspects the return shape to build a compatible JAX ShapeDtypeStruct for the pure callback. Examples -------- Implementing a B-spline basis manually: >>> from liesel.contrib.splines import ( ... basis_matrix, ... equidistant_knots, ... pspline_penalty, ... ) >>> import liesel_gam as gam >>> df = gam.demo_data(n=100) >>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20) >>> pen = pspline_penalty(d=20) >>> def bspline_basis(x): ... return basis_matrix(x, knots=knots) >>> gam.Basis( ... value=df["x_nonlin"].to_numpy(), ... basis_fn=bspline_basis, ... xname="x", ... penalty=pen, ... ) Basis(name="B(x)") Implementing a fixed basis matrix (without using the basis function). This is not recommended, because it means you cannot simply supply new covariate values to :meth:`liesel.model.Model.predict` for evaluating the basis matrix for predictions. >>> from liesel.contrib.splines import equidistant_knots, basis_matrix >>> import liesel_gam as gam >>> df = gam.demo_data(n=100) >>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20) >>> def bspline_basis(x): ... return basis_matrix(x, knots=knots) >>> x = df["x_nonlin"].to_numpy() >>> gam.Basis(value=bspline_basis(x), name="B(x)") Basis(name="B(x)") """ def __init__( self, value: lsl.Var | lsl.Node | ArrayLike, basis_fn: Callable[[Array], Array] | Callable[..., Array] = lambda x: x, name: str | None = None, xname: str | None = None, use_callback: bool = True, cache_basis: bool = True, penalty: ArrayLike | lsl.Value | Literal["identity"] | None = "identity", **basis_kwargs, ) -> None: self._validate_xname(value, xname) value_var = _ensure_var_or_node(value, xname) if use_callback: value_ar = jnp.asarray(value_var.value) basis_kwargs_arr = {} for key, val in basis_kwargs.items(): if isinstance(val, lsl.Var | lsl.Node): basis_kwargs_arr[key] = val.value else: basis_kwargs_arr[key] = val basis_ar = basis_fn(value_ar, **basis_kwargs_arr) dtype = basis_ar.dtype input_shape = jnp.shape(basis_ar) # This is special-case handling for compatibility with # basis functions that remove cases. For example, if you have a formulaic # formula "x + lag(x)", then the resulting basis will have one case less # than the original x, because the first case is dropped. if value_ar.shape: p = value_ar.shape[0] if value_ar.shape else 0 k = input_shape[0] if input_shape else 0 m = p - k else: m = 0 fn = make_callback(basis_fn, input_shape, dtype, m) else: fn = basis_fn name_ = self._basis_name(value_var, name) if cache_basis: calc = lsl.Calc( fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc") ) else: calc = lsl.TransientCalc( fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc") ) super().__init__(calc, name=name_) self.update() if isinstance(penalty, lsl.Value): penalty_var = penalty elif isinstance(penalty, str) and penalty == "identity": penalty_arr = jnp.eye(self.nbases) penalty_var = lsl.Value(penalty_arr) elif penalty is None: penalty_var = None else: penalty_arr = jnp.asarray(penalty) penalty_var = lsl.Value(penalty_arr) self._penalty = penalty_var self._constraint: str | None = None self._reparam_matrix: Array | None = None @property def nbases(self) -> int: """Number of basis functions (number of columns in the basis matrix).""" basis_shape = jnp.shape(self.value) if len(basis_shape) > 1: nbases: int = basis_shape[-1] else: nbases = 1 # scalar case return nbases @property def x(self) -> lsl.Var | lsl.Node: """The input variable (observations) used to construct the basis.""" return self.value_node[0] @property def constraint(self) -> str | None: """ The type of constraint applied to this basis and penalty (if any). See :meth:`.Basis.constrain` for details. """ return self._constraint @property def reparam_matrix(self) -> Array | None: """ Reparameterization matrix used for constraint of this basis and penalty (if any). See :meth:`.Basis.constrain` for details. """ return self._reparam_matrix def _validate_xname(self, value: lsl.Var | lsl.Node | ArrayLike, xname: str | None): if isinstance(value, lsl.Var | lsl.Node) and xname is not None: raise ValueError( "When supplying a variable or node to `value`, `xname` must not be " "used. Name the variable instead." ) def _basis_name(self, value: lsl.Var | lsl.Node, name: str | None): if name is not None: return name if value.name == "": return "" return f"B({value.name})" @property def penalty(self) -> lsl.Value | None: """ Penalty matrix, wrapped as a :class:`liesel.model.Value` (if any). """ return self._penalty def _validate_penalty_shape(self, pen: ArrayLike | lsl.Value) -> lsl.Value: if isinstance(pen, lsl.Value): pen_arr = jnp.asarray(pen.value) pen_val = pen pen_val.value = pen_arr else: pen_arr = jnp.asarray(pen) pen_val = lsl.Value(pen_arr) if self.value.ndim < 2: # 1d bases don't necessarily allow shape validation of penalties. return pen_val if not pen_arr.shape[-1] == self.nbases: raise ValueError( f"Basis has {self.nbases} columns, replacement penalty has " f"{pen_arr.shape[-1]}" ) return pen_val
[docs] def update_penalty(self, value: ArrayLike | lsl.Value) -> None: """ Updates the penalty matrix for this basis. If :attr:`.Basis.penalty` is not None, this method will only update the value of the penalty node, not the whole object. Even if the argument to this method is a node. Parameters ---------- value New penalty matrix or a :class:`liesel.model.Value` wrapping a penalty matrix. """ if self._penalty is None: self._penalty = self._validate_penalty_shape(value) else: self._penalty.value = self._validate_penalty_shape(value).value
[docs] @classmethod def new_linear( cls, value: lsl.Var | lsl.Node | Array, name: str | None = None, xname: str | None = None, add_intercept: bool = False, ): """ Create a linear basis (design matrix) from input values. Parameters ---------- value Input variable or raw array used to construct the design matrix. name Optional name for the basis. xname Name for the observation variable when ``value`` is \ a raw array. add_intercept If ``True``, adds an intercept column of ones as the first \ column of the design matrix. Returns ------- A :class:`.Basis` instance that produces a (n_obs, n_features) design matrix. """ def as_matrix(x): x = jnp.atleast_1d(x) if len(jnp.shape(x)) == 1: x = jnp.expand_dims(x, -1) if add_intercept: ones = jnp.ones(x.shape[0]) x = jnp.c_[ones, x] return x basis = cls( value=value, basis_fn=as_matrix, name=name, xname=xname, use_callback=False, cache_basis=False, ) return basis
[docs] def diagonalize_penalty(self, atol: float = 1e-6) -> Self: """ Diagonalize the penalty via an eigenvalue decomposition. This method computes a transformation that diagonalizes the penalty matrix and updates the internal basis function such that subsequent evaluations use the accordingly transformed basis. The penalty is updated to the diagonalized version. Parameters ---------- atol Absolute tolerance used in testing whether the existing penalty is already diagonal. If that is the case, the basis instance is returned without any further changes. Returns ------- The modified basis instance (self). Notes ----- Penalty diagonalization works via an eigenvalue decomposition of the penalty matrix. Let the eigenvalue decomposition of the :math:`d \\times d` penalty matrix :math:`\\mathbf{K}` be given by .. math:: \\mathbf{K} = \\mathbf{U} \\boldsymbol{\\Lambda} \\mathbf{U}^\\top, where :math:`\\boldsymbol{\\Lambda} = \\operatorname{diag}(\\lambda_1, \\dots, \\lambda_d)` contains the eigenvalues of :math:`\\mathbf{K}` in decreasing order and :math:`\\mathbf{U}` the corresponding eigenvectors. Let :math:`r` denote the rank of :math:`\\mathbf{K}`. The function obtains a reparameterization matrix :math:`\\mathbf{Z}` as .. math:: \\mathbf{Z} = \\mathbf{U} \\boldsymbol{\\Lambda}^{-1/2}, where :math:`\\boldsymbol{\\Lambda}^{-1/2} = \\operatorname{diag}(\\lambda_1^{-1/2}, \\dots, \\lambda_r^{-1/2}, \\mathbf{0}_{d-r}^\\top)`. The element :math:`\\mathbf{0}_{d-r}^\\top` is a zero-vector of length :math:`d-r`, corresponding to the zero eigenvalues of the penalty matrix. The basis matrix :math:`\\mathbf{B}` is then updated as :math:`\\mathbf{B}_Z = \\mathbf{B} \\mathbf{Z}`, and the penalty matrix is updated to :math:`\\operatorname{diag}(\\mathbf{1}_{r}^\\top, \\mathbf{0}_{d-r}^\\top)`. The basis function is likewise updated to evaluate to the reparamterized basis matrix during prediction. References ---------- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z """ if self.penalty is None: raise TypeError("Basis.penalty is None, cannot apply transformation.") assert isinstance(self.value_node, lsl.Calc) basis_fn = self.value_node.function K = self.penalty.value if is_diagonal(K, atol=atol): return self rank = jnp.linalg.matrix_rank(K) Z = penalty_to_unit_design(K, rank=rank) def reparam_basis(*args, **kwargs): return basis_fn(*args, **kwargs) @ Z self.value_node.function = reparam_basis self.update() penalty = jnp.eye(Z.shape[-1]) # practically equal to: penalty = Z.T @ K @ Z penalty = penalty.at[rank:, rank:].set(0.0) self.update_penalty(penalty) return self
[docs] def scale_penalty(self) -> Self: """ Scale the penalty matrix by its infinite norm. The penalty matrix is divided by its infinity norm (max absolute row sum) so that its values are numerically well-conditioned for downstream use. The updated penalty replaces the previous one. Returns ------- The modified basis instance (self). """ if self.penalty is None: raise TypeError("Basis.penalty is None, cannot apply transformation.") K = self.penalty.value scale = jnp.linalg.norm(K, ord=jnp.inf) penalty = K / scale self.update_penalty(penalty) return self
def _apply_constraint(self, Z: Array) -> Self: """ Apply a linear reparameterisation to the basis using matrix Z. This internal helper multiplies the basis functions by ``Z`` (i.e. right-multiplies the design matrix) and updates the penalty to reflect the change of basis: ``K_new = Z.T @ K @ Z``. Parameters ---------- Z Transformation matrix applied to the basis functions. Returns ------- The modified basis instance (self). """ if self.penalty is None: raise TypeError("Basis.penalty is None, cannot apply transformation.") assert isinstance(self.value_node, lsl.Calc) basis_fn = self.value_node.function K = self.penalty.value def reparam_basis(*args, **kwargs): return basis_fn(*args, **kwargs) @ Z self.value_node.function = reparam_basis self.update() penalty = Z.T @ K @ Z self.update_penalty(penalty) return self
[docs] def constrain( self, constraint: ArrayLike | Literal["sumzero_term", "sumzero_coef", "constant_and_linear"], ) -> Self: r""" Apply a linear constraint to the basis and corresponding penalty. When a constraint is applied, the type of constraint is saved to :attr:`.Basis.constraint`, and the reparamterization matrix is saved to :attr:`.Basis.reparam_matrix`. Parameters ---------- constraint Type of constraint or custom linear constraint matrix to apply. If an array is supplied, the constraint will be ``A @ coef == 0``, where ``A`` is the supplied array (the constraint matrix). Returns ------- The modified basis instance (self). Notes ----- This method implements the procedure detailed by Kneib et al. (2019). For the following exposition, which is quoted almost verbatim from Kneib et al. (2019), assume that this basis is used to evaluate a function .. math:: s(\mathbf{x}_i) = \sum_{j=1}^J B_j(\mathbf{x}_i) \beta_j = \mathbf{b}(\mathbf{x}_i)^\top \boldsymbol{\beta}, where - :math:`i=1, \dots, N` is the observation index, - :math:`\mathbf{x}_i^\top = [x_{i,1}, \dots, x_{i,M}]` are covariate observations, where :math:`M` denotes the number of covariates, - :math:`\mathbf{b}^\top = [B_1(\mathbf{x}), \dots, B_J(\mathbf{x})]` are a set of basis function evaluations, and - :math:`\boldsymbol{\beta}^\top = [\beta_1, \dots, \beta_J]` are the corresponding coefficients. The basis matrix for such a term is .. math:: \mathbf{B} = \begin{bmatrix} \mathbf{b}(\mathbf{x}_1)^\top \\ \vdots \\ \mathbf{b}(\mathbf{x}_N)^\top \end{bmatrix}, and the term can be written in matrix form as .. math:: \mathbf{s} = \mathbf{B} \boldsymbol{\beta}, where :math:`\mathbf{B}` is the basis matrix of dimension :math:`N \times J`. We consider :math:`\boldsymbol{\beta} \in \mathbb{R}^J` to be subject to linear constraints of the form .. math:: \mathbf{A} \boldsymbol{\beta} = \mathbf{0}. :math:`\mathbf{A}` is an :math:`A \times J` constraint matrix. To explicitly remove the constrained component, we construct a complementary matrix :math:`\bar{\mathbf{A}} \in \mathbb{R}^{(J-A) \times J}` such that .. math:: \bar{\mathbf{A}} \mathbf{A}^\top = \mathbf{0}, and the stacked matrix :math:`[\mathbf{A}^\top, \bar{\mathbf{A}}^\top]^\top` is of full rank. One possible construction of :math:`\bar{\mathbf{A}}` is based on the eigenvalue decomposition of :math:`\mathbf{A}^\top \mathbf{A}`, using the eigenvectors corresponding to zero eigenvalues. This is the construction of :math:`\bar{\mathbf{A}}` used in this method. Under the full-rank assumption, the inverse of the composed matrix exists and can be written as .. math:: \begin{bmatrix} \mathbf{A} \\ \bar{\mathbf{A}} \end{bmatrix}^{-1} = \begin{bmatrix} \mathbf{C}, \bar{\mathbf{C}} \end{bmatrix}, where :math:`\mathbf{C} \in \mathbb{R}^{J \times A}` and :math:`\bar{\mathbf{C}} \in \mathbb{R}^{J \times (J-A)}`. This yields the reparameterisation .. math:: \boldsymbol{\beta} = \mathbf{C} \boldsymbol{\alpha} + \bar{\mathbf{C}} \boldsymbol{\gamma}, where :math:`\boldsymbol{\alpha} = \mathbf{A} \boldsymbol{\beta} = \mathbf{0}` vanishes due to the constraint and :math:`\boldsymbol{\gamma} = \bar{\mathbf{A}} \boldsymbol{\beta}` represents the remaining unconstrained coefficients. Applying this reparameterisation to the functional effect gives :math:`\bar{\mathbf{s}} = \bar{\mathbf{B}} \boldsymbol{\alpha}`, where the basis matrix is reparameterized as .. math:: \bar{\mathbf{B}} = \mathbf{B} \bar{\mathbf{C}}. Accordingly, the original penalty matrix :math:`\mathbf{K}` is reparamterized as .. math:: \bar{\mathbf{K}} = \bar{\mathbf{C}}^\top \mathbf{K} \bar{\mathbf{C}}. .. rubric:: Default constraint options The default options correspond to the following constraint matrices: - ``"sumzero_term"``: :math:`\mathbf{A} = \mathbf{1}^\top \mathbf{B}`, where :math:`\mathbf{B}` is the basis matrix. This is the preferred option for a sum to zero constraint, because it centers the evaluated term. - ``"sumzero_coef"``: :math:`\mathbf{A} = \mathbf{1}^\top`. This is an alternative sum to zero constraint, focusing only on ensuring that the coefficients sum to zero. - ``"constant_and_linear"``: :math:`\mathbf{A}=(\mathbf{X}^\top\mathbf{X})^{-1}\mathbf{X}^\top \mathbf{B}`, where :math:`\mathbf{X} = [\mathbf{1}, \mathbf{x}]` is a design matrix built with the covariate observations :math:`\mathbf{x}` used in this basis. This constraint removes both a constant (like ``"sumzero_term"``) and a linear trend from the term modeled with this basis. References ---------- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z """ # noqa: E501 if not self.value.ndim == 2: raise ValueError( "Constraints can only be applied to matrix-valued bases. " f"{self} has shape {self.value.shape}" ) if self.constraint is not None: raise ValueError( f"A '{self.constraint}' constraint has already been applied." ) if isinstance(constraint, str): type_: str = constraint else: constraint_matrix = jnp.asarray(constraint) type_ = "custom" match type_: case "sumzero_coef": Z = LinearConstraintEVD.sumzero_coef(self.nbases) case "sumzero_term": Z = LinearConstraintEVD.sumzero_term(self.value) case "constant_and_linear": Z = LinearConstraintEVD.constant_and_linear(self.x.value, self.value) case "custom": Z = LinearConstraintEVD.general(constraint_matrix) self._apply_constraint(Z) self._constraint = type_ self._reparam_matrix = Z return self
[docs] class MRFBasis(Basis): """ Dedicated basis object for Markov random fields. See :class:`.Basis` for general usage information. This class additionally offers information about the Markov random field setup in :attr:`.mrf_spec`. """ _mrf_spec: MRFSpec | None = None @property def mrf_spec(self) -> MRFSpec: """ A named tuple, containing information about the Markov random field setup. The :class:`.MRFSpec` has the attributes ``nb`` (neighborhood structure), ``mapping`` (label-integer map for the region labels), and ``ordered_labels`` (ordered labels, such that the order correspond to the columns of the basis matrix.) """ if self._mrf_spec is None: raise ValueError("No MRF spec defined.") return self._mrf_spec @mrf_spec.setter def mrf_spec(self, value: MRFSpec): if not isinstance(value, MRFSpec): raise TypeError( f"Replacement must be of type {MRFSpec}, got {type(value)}." ) self._mrf_spec = value
[docs] class LinBasis(Basis): """ Dedicated basis object for linear effects. See :class:`.Basis` for general usage information. This class additionally offers - :attr:`.model_spec`: The model spec used internally by ``formulaic`` to set up the basis matrix. - :attr:`.mappings`: A dictionary of label-integer mappings for all categorical variables in this basis. - :attr:`.column_names`: List of column names for this basis. """ _model_spec: ModelSpec | None = None _mappings: dict[str, CategoryMapping] | None = None _column_names: list[str] | None = None @property def model_spec(self) -> ModelSpec: """ The model spec used internally by ``formulaic`` to set up the basis matrix. """ if self._model_spec is None: raise ValueError("No model spec defined.") return self._model_spec @model_spec.setter def model_spec(self, value: ModelSpec): if not isinstance(value, ModelSpec): raise TypeError( f"Replacement must be of type {ModelSpec}, got {type(value)}." ) self._model_spec = value @property def mappings(self) -> dict[str, CategoryMapping]: """ A dictionary of label-integer mappings for all categorical variables in this basis. """ if self._mappings is None: raise ValueError("No model spec defined.") return self._mappings @mappings.setter def mappings(self, value: dict[str, CategoryMapping]): if not isinstance(value, dict): raise TypeError(f"Replacement must be of type dict, got {type(value)}.") for val in value.values(): if not isinstance(val, CategoryMapping): raise TypeError( f"The values in the replacement must be of type {CategoryMapping}, " f"got {type(val)}." ) self._mappings = value @property def column_names(self) -> list[str]: if self._column_names is None: raise ValueError("No model spec defined.") return self._column_names @column_names.setter def column_names(self, value: Sequence[str]): """List of column names for this basis.""" if not isinstance(value, Sequence): raise TypeError(f"Replacement must be a sequence, got {type(value)}.") if isinstance(value, str): raise TypeError("Replacement type cannot be string.") if not len(value) == self.value.shape[-1]: raise ValueError( f"Expected {self.value.shape[-1]} column names, got {len(value)}" ) for val in value: if not isinstance(val, str): raise TypeError( f"The values in the replacement must be of type str, " f"got {type(val)}." ) self._column_names = list(value)
[docs] class MRFSpec(NamedTuple): """ A named tuple, containing information about the Markov random field setup. The :class:`.MRFSpec` has the attributes ``nb`` (neighborhood structure), ``mapping`` (label-integer map for the region labels), and ``ordered_labels`` (ordered labels, such that the order correspond to the columns of the basis matrix.) """ mapping: CategoryMapping nb: dict[str, list[str]] | None ordered_labels: list[str] | None polys: dict[str, ArrayLike] | None