MultivariateNormalStructured.get_locscale_constructor()#
- classmethod MultivariateNormalStructured.get_locscale_constructor(penalties, tol=1e-06, precompute_masks=True, validate_args=False, allow_nan_stats=True, include_normalizing_constant=True)[source]#
Creates a constructor for this distribution that takes a location array and an array of marginal scales.
- Parameters:
penalties (
Sequence[Array|ndarray|bool|number|bool|int|float|complex]) – Sequence of arrays, the penalty matrices of the marginal smooths.tol (
float, default:1e-06) – Tolerance used when computing the ranks of the marginal penalties and the pseudo log-determinant. Any eigenvalue of a marginal penalty smaller than this tolerance will be treated as zero.precompute_masks (
bool, default:True) – Whether to pre-compute and store the indices referring to non-zero and zero eigenvalues, are compare the eigenvalues to the tolerance in every function call at runtime. Pre-computing is safer, because this allows us to run an additional check whether the zero- and non-zero eigenvalues are successfully distinguished by the tolrance.include_normalizing_constant (
bool, default:True) – Whether to include the normalizing constant when computing the log density in the.log_probmethod. If True, the returned log probability will be equal to the log probability returned by aMultivariateNormalFullCovariance, if all marginal penalty matrices are of full rank.
- Return type:
Callable[[Array|ndarray|bool|number|bool|int|float|complex,Array|ndarray|bool|number|bool|int|float|complex],MultivariateNormalStructured]- Returns:
The returned constructor takes the following arguments:
loc: Location array with shape (B, J), where B is the batch shape and J is the event shape.
scales: Array of scales for the marginal smooths, has shape (B, M) where B is the batch shape and M is the number of marginal smooths.
Examples
>>> import jax >>> import jax.numpy as jnp >>> from liesel.contrib.splines import pspline_penalty >>> import liesel_gam as gam
>>> K1 = pspline_penalty(6) >>> K2 = pspline_penalty(8)
>>> MVNDStrct = gam.MultivariateNormalStructured.get_locscale_constructor( ... [K1, K2] ... )
>>> n = K1.shape[-1] * K2.shape[-1] >>> loc = jnp.zeros(n) >>> scales = jnp.array([1.0, 2.0])
>>> dist = MVNDStrct(loc=loc, scales=scales) >>> dist.log_prob(jnp.zeros(n)).round(1) Array(-22.2, dtype=float32)
Draw some random samples from the stochastic part:
>>> xnew = dist.sample(sample_shape=(2,), seed=jax.random.key(1)) >>> xnew.shape (2, 48)