from collections.abc import Callable, Sequence
from functools import cached_property, reduce
from math import prod
from typing import Self
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.substrates.jax import tf2jax as tf
from tensorflow_probability.substrates.jax.internal.parameter_properties import (
ParameterProperties,
)
Array = jax.typing.ArrayLike
[docs]
class MultivariateNormalSingular(tfd.Distribution):
r"""
Potentially rank-deficient multivariate Gaussian distribution used as a prior in
structured additive terms.
Implements the :class:`tfp.distributions.Distribution` interface.
Parameters
-----------
loc
Location array of shape (B, J) where B is the batching dimension and J is the
number of coefficients; the event shape of this distribution.
scale
Scale array of shape (B,), where B is the batching dimension. This is
:math:`\tau = \sqrt{\tau^2}`.
penalty
Penalty matrix of shape (B, J, J), where B is the batching dimension and J is
the number of coefficients; the event shape of this distribution.
penalty_rank
Array of shape (B,), where B is the batching dimension, giving the rank of the
penalty matrix. This is required in addition to the penalty for computational
efficiency, to avoid re-computing the rank of the penalty matrix in each
execution.
See Also
--------
.StrctTerm : Structured additive term object, a Liesel var that uses this
distribution.
.MultivariateNormalStructured : More general distribution, implementing the prior
for tensor product terms.
Notes
-----
The coefficient :math:`\boldsymbol{\beta} \in \mathbb{R}^{J \times 1}`
in a structured additive term receives a potentially rank-deficient
multivariate normal prior
.. math::
p(\boldsymbol{\beta}) \propto
\left(\frac{1}{\tau^2}\right)^{\operatorname{rk}(\mathbf{K})/2} \exp \left( -
\frac{1}{\tau^2} \boldsymbol{\beta}^\top \mathbf{K} \boldsymbol{\beta} \right)
with the potentially rank-deficient penalty matrix :math:`\mathbf{K}` of rank
:math:`\operatorname{rk}(\mathbf{K})`. The variance parameter :math:`\tau^2` acts as
an inverse smoothing parameter.
.. rubric:: Sampling from this distribution.
Sampling from this distribution is implemented, but note that, if
:math:`\mathbf{K}` is rank-deficient, samples are drawn only
from the stochastic part of the distribution; the constant part will remain fixed
to zero. For sampling, we use a generalized inverse of the potentially
rank-deficient penalty matrix :math:`\mathbf{K}`.
References
----------
- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego
system for building structured additive distributional regression models with
tensor product interactions. TEST, 28(1), 1–39.
https://doi.org/10.1007/s11749-019-00631-z
"""
def __init__(
self,
loc: Array,
scale: Array,
penalty: Array,
penalty_rank: Array,
validate_args: bool = False,
allow_nan_stats: bool = True,
name: str = "MultivariateNormalSingular",
):
parameters = dict(locals())
self._loc = jnp.asarray(loc)
self._scale = jnp.asarray(scale)
self._penalty = jnp.asarray(penalty)
self._penalty_rank = jnp.asarray(penalty_rank)
super().__init__(
dtype=self._loc.dtype,
reparameterization_type=tfd.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name,
)
@classmethod
def _parameter_properties(cls, dtype=jnp.float32, num_classes=None):
return dict(
loc=ParameterProperties(event_ndims=1),
scale=ParameterProperties(event_ndims=0),
penalty=ParameterProperties(event_ndims=2),
penalty_rank=ParameterProperties(event_ndims=0),
)
def _event_shape(self):
return tf.TensorShape((jnp.shape(self._penalty)[-1],))
def _event_shape_tensor(self):
return jnp.array((jnp.shape(self._penalty)[-1],))
def _log_prob(self, x: Array) -> Array:
x_centered = x - self._loc
# The following lines illustrate what the jnp.einsum call is conceptually
# doing.
# xt = jnp.expand_dims(x, axis=-2) # [batch_dims, 1, event_dim]
# x = jnp.swapaxes(xt, -2, -1) # [batch_dims, event_dim, 1]
# quad_form = jnp.squeeze((xt @ self._penalty @ x))
quad_form = jnp.einsum(
"...i,...ij,...j->...", x_centered, self._penalty, x_centered
)
neg_kernel = 0.5 * quad_form * jnp.power(self._scale, -2.0)
return -(jnp.log(self._scale) * self._penalty_rank + neg_kernel)
def _sample_n(self, n, seed=None) -> Array:
shape = [n] + self.batch_shape + self.event_shape
# The added dimension at the end here makes sure that matrix multiplication
# with the "sqrt pcov" matrices works out correctly.
z = jax.random.normal(key=seed, shape=shape + [1])
# Add a dimension at 0 for the sample size.
sqrt_cov = jnp.expand_dims(self._sqrt_cov, 0)
centered_samples = jnp.reshape(sqrt_cov @ z, shape)
# Add a dimension at 0 for the sample size.
loc = jnp.expand_dims(self._loc, 0)
scale = jnp.expand_dims(self._scale, 0)
return scale * centered_samples + loc
@cached_property
def _sqrt_cov(self) -> Array:
eigenvalues, evecs = jnp.linalg.eigh(self._penalty)
sqrt_eval = jnp.sqrt(1 / eigenvalues)
sqrt_eval = sqrt_eval.at[: -self._penalty_rank].set(0.0)
event_shape = sqrt_eval.shape[-1]
shape = sqrt_eval.shape + (event_shape,)
r = tuple(range(event_shape))
diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
return evecs @ diags
def _diag_of_kron_of_diag_with_identities(in_diags: Sequence[jax.Array]) -> jax.Array:
sizes = [v.shape[-1] for v in in_diags]
diag = jnp.zeros(prod(sizes))
for j in range(len(in_diags)):
left_size = prod(sizes[:j])
right_size = prod(sizes[(j + 1) :])
d = in_diags[j]
# First handle identities to the right: repeat
d_rep = jnp.repeat(d, right_size)
# Then handle identities to the left: tile
d_tile = jnp.tile(d_rep, left_size)
diag = diag + d_tile
return diag
def _materialize_precision(penalties: Sequence[jax.Array]) -> jax.Array:
"""
Build K(tau^2) = sum_{j=1}^p K_j / tau_j^2
with K_j = I_{d1} ⊗ ... ⊗ I_{dj-1} ⊗ Ktilde_j ⊗ I_{dj+1} ⊗ ... ⊗ I_{dp}.
Parameters
----------
tau2 : array-like, shape (p,)
Squared smoothing parameters (tau_1^2, ..., tau_p^2).
K_tilde : sequence of arrays
List/tuple of p matrices, K_tilde[j] has shape (d_j, d_j).
dims : sequence of ints, optional
d_j for each dimension. If None, inferred from K_tilde.
Returns
-------
K : jnp.ndarray, shape (∏ d_j, ∏ d_j)
"""
p = len(penalties)
dims = [penalties[j].shape[-1] for j in range(p)]
# Build K_1 / tau1^2 as initial value
factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
def kron_all(mats):
"""Kronecker product of a list of matrices."""
return reduce(jnp.kron, mats)
K = kron_all(factors)
# Add remaining K_j / tau_j^2
for j in range(1, p):
factors = [penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)]
Kj = kron_all(factors)
K = K + Kj
return K
def _compute_masks(
penalties: Sequence[jax.Array],
penalties_eigvalues: Sequence[jax.Array],
eps: float = 1e-6,
) -> jax.Array:
diag = _diag_of_kron_of_diag_with_identities
B = penalties_eigvalues[0].shape[:-1]
B_flat = int(jnp.prod(jnp.array(B))) if B else 1
flat_evs = [ev.reshape(B_flat, ev.shape[-1]) for ev in penalties_eigvalues]
diags = jax.vmap(diag)(flat_evs) # (B_flat, N)
K = _materialize_precision(penalties) # (B, N, N)
K = K.reshape((B_flat,) + K.shape[-2:])
ranks = jax.vmap(jnp.linalg.matrix_rank)(K) # (B_flat,)
masks = (diags > eps).sum(-1) # (B_flat,)
if not jnp.allclose(masks, ranks):
raise ValueError(
f"Number of zero eigenvalues ({masks}) does not "
f"correspond to penalty rank ({ranks}). Maybe a different value for "
f"{eps=} can help."
)
mask = diags > eps
return mask.reshape(B + (mask.shape[-1],))
def _apply_Ki_along_axis(B, K, axis, dims, total_size):
"""
Apply K (Di x Di) along a given axis of B (shape dims),
returning an array with the same shape as B.
"""
# Move the axis we want to the front: (Di, rest...)
B_perm = jnp.moveaxis(B, axis, 0)
Di = dims[axis]
rest = total_size // Di
# Flatten everything except that axis: (Di, rest)
B_flat = B_perm.reshape(Di, rest)
# Matrix multiply: K @ B_flat -> (Di, rest)
C_flat = K @ B_flat
# Restore original shape/order
C_perm = C_flat.reshape(B_perm.shape)
C = jnp.moveaxis(C_perm, 0, axis) # back to shape = dims
return C
def _kron_sum_quadratic(x: jax.Array, Ks: Sequence[jax.Array]) -> jax.Array:
dims = [K.shape[0] for K in Ks]
# Basic sanity checks (cheap, can remove if you like)
for K, d in zip(Ks, dims):
assert K.shape == (d, d)
total_size = prod(dims)
assert x.size == total_size
# Reshape x into m-dimensional tensor
B = x.reshape(dims)
total = jnp.array(0.0, dtype=x.dtype)
for axis, K in enumerate(Ks):
C = _apply_Ki_along_axis(B, K, axis, dims, total_size)
total = total + jnp.vdot(B, C) # scalar
return total
[docs]
class StructuredPenaltyOperator:
"""
Operator for efficiently computing the pseudo log determinant and
quadratic form of a structured tensor product precision matrix.
Parameters
----------
scales
Array of shape (B, M), where B is the batch shape and M is the number of
marginal penalties
penalties
Sequence of length M, containing array of shape (B, Jm, Jm), where
B is the batch shape, and Jm is the block size of the individual penalty.
This block size can differ between elements of the penalty sequence.
penalties_eigenvalues
Sequence of arrays of shape (B, Jm), containing the eigenvalues of the
marignal penalties.
masks
Boolean array of shape (B, J) where B is the batch shape and
``J = prod(J1, ... JM)``
is the dimension of the overall precision matrix. The mask indicates the
nonzero eigenvalues of the overall precision matrix,
where a ``True`` entry means that the corresponding eigenvalue is nonzero.
If ``None``, will be inferred from ``penalties_eigenvalues`` and ``tol``
at runtime.
tol
Used to infer ``masks`` at runtime, if ``masks=None`` was passed.
"""
def __init__(
self,
scales: jax.Array,
penalties: Sequence[jax.Array],
penalties_eigvalues: Sequence[jax.Array],
masks: jax.Array | None = None,
validate_args: bool = False,
tol: float = 1e-6,
) -> None:
self._scales = jnp.asarray(scales)
self._penalties = tuple([jnp.asarray(p) for p in penalties])
self._penalties_eigvalues = tuple(
[jnp.asarray(ev) for ev in penalties_eigvalues]
)
if validate_args:
self._validate_penalties()
self._sizes = [K.shape[-1] for K in self._penalties]
self._masks = masks
self._tol = tol
[docs]
@classmethod
def from_penalties(
cls,
scales: jax.Array,
penalties: Sequence[jax.Array],
eps: float = 1e-6,
validate_args: bool = False,
) -> Self:
"""
Shortcut for initializing the operator from marignal scales and penalties
(expensive).
Computes the eigenvalue decompositions of the marginal penalties and the
``mask`` at runtime, so this is not appropriate for repeated execution if
computational efficiency is a concern.
"""
evs = [jnp.linalg.eigh(K) for K in penalties]
evals = [ev.eigenvalues for ev in evs]
masks = _compute_masks(penalties=penalties, penalties_eigvalues=evals, eps=eps)
return cls(
scales=scales,
penalties=penalties,
penalties_eigvalues=evals,
masks=masks,
validate_args=validate_args,
)
[docs]
@cached_property
def variances(self) -> jax.Array:
"""
Array of marginal variances (squares of marginal scales).
"""
return jnp.square(self._scales)
[docs]
def materialize_precision(self) -> jax.Array:
"""
Materializes the full J x J overall precision matrix.
Usually not necessary, mostly useful for testing and checking.
"""
return self._materialize_precision(self.variances)
[docs]
def materialize_penalty(self) -> jax.Array:
"""
Materializes the full J x J overall precision matrix.
Usually not necessary, mostly useful for testing and checking.
"""
return self._materialize_precision(jnp.ones_like(self.variances))
def _materialize_precision(self, variances: jax.Array) -> jax.Array:
"""This is inefficient, should be used for testing only."""
p = len(self._penalties)
dims = [self._penalties[j].shape[-1] for j in range(p)]
def kron_all(mats):
"""Kronecker product of a list of matrices."""
return reduce(jnp.kron, mats)
def one_batch(variances, *penalties):
# Build K_1 / tau1^2 as initial value
factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
K = kron_all(factors) / variances[0]
# Add remaining K_j / tau_j^2
for j in range(1, p):
factors = [
penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)
]
Kj = kron_all(factors)
K = K + Kj / variances[j]
return K
batch_shape = variances.shape[:-1]
K = variances.shape[-1]
# flatten batch dims so we can vmap over a single leading dim
B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
pens_flat = [p.reshape((B_flat,) + p.shape[-2:]) for p in self._penalties]
big_K_fun = jax.vmap(one_batch, in_axes=(0,) + (0,) * K)
big_K = big_K_fun(tau2_flat, *pens_flat)
N = prod(dims)
big_K = jnp.reshape(big_K, batch_shape + (N, N))
return big_K
def _sum_of_scaled_eigenvalues(
self, variances: jax.Array, eigenvalues: Sequence[jax.Array]
) -> jax.Array:
"""
Expects
- variances (p,)
- eigenvalues (p, Di)
Returns (N,) where N = prod(Di)
"""
diag = jnp.zeros(prod(self._sizes))
for j in range(len(self._penalties)):
left_size = prod(self._sizes[:j])
right_size = prod(self._sizes[(j + 1) :])
d = eigenvalues[j] / variances[j]
# First handle identities to the right: repeat
d_rep = jnp.repeat(d, right_size)
# Then handle identities to the left: tile
d_tile = jnp.tile(d_rep, left_size)
diag = diag + d_tile
return diag
[docs]
def log_pdet(self) -> jax.Array:
"""
Efficiently computes the log pseudo-determinant of the overall precision matrix.
Implements the result obtained by Bach & Klein (2025).
Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using
Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43.
https://doi.org/10.1007/s11222-025-10569-y
"""
variances = self.variances # shape (B..., K)
batch_shape = variances.shape[:-1]
K = variances.shape[-1]
# flatten batch dims so we can vmap over a single leading dim
B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
# eigenvalues per penalty, flattened over batch
eigvals_flat = [
ev.reshape(B_flat, ev.shape[-1]) for ev in self._penalties_eigvalues
] # list of K arrays (B_flat, Di)
def _single_diag(variances, *eigenvalues):
diag = self._sum_of_scaled_eigenvalues(variances, eigenvalues)
return diag
# vmap over flattened batch dimension
diag_flat = jax.vmap(_single_diag, in_axes=(0,) + (0,) * K)(
tau2_flat,
*eigvals_flat,
)
diag = jnp.reshape(diag_flat, batch_shape + (diag_flat.shape[-1],))
if self._masks is None:
mask = diag > self._tol
else:
mask = self._masks
logdet = jnp.log(jnp.where(mask, diag, 1.0)).sum(-1)
return logdet
def _validate_penalties(self) -> None:
# validate number of penalty matrices
n_penalties1 = self._scales.shape[-1]
n_penalties2 = len(self._penalties)
n_penalties3 = len(self._penalties_eigvalues)
if not len({n_penalties1, n_penalties2, n_penalties3}) == 1:
msg1 = "Got inconsistent numbers of parameters. "
msg2 = f"Number of scale parameters: {n_penalties1} "
msg3 = f"Number of penalty matrices: {n_penalties2}. "
msg4 = f"Number of eigenvalue vectors: {n_penalties3}. "
raise ValueError(msg1 + msg2 + msg3 + msg4)
[docs]
class MultivariateNormalStructured(tfd.Distribution):
r"""
Potentially rank-deficient multivariate Gaussian distribution for the prior used in
structured tensor product terms.
Implements the :class:`tfp.distributions.Distribution` interface.
Parameters
----------
loc
Location array with shape (B, J), where B is the batch shape and J is the
event shape.
op
A structured penalty operator for efficient computation of the
pseudo log-determinant and quadratic form.
include_normalizing_constant
Whether to include the normalizing constant when computing the
log density in the ``.log_prob`` method. If True, the returned log
probability will be equal to the log probability returned by a
``MultivariateNormalFullCovariance``, if all marginal penalty matrices
are of full rank.
Notes
-----
This distribution is the prior used for the coefficient vector
in :class:`.StrctTensorProdTerm`.
It is a potentially rank-deficient multivariate
Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as
.. math::
p(\boldsymbol{\beta} | \boldsymbol{\tau}^2)
\propto
\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2}
\exp \left(
- \frac{1}{2}
\boldsymbol{\beta}^\top
\mathbf{K}(\boldsymbol{\tau}^2)
\boldsymbol{\beta}
\right),
with the precision matrix constructed from marginal penalties
:math:`\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M`
and variance parameters :math:`\tau^2_1,\dots, \tau^2_M` as
.. math::
\mathbf{K}(\boldsymbol{\tau}^2)
= \frac{\mathbf{K}_1}{\tau^2_1}
+
\cdots
+
\frac{\mathbf{K}_M}{\tau^2_M},
where
.. math::
\mathbf{K}_m = \mathbf{I}_{J_1}
\otimes \cdots \otimes
\mathbf{I}_{J_{m-1}}
\otimes
\tilde{\mathbf{K}}_m
\otimes
\mathbf{I}_{J_{m+1}}
\otimes
\cdots
\mathbf{I}_{J_{M}},
and :math:`\mathbf{I}_{J_m}` denotes the identity matrix of dimension
:math:`J_m \times J_m`.
Since :math:`\mathbf{K}(\boldsymbol{\tau}^2)` may be rank-deficient,
:math:`\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))` is the
pseudo-determinant, or generalized determinant.
This class exploits the clearly defined structure of the precision matrix
to obtain a computationally and memory-efficient evaluation of the prior.
We also implement the results obtained by Bach & Klein (2025) for efficiently
computing the pseudo-determinant; a key prerequisite for making higher-dimensional
tensor products feasible.
.. rubric:: Sampling from this distribution.
Sampling from this distribution is implemented, but note that, if
:math:`\mathbf{K}(\boldsymbol{\tau}^2)` is rank-deficient, samples are drawn only
from the stochastic part of the distribution; the constant part will remain fixed
to zero. For sampling, we use a generalized inverse of the potentially
rank-deficient precision matrix :math:`\mathbf{K}(\boldsymbol{\tau}^2)`.
References
----------
- Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego
system for building structured additive distributional regression models with
tensor product interactions. TEST, 28(1), 1–39.
https://doi.org/10.1007/s11749-019-00631-z
- Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using
Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43.
https://doi.org/10.1007/s11222-025-10569-y
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from liesel.contrib.splines import pspline_penalty
>>> import liesel_gam as gam
>>> K1 = pspline_penalty(6)
>>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor([K1, K2])
>>> n = K1.shape[-1] * K2.shape[-1]
>>> loc = jnp.zeros(n)
>>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales)
>>> dist.log_prob(jnp.zeros(n)).round(1)
Array(-22.2, dtype=float32)
Draw some random samples from the stochastic part:
>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1))
>>> xnew.shape
(2, 48)
"""
def __init__(
self,
loc: Array,
op: StructuredPenaltyOperator,
validate_args: bool = False,
allow_nan_stats: bool = True,
name: str = "MultivariateNormalStructured",
include_normalizing_constant: bool = True,
):
parameters = dict(locals())
self._loc = jnp.asarray(loc)
self._op = op
self._n = self._loc.shape[-1]
self._include_normalizing_constant = include_normalizing_constant
if validate_args:
self._op._validate_penalties()
self._validate_event_dim()
super().__init__(
dtype=self._loc.dtype,
reparameterization_type=tfd.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name,
)
def _validate_event_dim(self) -> None:
# validate sample size
n_loc = self._loc.shape[-1]
ndim_penalties = [p.shape[-1] for p in self._op._penalties]
n_penalties = prod(ndim_penalties)
if not n_loc == n_penalties:
msg1 = "Got inconsistent event dimensions. "
msg2 = f"Event dimension implied by loc: {n_loc}. "
msg3 = f"Event dimension implied by penalties: {n_penalties}"
raise ValueError(msg1 + msg2 + msg3)
def _batch_shape(self):
variances = self._op.variances # shape (B..., K)
batch_shape = tuple(variances.shape[:-1])
return tf.TensorShape(batch_shape)
def _batch_shape_tensor(self):
variances = self._op.variances # shape (B..., K)
batch_shape = tuple(variances.shape[:-1])
return jnp.array(batch_shape)
def _event_shape(self):
return tf.TensorShape((jnp.shape(self._loc)[-1],))
def _event_shape_tensor(self):
return jnp.array((jnp.shape(self._loc)[-1],), dtype=self._loc.dtype)
def _log_prob(self, x: Array) -> jax.Array:
x = jnp.asarray(x)
x_centered = x - self._loc
log_pdet = self._op.log_pdet()
quad_form = self._op.quad_form(x_centered)
# early returns, minimally more efficient
if not self._include_normalizing_constant:
return 0.5 * (log_pdet - quad_form)
const = -(self._n / 2) * jnp.log(2 * jnp.pi)
return 0.5 * (log_pdet - quad_form) + const
[docs]
@classmethod
def from_penalties(
cls,
loc: Array,
scales: Array,
penalties: Sequence[Array],
tol: float = 1e-6,
precompute_masks: bool = True,
validate_args: bool = False,
allow_nan_stats: bool = True,
include_normalizing_constant: bool = True,
) -> Self:
"""
Initializes the distribution directly from marginal scales and penalties
(computationally expensive).
This is expensive, because it computes eigenvalue decompositions of all
penalty matrices. Should only be used when performance is irrelevant.
"""
constructor = cls.get_locscale_constructor(
penalties=penalties,
tol=tol,
precompute_masks=precompute_masks,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
include_normalizing_constant=include_normalizing_constant,
)
return constructor(loc, scales)
[docs]
@classmethod
def get_locscale_constructor(
cls,
penalties: Sequence[Array],
tol: float = 1e-6,
precompute_masks: bool = True,
validate_args: bool = False,
allow_nan_stats: bool = True,
include_normalizing_constant: bool = True,
) -> Callable[[Array, Array], "MultivariateNormalStructured"]:
"""
Creates a constructor for this distribution that takes a location array and
an array of marginal scales.
Parameters
----------
penalties
Sequence of arrays, the penalty matrices of the marginal smooths.
tol
Tolerance used when computing the ranks of the marginal penalties and the
pseudo log-determinant. Any eigenvalue of a marginal penalty smaller than
this tolerance will be treated as zero.
precompute_masks
Whether to pre-compute and store the indices referring to non-zero
and zero eigenvalues, are compare the eigenvalues to the tolerance
in every function call at runtime. Pre-computing is safer, because this
allows us to run an additional check whether the zero- and non-zero
eigenvalues are successfully distinguished by the tolrance.
include_normalizing_constant
Whether to include the normalizing constant when computing the
log density in the ``.log_prob`` method. If True, the returned log
probability will be equal to the log probability returned by a
``MultivariateNormalFullCovariance``, if all marginal penalty matrices
are of full rank.
Returns
--------
The returned constructor takes the following arguments:
- *loc*: Location array with shape (B, J), where B is the batch shape and
J is the event shape.
- *scales*: Array of scales for the marginal smooths, has shape (B, M) where
B is the batch shape and M is the number of marginal smooths.
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from liesel.contrib.splines import pspline_penalty
>>> import liesel_gam as gam
>>> K1 = pspline_penalty(6)
>>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor(
... [K1, K2]
... )
>>> n = K1.shape[-1] * K2.shape[-1]
>>> loc = jnp.zeros(n)
>>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales)
>>> dist.log_prob(jnp.zeros(n)).round(1)
Array(-22.2, dtype=float32)
Draw some random samples from the stochastic part:
>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1))
>>> xnew.shape
(2, 48)
"""
penalties_ = [jnp.asarray(p) for p in penalties]
evs = [jnp.linalg.eigh(K) for K in penalties]
evals = [ev.eigenvalues for ev in evs]
if precompute_masks:
masks = _compute_masks(
penalties=penalties_, penalties_eigvalues=evals, eps=tol
)
else:
masks = None
def construct_dist(loc: Array, scales: Array) -> "MultivariateNormalStructured":
loc = jnp.asarray(loc)
scales = jnp.asarray(scales)
op = StructuredPenaltyOperator(
scales=scales,
penalties=penalties_,
penalties_eigvalues=evals,
masks=masks,
tol=tol,
)
dist = cls(
loc=loc,
op=op,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
include_normalizing_constant=include_normalizing_constant,
)
return dist
return construct_dist
@cached_property
def _sqrt_cov(self) -> Array:
# TODO this is inefficient, because it does not make use of the sparsity
# at the moment, at least it makes sampling possible. But we should update it.
prec = self._op.materialize_precision()
eigenvalues, evecs = jnp.linalg.eigh(prec)
sqrt_eval = jnp.sqrt(1 / eigenvalues)
if self._op._masks is None:
raise ValueError("self._op_masks is None, but need pre-computed masks.")
sqrt_eval = sqrt_eval.at[..., ~self._op._masks].set(0.0)
event_shape = sqrt_eval.shape[-1]
shape = sqrt_eval.shape + (event_shape,)
r = tuple(range(event_shape))
diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
return evecs @ diags
def _sample_n(self, n, seed=None) -> Array:
shape = [n] + self.batch_shape + self.event_shape
# The added dimension at the end here makes sure that matrix multiplication
# with the "sqrt pcov" matrices works out correctly.
z = jax.random.normal(key=seed, shape=shape + [1])
# Add a dimension at 0 for the sample size.
sqrt_cov = jnp.expand_dims(self._sqrt_cov, 0)
centered_samples = jnp.reshape(sqrt_cov @ z, shape)
# Add a dimension at 0 for the sample size.
loc = jnp.expand_dims(self._loc, 0)
return centered_samples + loc