Basis#
- class liesel_gam.Basis(value, basis_fn=<function Basis.<lambda>>, name=None, xname=None, use_callback=True, cache_basis=True, penalty='identity', **basis_kwargs)[source]#
Bases:
UserVarGeneral basis for a structured additive term.
The
Basisclass wraps an observation variable (or an array) and a basis-generation function. It constructs an internal calculation node that produces the basis (design) matrix by computingbasis_fn(value). The basis function may be executed via a callback, in which case it does not need to be jax-compatible. This is the default, but it is potentially very slow, if the value of the basis needs to be recomputed during estimation. We recommend it only for bases that remain static during estimation.- Parameters:
value (
Var|Node|Array|ndarray|bool|number|bool|int|float|complex) – If aliesel.model.Varor node is provided it is used as the input variable for the basis. Otherwise a raw array-like object may be supplied together withxnameto create an observed variable internally.basis_fn (
Callable[[Array],Array] |Callable[...,Array], default:<function Basis.<lambda> at 0x75fb71d479c0>) – Function mapping the input variable’s values to a basis matrix or vector. It must accept the input array and anybasis_kwargsand return an array of shape(n_obs, n_bases)(or a scalar/1-d array for simpler bases). By default this is the identity function (lambda x: x).name (
str|None, default:None) – Optional name for the basis object. If omitted, a sensible name is constructed from the input variable’s name (B(<xname>)).xname (
str|None, default:None) – Required whenvalueis a raw array: provides a name for the observation variable that will be created.use_callback (
bool, default:True) – IfTrue(default) the basis_fn is wrapped in a JAXpure_callbackviamake_callback()to allow arbitrary Python basis functions while preserving JAX tracing. IfFalsethe function is used directly and must be jittable via JAX.cache_basis (
bool, default:True) – IfTruethe computed basis is cached in a persistent calculation node (lsl.Calc), which avoids re-computation when not required, but uses memory. IfFalsea transient calculation node (lsl.TransientCalc) is used and the basis will be recomputed with each evaluation ofBasis.value, but not stored in memory.penalty (
Array|ndarray|bool|number|bool|int|float|complex|Value|Literal['identity'] |None, default:'identity') – Penalty matrix associated with the basis. If"identity", a default identity penalty is created based on the number of basis functions. If None, an identity penalty is assumed, but not materialized, which saves memory but must be handled explicitly later, if downstream functionality relies on an explicit penalty matrix.**basis_kwargs – Additional keyword arguments forwarded to
basis_fn.
See also
TermBuilderInitializes structured additive terms.
BasisBuilderInitializesstructured additive terms.
StrctTermA general structured additive term.
Notes
The basis is evaluated once during initialization (via
self.update()) to determine its shape and dtype. The internal callback wrapper inspects the return shape to build a compatible JAX ShapeDtypeStruct for the pure callback.Examples
Implementing a B-spline basis manually:
>>> from liesel.contrib.splines import ( ... basis_matrix, ... equidistant_knots, ... pspline_penalty, ... ) >>> import liesel_gam as gam >>> df = gam.demo_data(n=100)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20) >>> pen = pspline_penalty(d=20)
>>> def bspline_basis(x): ... return basis_matrix(x, knots=knots)
>>> gam.Basis( ... value=df["x_nonlin"].to_numpy(), ... basis_fn=bspline_basis, ... xname="x", ... penalty=pen, ... ) Basis(name="B(x)")
Implementing a fixed basis matrix (without using the basis function). This is not recommended, because it means you cannot simply supply new covariate values to
liesel.model.Model.predict()for evaluating the basis matrix for predictions.>>> from liesel.contrib.splines import equidistant_knots, basis_matrix >>> import liesel_gam as gam >>> df = gam.demo_data(n=100)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20) >>> def bspline_basis(x): ... return basis_matrix(x, knots=knots)
>>> x = df["x_nonlin"].to_numpy() >>> gam.Basis(value=bspline_basis(x), name="B(x)") Basis(name="B(x)")
Methods
Apply a linear constraint to the basis and corresponding penalty.
Diagonalize the penalty via an eigenvalue decomposition.
Create a linear basis (design matrix) from input values.
Scale the penalty matrix by its infinite norm.
Updates the penalty matrix for this basis.
Attributes
The type of constraint applied to this basis and penalty (if any).
Number of basis functions (number of columns in the basis matrix).
Penalty matrix, wrapped as a
liesel.model.Value(if any).Reparameterization matrix used for constraint of this basis and penalty (if any).
The input variable (observations) used to construct the basis.