Source code for liesel_gam.experimental.approx_bspline

from __future__ import annotations

from collections.abc import Callable

import jax
import jax.numpy as jnp
from jax.tree_util import Partial as partial
from liesel.contrib.splines import basis_matrix

Array = jax.Array
ArrayLike = jax.typing.ArrayLike


def _bspline_basis(x: ArrayLike, knots: ArrayLike, order: int) -> Array:
    """Return B-spline basis, allowing values outside knots."""
    x = jnp.asarray(x)
    knots = jnp.asarray(knots)
    return basis_matrix(x, knots, order, outer_ok=True)


@partial(jax.jit, static_argnums=2)
@partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)")
def bspline_basis(x: ArrayLike, knots: ArrayLike, order: int) -> Array:
    """
    Vectorized B-spline basis function evaluation.

    Values outside the support are set to zero, where::

        min = knots[order]
        max = knots[-(order + 1)]

    Parameters
    ----------
    x
        Input array.
    knots
        Array of knots.
    order
        Order of the spline (``order=3`` for a cubic spline).

    Returns
    -------
    B-spline basis matrix.
    """
    x = jnp.asarray(x)
    knots = jnp.asarray(knots)

    min_knot = knots[order]
    max_knot = knots[-(order + 1)]
    basis = _bspline_basis(x, knots, order)
    mask = jnp.logical_or(x < min_knot, x > max_knot)
    mask = jnp.expand_dims(mask, -1)
    return jnp.where(mask, 0.0, basis)


@partial(jax.jit, static_argnums=2)
@partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)")
def bspline_basis_deriv(x: ArrayLike, knots: ArrayLike, order: int) -> Array:
    """
    Evaluate matrix of first derivatives of B-spline bases.

    Values outside the support are set to zero.
    """
    x = jnp.asarray(x)
    knots = jnp.asarray(knots)

    min_knot = knots[order]
    max_knot = knots[-(order + 1)]

    basis = _bspline_basis(x, knots[1:-1], order - 1)
    dknots = jnp.diff(knots).mean()
    D = jnp.diff(jnp.identity(jnp.shape(knots)[-1] - order - 1)).T
    basis_grad = basis @ (D / dknots)

    mask = jnp.logical_or(x < min_knot, x > max_knot)
    mask = jnp.expand_dims(mask, -1)
    return jnp.where(mask, 0.0, basis_grad)


@partial(jax.jit, static_argnums=2)
@partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)")
def bspline_basis_deriv2(x, knots, order):
    """Evaluate matrix of second derivatives of B-spline bases."""
    x = jnp.asarray(x)
    knots = jnp.asarray(knots)

    min_knot = knots[order]
    max_knot = knots[-(order + 1)]
    basis = _bspline_basis(x, knots[2:-2], order - 2)

    dknots = jnp.diff(knots).mean()
    D = jnp.diff(jnp.identity(jnp.shape(knots)[-1] - order - 1)).T
    basis_grad = basis @ D[1::, 1:] @ (D / (dknots**2))

    mask = jnp.logical_or(x < min_knot, x > max_knot)
    mask = jnp.expand_dims(mask, -1)
    return jnp.where(mask, 0.0, basis_grad)


