BasisBuilder.basis()

BasisBuilder.basis()#

BasisBuilder.basis(*x, basis_fn, use_callback=True, cache_basis=True, penalty=None, basis_name='B')[source]#

Initializes a general basis given a basis function.

Parameters:
  • *x (str | Var) – Names of input variables.

  • basis_fn (Callable[[Array], Array]) – Basis function. Must take a 2d-array as input and return a 2d array.

  • use_callback (bool, default: True) – If True, the basis function is evaluated using a Python callback, which means that it does not have to be jit-compatible via JAX. This also means that the basis must remain constant throughout estimation. Passed on to Basis.

  • cache_basis (bool, default: True) – If True the computed basis is cached in a persistent calculation node (lsl.Calc), which avoids re-computation when not required. Passed on to Basis.

  • penalty (Array | ndarray | bool | number | bool | int | float | complex | Value | None, default: None) – Penalty matrix associated with the basis. Passed on to Basis.

  • basis_name (str, default: 'B') – Function-name for the basis matrix. If "B", and the basis is a function of the variable "x", the full name of the Basis object will be "B(x)". Names are made unique by appending a counter if necessary.

Return type:

Basis

Examples

Manually specified B-Spline basis

>>> from liesel.contrib.splines import basis_matrix, equidistant_knots
>>> from liesel.contrib.splines import pspline_penalty
>>> import liesel_gam as gam
>>> df = gam.demo_data(n=100)
>>> registry = gam.PandasRegistry(df)
>>> bb = gam.BasisBuilder(registry)
>>> knots = equidistant_knots(df["x_nonlin"].to_numpy(), n_param=20)
>>> pen = pspline_penalty(d=20)

The basis function should always expect a matrix-valued array as an input.

>>> def bspline_basis(x_mat):
...     # x_mat is shape (n, 1)
...     x_vec = x_mat.squeeze()  # shape (n,)
...     return basis_matrix(x_vec, knots=knots)
>>> bb.basis("x_nonlin", basis_fn=bspline_basis, penalty=pen)
Basis(name="B(x_nonlin)")

Manually specified linear basis

This is a minimal example for how a basis as a function of multiple variables works.

>>> import jax.numpy as jnp
>>> import liesel_gam as gam
>>> df = gam.demo_data(n=100)
>>> registry = gam.PandasRegistry(df)
>>> bb = gam.BasisBuilder(registry)
>>> def linear_basis(x_mat):
...     # x_mat is shape (n, 2)
...     basis_mat = jnp.column_stack((jnp.ones(df.shape[0]), x_mat))
...     return basis_mat
>>> basis = bb.basis("x_nonlin", "x_lin", basis_fn=linear_basis)
>>> basis
Basis(name="B(x_nonlin,x_lin)")
>>> basis.value.shape
(100, 3)