Basis

Contents

Basis#

class liesel_gam.Basis(value, basis_fn=<function Basis.<lambda>>, name=None, xname=None, use_callback=True, cache_basis=True, penalty='identity', **basis_kwargs)[source]#

Bases: UserVar

General basis for a structured additive term.

The Basis class wraps an observation variable (or an array) and a basis-generation function. It constructs an internal calculation node that produces the basis (design) matrix by computing basis_fn(value). The basis function may be executed via a callback, in which case it does not need to be jax-compatible. This is the default, but it is potentially very slow, if the value of the basis needs to be recomputed during estimation. We recommend it only for bases that remain static during estimation.

Parameters:
  • value (Var | Node | Array | ndarray | bool | number | bool | int | float | complex) – If a liesel.model.Var or node is provided it is used as the input variable for the basis. Otherwise a raw array-like object may be supplied together with xname to create an observed variable internally.

  • basis_fn (Callable[[Array], Array] | Callable[..., Array], default: <function Basis.<lambda> at 0x75fb71d479c0>) – Function mapping the input variable’s values to a basis matrix or vector. It must accept the input array and any basis_kwargs and return an array of shape (n_obs, n_bases) (or a scalar/1-d array for simpler bases). By default this is the identity function (lambda x: x).

  • name (str | None, default: None) – Optional name for the basis object. If omitted, a sensible name is constructed from the input variable’s name (B(<xname>)).

  • xname (str | None, default: None) – Required when value is a raw array: provides a name for the observation variable that will be created.

  • use_callback (bool, default: True) – If True (default) the basis_fn is wrapped in a JAX pure_callback via make_callback() to allow arbitrary Python basis functions while preserving JAX tracing. If False the function is used directly and must be jittable via JAX.

  • cache_basis (bool, default: True) – If True the computed basis is cached in a persistent calculation node (lsl.Calc), which avoids re-computation when not required, but uses memory. If False a transient calculation node (lsl.TransientCalc) is used and the basis will be recomputed with each evaluation of Basis.value, but not stored in memory.

  • penalty (Array | ndarray | bool | number | bool | int | float | complex | Value | Literal['identity'] | None, default: 'identity') – Penalty matrix associated with the basis. If "identity", a default identity penalty is created based on the number of basis functions. If None, an identity penalty is assumed, but not materialized, which saves memory but must be handled explicitly later, if downstream functionality relies on an explicit penalty matrix.

  • **basis_kwargs – Additional keyword arguments forwarded to basis_fn.

See also

TermBuilder

Initializes structured additive terms.

BasisBuilder

Initializesstructured additive terms.

StrctTerm

A general structured additive term.

Notes

The basis is evaluated once during initialization (via self.update()) to determine its shape and dtype. The internal callback wrapper inspects the return shape to build a compatible JAX ShapeDtypeStruct for the pure callback.

Examples

Implementing a B-spline basis manually:

>>> from liesel.contrib.splines import (
...     basis_matrix,
...     equidistant_knots,
...     pspline_penalty,
... )
>>> import liesel_gam as gam
>>> df = gam.demo_data(n=100)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20)
>>> pen = pspline_penalty(d=20)
>>> def bspline_basis(x):
...     return basis_matrix(x, knots=knots)
>>> gam.Basis(
...     value=df["x_nonlin"].to_numpy(),
...     basis_fn=bspline_basis,
...     xname="x",
...     penalty=pen,
... )
Basis(name="B(x)")

Implementing a fixed basis matrix (without using the basis function). This is not recommended, because it means you cannot simply supply new covariate values to liesel.model.Model.predict() for evaluating the basis matrix for predictions.

>>> from liesel.contrib.splines import equidistant_knots, basis_matrix
>>> import liesel_gam as gam
>>> df = gam.demo_data(n=100)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20)
>>> def bspline_basis(x):
...     return basis_matrix(x, knots=knots)
>>> x = df["x_nonlin"].to_numpy()
>>> gam.Basis(value=bspline_basis(x), name="B(x)")
Basis(name="B(x)")

Methods

constrain

Apply a linear constraint to the basis and corresponding penalty.

diagonalize_penalty

Diagonalize the penalty via an eigenvalue decomposition.

new_linear

Create a linear basis (design matrix) from input values.

scale_penalty

Scale the penalty matrix by its infinite norm.

update_penalty

Updates the penalty matrix for this basis.

Attributes

constraint

The type of constraint applied to this basis and penalty (if any).

nbases

Number of basis functions (number of columns in the basis matrix).

penalty

Penalty matrix, wrapped as a liesel.model.Value (if any).

reparam_matrix

Reparameterization matrix used for constraint of this basis and penalty (if any).

x

The input variable (observations) used to construct the basis.