Source code for liesel_gam.constraint

import jax.numpy as jnp
from jax import Array


def penalty_to_unit_design(penalty: Array, rank: Array | int | None = None) -> Array:
    """
    Convert a (semi-)definite penalty matrix into the design matrix
    projector used by mixed-model reparameterizations.

    The routine performs an eigenvalue decomposition of `penalty`, keeps the
    first `rank` eigenvectors (default: numerical rank of `penalty`), rescales
    them to have unit marginal variance (1 / sqrt(lambda)), and returns the
    resulting loading matrix.

    Parameters
    ----------
    penalty
        Positive semi-definite penalty matrix.
    rank
        Optional target rank. Defaults to the matrix rank inferred from
        ``penalty``.

    Returns
    -------
    A matrix whose columns span the penalized subspace and are scaled for
    mixed-model formulations.
    """
    if rank is None:
        rank = jnp.linalg.matrix_rank(penalty)

    evalues, evectors = jnp.linalg.eigh(penalty)
    evalues = evalues[::-1]  # put in decreasing order
    evectors = evectors[:, ::-1]  # make order correspond to eigenvalues
    rank = jnp.linalg.matrix_rank(penalty)

    if evectors[0, 0] < 0:
        evectors = -evectors

    U = evectors
    D = 1 / jnp.sqrt(jnp.ones_like(evalues).at[:rank].set(evalues[:rank]))
    D = D.at[rank:].set(0.0)
    Z = (U.T * jnp.expand_dims(D, 1)).T
    return Z


[docs] class LinearConstraintEVD: """ Computes reparameterization matrices for linear constraints. Reparameterization matrices are computed via eigenvalue decomposition. If you have a linear constraint ``A @ coef`` to be applied to a basis-coef product ``B @ coef``, were ``B`` is the basis matrix, then this constraint can be enforced by computing ``B @ Z @ latent_coef`` instead, where ``latent_coef`` is an unconstrained version of ``coef``, with penalty matrix ``Z.T @ K @ Z``, where ``K`` is the penalty matrix in the prior for ``coef``. See :meth:`.Basis.constrain` for more detailed documentation and Kneib et al. (2019) for an in-depth reference. See Also --------- .Basis.constrain : Uses this class to apply constraints. .StrctTerm.constrain : Uses this class to apply constraints. References ---------- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z """
[docs] @staticmethod def general(constraint: Array) -> Array: """ Reparameterization matrix for a general linear constraint ``constraint @ coef``. """ A = constraint nconstraints, _ = A.shape AtA = A.T @ A _, evecs = jnp.linalg.eigh(AtA) signs = jnp.sign(evecs[0, :]) signs = jnp.where(signs == 0, 1.0, signs) evecs = evecs * signs rank = A.shape[0] Abar = evecs[:, :-rank].T A_stacked = jnp.r_[A, Abar] C_stacked = jnp.linalg.inv(A_stacked) Cbar = C_stacked[:, nconstraints:] return Cbar
@classmethod def _nullspace(cls, penalty: Array, rank: float | Array | None = None) -> Array: if rank is None: rank = jnp.linalg.matrix_rank(penalty) evals, evecs = jnp.linalg.eigh(penalty) evals = evals[::-1] # put in decreasing order evecs = evecs[:, ::-1] # make order correspond to eigenvalues rank = jnp.sum(evals > 1e-6) if evecs[0, 0] < 0: evecs = -evecs U = evecs D = 1 / jnp.sqrt(jnp.ones_like(evals).at[:rank].set(evals[:rank])) Z = (U.T * jnp.expand_dims(D, 1)).T Abar = Z[:, :rank] return Abar
[docs] @classmethod def constant_and_linear(cls, x: Array, basis: Array) -> Array: """ Reparameterization matrix for removing a constant and a linear trend from a smooth like ``B(x) @ coef``. """ nobs = jnp.shape(x)[0] j = jnp.ones(shape=nobs) X = jnp.c_[j, x] A = jnp.linalg.inv(X.T @ X) @ X.T @ basis return cls.general(constraint=A)
[docs] @classmethod def sumzero_coef(cls, ncoef: int) -> Array: """ Reparameterization matrix for enforcing a constraint ``jnp.ones(...).T @ coef``. In other words, this applies a sum-to-zero constraint to the coefficient. """ j = jnp.ones(shape=(1, ncoef)) return cls.general(constraint=j)
[docs] @classmethod def sumzero_term(cls, basis: Array) -> Array: """ Reparameterization matrix for enforcing a constraint ``jnp.ones(...).T @ B(x) @ coef``. In other words, this applies a sum-to-zero-constraint to the full term. """ nobs = jnp.shape(basis)[0] j = jnp.ones(shape=nobs) A = jnp.expand_dims(j @ basis, 0) return cls.general(constraint=A)