MultivariateNormalStructured.get_locscale_constructor()

MultivariateNormalStructured.get_locscale_constructor()#

classmethod MultivariateNormalStructured.get_locscale_constructor(penalties, tol=1e-06, precompute_masks=True, validate_args=False, allow_nan_stats=True, include_normalizing_constant=True)[source]#

Creates a constructor for this distribution that takes a location array and an array of marginal scales.

Parameters:
  • penalties (Sequence[Array | ndarray | bool | number | bool | int | float | complex]) – Sequence of arrays, the penalty matrices of the marginal smooths.

  • tol (float, default: 1e-06) – Tolerance used when computing the ranks of the marginal penalties and the pseudo log-determinant. Any eigenvalue of a marginal penalty smaller than this tolerance will be treated as zero.

  • precompute_masks (bool, default: True) – Whether to pre-compute and store the indices referring to non-zero and zero eigenvalues, are compare the eigenvalues to the tolerance in every function call at runtime. Pre-computing is safer, because this allows us to run an additional check whether the zero- and non-zero eigenvalues are successfully distinguished by the tolrance.

  • include_normalizing_constant (bool, default: True) – Whether to include the normalizing constant when computing the log density in the .log_prob method. If True, the returned log probability will be equal to the log probability returned by a MultivariateNormalFullCovariance, if all marginal penalty matrices are of full rank.

Return type:

Callable[[Array | ndarray | bool | number | bool | int | float | complex, Array | ndarray | bool | number | bool | int | float | complex], MultivariateNormalStructured]

Returns:

The returned constructor takes the following arguments:

  • loc: Location array with shape (B, J), where B is the batch shape and J is the event shape.

  • scales: Array of scales for the marginal smooths, has shape (B, M) where B is the batch shape and M is the number of marginal smooths.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from liesel.contrib.splines import pspline_penalty
>>> import liesel_gam as gam
>>> K1 = pspline_penalty(6)
>>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor(
...     [K1, K2]
... )
>>> n = K1.shape[-1] * K2.shape[-1]
>>> loc = jnp.zeros(n)
>>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales)
>>> dist.log_prob(jnp.zeros(n)).round(1)
Array(-22.2, dtype=float32)

Draw some random samples from the stochastic part:

>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1))
>>> xnew.shape
(2, 48)