[docs] class BSplineApprox: """ Approximate B-spline evaluations on a fixed grid. Parameters ---------- knots Knot positions. degree Degree of the spline, with ``degree=3`` indicating a cubic spline. ngrid Number of grid points used to precompute the basis (default 1000). Z Optional matrix to post-multiply the basis. In ``B(x) @ Z``, ``B(x)`` is the basis matrix and ``Z`` is the postmultiplication matrix. Can be used to apply linear constraints via reparameterization matrices such as those returned by :class:`.LinearConstraintEVD`. See Also -------- .LinearConstraintEVD : Compute reparameterization matrices for linear constraints. Examples -------- Basic example. >>> import jax >>> from liesel.contrib.splines import equidistant_knots >>> import liesel_gam as gam >>> nbases = 20 >>> x = jax.random.uniform(jax.random.key(1234), shape=(40,)) >>> coef = jax.random.normal(jax.random.key(4321), shape=(nbases,)) >>> knots = equidistant_knots(x, n_param=nbases) >>> bspline = gam.experimental.BSplineApprox(knots) >>> bspline.dot(x, coef).shape (40,) Applying a linear constraint matrix. >>> import jax >>> from liesel.contrib.splines import basis_matrix, equidistant_knots >>> import liesel_gam as gam >>> nbases = 20 >>> x = jax.random.uniform(jax.random.key(1234), shape=(40,)) >>> coef = jax.random.normal(jax.random.key(4321), shape=((nbases - 1),)) >>> knots = equidistant_knots(x, n_param=nbases) >>> basis = basis_matrix(x, knots) >>> Z = gam.LinearConstraintEVD.sumzero_term(basis) >>> bspline = gam.experimental.BSplineApprox(knots, Z=Z) >>> bspline.dot(x, coef).shape (40,) """ def __init__( self, knots: Array, degree: int = 3, ngrid: int = 1000, Z: Array | None = None, subscripts: str = "...ij,...j->...i", ) -> None: self.knots = jnp.asarray(knots) self.dknots = jnp.mean(jnp.diff(knots)) self.degree = degree self.nparam = jnp.shape(knots)[0] - degree - 1 self.subscripts = subscripts self.min_knot = self.knots[degree] self.max_knot = self.knots[-(degree + 1)] grid = jnp.linspace(self.min_knot, self.max_knot, ngrid) self.step = (self.max_knot - self.min_knot) / ngrid prepend = jnp.array([self.min_knot - self.step]) append = jnp.array([self.max_knot + self.step]) self.ngrid = ngrid self.grid = jnp.concatenate((prepend, grid, append)) self.Z = Z self.basis_grid = self.compute_basis(grid) self.basis_deriv_grid = self.compute_deriv(grid) self.basis_deriv2_grid = self.compute_deriv2(grid) self._dot_fn = self._get_dot_fn() self._dot_and_deriv_fn = self._get_dot_and_deriv_fn()
[docs] def compute_basis(self, x: Array) -> Array: """Computes the basis matrix.""" basis = bspline_basis(x, self.knots, self.degree) if self.Z is not None: basis = basis @ self.Z return basis
[docs] def compute_deriv(self, x: Array) -> Array: """Computes the matrix of basis function derivatives with respect to x.""" deriv = bspline_basis_deriv(x, self.knots, self.degree) if self.Z is not None: deriv = deriv @ self.Z return deriv
[docs] def compute_deriv2(self, x: Array) -> Array: """ Computes the matrix of basis function second derivatives with respect to x. """ deriv2 = bspline_basis_deriv2(x, self.knots, self.degree) if self.Z is not None: deriv2 = deriv2 @ self.Z return deriv2
@partial(jax.jit, static_argnums=0) def _approx_basis(self, x: Array) -> Array: i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis_grid[i, :] + (k * self.basis_grid[i + 1, :]) return jnp.atleast_2d(basis) @partial(jax.jit, static_argnums=0) def _approx_basis_and_deriv(self, x: Array) -> tuple[Array, Array]: """ Returns the basis matrix approximation and its gradient with respect to the data. """ i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis_grid[i, :] + (k * self.basis_grid[i + 1, :]) basis_deriv = (1.0 - k) * self.basis_deriv_grid[i, :] + ( k * self.basis_deriv_grid[i + 1, :] ) basis = jnp.atleast_2d(basis) basis_deriv = jnp.atleast_2d(basis_deriv) return basis, basis_deriv @partial(jax.jit, static_argnums=0) def _approx_basis_deriv_and_deriv2(self, x: Array) -> tuple[Array, Array, Array]: """ Returns the basis matrix approximation and its first and second derivative with respect to the data. """ i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis_grid[i, :] + (k * self.basis_grid[i + 1, :]) basis_deriv = (1.0 - k) * self.basis_deriv_grid[i, :] + ( k * self.basis_deriv_grid[i + 1, :] ) basis_deriv2 = (1.0 - k) * self.basis_deriv2_grid[i, :] + ( k * self.basis_deriv2_grid[i + 1, :] ) basis = jnp.atleast_2d(basis) basis_deriv = jnp.atleast_2d(basis_deriv) basis_deriv2 = jnp.atleast_2d(basis_deriv2) return basis, basis_deriv, basis_deriv2
[docs] def approx_basis(self, x: Array) -> Array: """ Approximates B-spline basis for input data. Parameters ---------- x Input data. Returns ------- Basis matrix. """ return self._approx_basis(x)
[docs] def approx_basis_and_deriv(self, x: Array) -> tuple[Array, Array]: """ Approximates basis and matrix of first derivatives. Parameters ---------- x Input data. Returns ------- basis Basis matrix. deriv First derivative. """ return self._approx_basis_and_deriv(x)
[docs] def approx_basis_and_deriv2(self, x: Array) -> tuple[Array, Array, Array]: """ Approximates basis matrix and matrices of first and second derivatives. Parameters ---------- x Input data. Returns ------- basis : Array Basis matrix. deriv : Array First derivative. deriv2 : Array Second derivative. """ return self._approx_basis_deriv_and_deriv2(x)
def _compute_dot(self, x: Array, coef: Array) -> Array: return jnp.einsum(self.subscripts, x, coef) def _get_dot_fn(self) -> Callable[[Array, Array], Array]: @jax.custom_jvp def _dot( x: Array, coef: Array, ) -> Array: basis = self.approx_basis(x) smooth = self._compute_dot(basis, coef) return smooth @_dot.defjvp def _dot_jvp(primals, tangents): # x and x_dot: (n,) # coef and coef_dot: (k,) x, coef = primals x_dot, coef_dot = tangents # both shape (n, k) basis, basis_deriv = self.approx_basis_and_deriv(x) # shape (n,) smooth = self._compute_dot(basis, coef) tangent_x = self._compute_dot(basis_deriv, coef) * x_dot tangent_coef = self._compute_dot(basis, coef_dot) tangent = tangent_x + tangent_coef return smooth, tangent return jax.jit(_dot) def _get_dot_and_deriv_fn( self, ) -> Callable[[Array, Array], tuple[Array, Array]]: @jax.custom_jvp def _dot_and_deriv( x: Array, coef: Array, ) -> tuple[Array, Array]: """ Assumes x is (,) And coef is (p,) """ basis, basis_deriv = self.approx_basis_and_deriv(x) # (p,) and (p,) shapes smooth = self._compute_dot(basis, coef) smooth_deriv = self._compute_dot(basis_deriv, coef) return smooth, smooth_deriv # (,) and (,) shapes @_dot_and_deriv.defjvp def _dot_and_deriv_jvp(primals, tangents): x, coef = primals x_dot, coef_dot = tangents basis, basis_deriv, basis_deriv2 = self.approx_basis_and_deriv2(x) smooth = self._compute_dot(basis, coef) smooth_deriv = self._compute_dot(basis_deriv, coef) smooth_deriv2 = self._compute_dot(basis_deriv2, coef) primal_out = (smooth, smooth_deriv) tangent_bdot_x = smooth_deriv * x_dot tangent_bdot_coef = self._compute_dot(basis, coef_dot) tangent_bdot = tangent_bdot_x + tangent_bdot_coef tangent_deriv_x = smooth_deriv2 * x_dot tangent_deriv_coef = self._compute_dot(basis_deriv, coef_dot) tangent_deriv = tangent_deriv_x + tangent_deriv_coef tangent_out = (tangent_bdot, tangent_deriv) return primal_out, tangent_out return jax.jit(_dot_and_deriv)
[docs] def dot(self, x: Array, coef: Array) -> Array: """ Evaluate spline at given points. Parameters ---------- x Input data, an array of shape (n,). coef Spline coefficients. """ return self._dot_fn(x, coef)
[docs] def dot_and_deriv(self, x: Array, coef: Array) -> tuple[Array, Array]: """ Evaluate spline and its derivative. Parameters ---------- x Input data, an array of shape (n,). coef Spline coefficients. Returns ------- value : Array Spline values. deriv : Array Spline derivatives. """ return self._dot_and_deriv_fn(x, coef)