MultivariateNormalStructured#
- class liesel_gam.MultivariateNormalStructured(loc, op, validate_args=False, allow_nan_stats=True, name='MultivariateNormalStructured', include_normalizing_constant=True)[source]#
Bases:
DistributionPotentially rank-deficient multivariate Gaussian distribution for the prior used in structured tensor product terms.
Implements the
tfp.distributions.Distributioninterface.- Parameters:
loc (
Array|ndarray|bool|number|bool|int|float|complex) – Location array with shape (B, J), where B is the batch shape and J is the event shape.op (
StructuredPenaltyOperator) – A structured penalty operator for efficient computation of the pseudo log-determinant and quadratic form.include_normalizing_constant (
bool, default:True) – Whether to include the normalizing constant when computing the log density in the.log_probmethod. If True, the returned log probability will be equal to the log probability returned by aMultivariateNormalFullCovariance, if all marginal penalty matrices are of full rank.
Notes
This distribution is the prior used for the coefficient vector in
StrctTensorProdTerm.It is a potentially rank-deficient multivariate Gaussian prior, which, in the notation of Bach & Klein (2025), can be written as
\[p(\boldsymbol{\beta} | \boldsymbol{\tau}^2) \propto \operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))^{1/2} \exp \left( - \frac{1}{2} \boldsymbol{\beta}^\top \mathbf{K}(\boldsymbol{\tau}^2) \boldsymbol{\beta} \right),\]with the precision matrix constructed from marginal penalties \(\tilde{\mathbf{K}}_1, \dots, \tilde{\mathbf{K}}_M\) and variance parameters \(\tau^2_1,\dots, \tau^2_M\) as
\[\mathbf{K}(\boldsymbol{\tau}^2) = \frac{\mathbf{K}_1}{\tau^2_1} + \cdots + \frac{\mathbf{K}_M}{\tau^2_M},\]where
\[\mathbf{K}_m = \mathbf{I}_{J_1} \otimes \cdots \otimes \mathbf{I}_{J_{m-1}} \otimes \tilde{\mathbf{K}}_m \otimes \mathbf{I}_{J_{m+1}} \otimes \cdots \mathbf{I}_{J_{M}},\]and \(\mathbf{I}_{J_m}\) denotes the identity matrix of dimension \(J_m \times J_m\).
Since \(\mathbf{K}(\boldsymbol{\tau}^2)\) may be rank-deficient, \(\operatorname{Det}(\mathbf{K}(\boldsymbol{\tau}^2))\) is the pseudo-determinant, or generalized determinant.
This class exploits the clearly defined structure of the precision matrix to obtain a computationally and memory-efficient evaluation of the prior. We also implement the results obtained by Bach & Klein (2025) for efficiently computing the pseudo-determinant; a key prerequisite for making higher-dimensional tensor products feasible.
Sampling from this distribution.
Sampling from this distribution is implemented, but note that, if \(\mathbf{K}(\boldsymbol{\tau}^2)\) is rank-deficient, samples are drawn only from the stochastic part of the distribution; the constant part will remain fixed to zero. For sampling, we use a generalized inverse of the potentially rank-deficient precision matrix \(\mathbf{K}(\boldsymbol{\tau}^2)\).
References
Kneib, T., Klein, N., Lang, S., & Umlauf, N. (2019). Modular regression—A Lego system for building structured additive distributional regression models with tensor product interactions. TEST, 28(1), 1–39. https://doi.org/10.1007/s11749-019-00631-z
Bach, P., & Klein, N. (2025). Anisotropic multidimensional smoothing using Bayesian tensor product P-splines. Statistics and Computing, 35(2), 43. https://doi.org/10.1007/s11222-025-10569-y
Examples
>>> import jax >>> import jax.numpy as jnp >>> from liesel.contrib.splines import pspline_penalty >>> import liesel_gam as gam
>>> K1 = pspline_penalty(6) >>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor([K1, K2])
>>> n = K1.shape[-1] * K2.shape[-1] >>> loc = jnp.zeros(n) >>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales) >>> dist.log_prob(jnp.zeros(n)).round(1) Array(-22.2, dtype=float32)
Draw some random samples from the stochastic part:
>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1)) >>> xnew.shape (2, 48)
Methods
Initializes the distribution directly from marginal scales and penalties (computationally expensive).
Creates a constructor for this distribution that takes a location array and an array of marginal scales.