MultivariateNormalStructured

MultivariateNormalStructured#

class liesel_gam.MultivariateNormalStructured(loc, op, validate_args=False, allow_nan_stats=True, name='MultivariateNormalStructured', include_normalizing_constant=True)[source]#

Bases: Distribution

Potentially rank-deficient multivariate Gaussian distribution for the prior used in structured tensor product terms.

Implements the tfp.distributions.Distribution interface.

Parameters:
  • loc (Array | ndarray | bool | number | bool | int | float | complex) – Location array with shape (B, J), where B is the batch shape and J is the event shape.

  • op (StructuredPenaltyOperator) – A structured penalty operator for efficient computation of the pseudo log-determinant and quadratic form.

  • include_normalizing_constant (bool, default: True) – Whether to include the normalizing constant when computing the log density in the .log_prob method. If True, the returned log probability will be equal to the log probability returned by a MultivariateNormalFullCovariance, if all marginal penalty matrices are of full rank.

Notes

This distribution is the prior used for the coefficient vector in StrctTensorProdTerm.

It is a potentially rank-deficient multivariate Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as

\[p(\boldsymbol{\beta} | \boldsymbol{\tau}^2) \propto \operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2} \exp \left( - \frac{1}{2} \boldsymbol{\beta}^\top \mathbf{K}(\boldsymbol{\tau}^2) \boldsymbol{\beta} \right),\]

with the precision matrix constructed from marginal penalties \(\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M\) and variance parameters \(\tau^2_1,\dots, \tau^2_M\) as

\[\mathbf{K}(\boldsymbol{\tau}^2) = \frac{\mathbf{K}_1}{\tau^2_1} + \cdots + \frac{\mathbf{K}_M}{\tau^2_M},\]

where

\[\mathbf{K}_m = \mathbf{I}_{J_1} \otimes \cdots \otimes \mathbf{I}_{J_{m-1}} \otimes \tilde{\mathbf{K}}_m \otimes \mathbf{I}_{J_{m+1}} \otimes \cdots \mathbf{I}_{J_{M}},\]

and \(\mathbf{I}_{J_m}\) denotes the identity matrix of dimension \(J_m \times J_m\).

Since \(\mathbf{K}(\boldsymbol{\tau}^2)\) may be rank-deficient, \(\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))\) is the pseudo-determinant, or generalized determinant.

This class exploits the clearly defined structure of the precision matrix to obtain a computationally and memory-efficient evaluation of the prior. We also implement the results obtained by Bach & Klein (2025) for efficiently computing the pseudo-determinant; a key prerequisite for making higher-dimensional tensor products feasible.

Sampling from this distribution.

Sampling from this distribution is implemented, but note that, if \(\mathbf{K}(\boldsymbol{\tau}^2)\) is rank-deficient, samples are drawn only from the stochastic part of the distribution; the constant part will remain fixed to zero. For sampling, we use a generalized inverse of the potentially rank-deficient precision matrix \(\mathbf{K}(\boldsymbol{\tau}^2)\).

References

  • Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z

  • Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43. https://doi.org/10.1007/s11222-025-10569-y

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from liesel.contrib.splines import pspline_penalty
>>> import liesel_gam as gam
>>> K1 = pspline_penalty(6)
>>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor([K1, K2])
>>> n = K1.shape[-1] * K2.shape[-1]
>>> loc = jnp.zeros(n)
>>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales)
>>> dist.log_prob(jnp.zeros(n)).round(1)
Array(-22.2, dtype=float32)

Draw some random samples from the stochastic part:

>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1))
>>> xnew.shape
(2, 48)

Methods

from_penalties

Initializes the distribution directly from marginal scales and penalties (computationally expensive).

get_locscale_constructor

Creates a constructor for this distribution that takes a location array and an array of marginal scales.