Skip to content

Multivariate Normal Distribution

lerax.distribution.MultivariateNormalDiag

Bases: AbstractDistreqxWrapper[Float[Array, ' dims']]

Multivariate Normal distribution with diagonal covariance.

Attributes:

Name Type Description
distribution distributions.MultivariateNormalDiag

The underlying distreqx MultivariateNormalDiag distribution.

Parameters:

Name Type Description Default
loc Float[ArrayLike, ' dims'] | None

The mean of the distribution.

None
scale_diag Float[ArrayLike, ' dims'] | None

The diagonal of the covariance matrix.

None

distribution instance-attribute

distribution: distributions.MultivariateNormalDiag = (
    distributions.MultivariateNormalDiag(
        loc=loc, scale_diag=scale_diag
    )
)

loc property

loc: Float[Array, ' dims']

scale_diag property

scale_diag: Float[Array, ' dims']

log_prob

log_prob(value: SampleType) -> Float[Array, '']

prob

prob(value: SampleType) -> Float[Array, '']

sample

sample(key: Key) -> SampleType

entropy

entropy() -> Float[Array, '']

mean

mean() -> SampleType

mode

mode() -> SampleType

sample_and_log_prob

sample_and_log_prob(
    key: Key,
) -> tuple[SampleType, Float[Array, ""]]

__init__

__init__(
    loc: Float[ArrayLike, " dims"] | None = None,
    scale_diag: Float[ArrayLike, " dims"] | None = None,
)