BasisBuilder.basis()#
- BasisBuilder.basis(*x, basis_fn, use_callback=True, cache_basis=True, penalty=None, basis_name='B')[source]#
Initializes a general basis given a basis function.
- Parameters:
basis_fn (
Callable[[Array],Array]) – Basis function. Must take a 2d-array as input and return a 2d array.use_callback (
bool, default:True) – If True, the basis function is evaluated using a Python callback, which means that it does not have to be jit-compatible via JAX. This also means that the basis must remain constant throughout estimation. Passed on toBasis.cache_basis (
bool, default:True) – IfTruethe computed basis is cached in a persistent calculation node (lsl.Calc), which avoids re-computation when not required. Passed on toBasis.penalty (
Array|ndarray|bool|number|bool|int|float|complex|Value|None, default:None) – Penalty matrix associated with the basis. Passed on toBasis.basis_name (
str, default:'B') – Function-name for the basis matrix. If"B", and the basis is a function of the variable"x", the full name of theBasisobject will be"B(x)". Names are made unique by appending a counter if necessary.
- Return type:
Examples
Manually specified B-Spline basis
>>> from liesel.contrib.splines import basis_matrix, equidistant_knots >>> from liesel.contrib.splines import pspline_penalty >>> import liesel_gam as gam
>>> df = gam.demo_data(n=100) >>> registry = gam.PandasRegistry(df) >>> bb = gam.BasisBuilder(registry)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20) >>> pen = pspline_penalty(d=20)
The basis function should always expect a matrix-valued array as an input.
>>> def bspline_basis(x_mat): ... # x_mat is shape (n, 1) ... x_vec = x_mat.squeeze() # shape (n,) ... return basis_matrix(x_vec, knots=knots)
>>> bb.basis("x_nonlin", basis_fn=bspline_basis, penalty=pen) Basis(name="B(x_nonlin)")
Manually specified linear basis
This is a minimal example for how a basis as a function of multiple variables works.
>>> import jax.numpy as jnp >>> import liesel_gam as gam >>> df = gam.demo_data(n=100) >>> registry = gam.PandasRegistry(df) >>> bb = gam.BasisBuilder(registry)
>>> def linear_basis(x_mat): ... # x_mat is shape (n, 2) ... basis_mat = jnp.column_stack((jnp.ones(df.shape[0]), x_mat)) ... return basis_mat
>>> basis = bb.basis("x_nonlin", "x_lin", basis_fn=linear_basis) >>> basis Basis(name="B(x_nonlin,x_lin)")
>>> basis.value.shape (100, 3)