Skip to content

Squashed Multivariate Normal Distribution

lerax.distribution.SquashedMultivariateNormalDiag

Bases: AbstractTransformedDistribution[Float[Array, ' dims']]

Multivariate Normal with squashing bijector for bounded outputs.

Attributes:

Name Type Description
distribution distributions.Transformed

The underlying distreqx Transformed distribution.

Parameters:

Name Type Description Default
loc Float[ArrayLike, ' dims']

The mean of the multivariate normal distribution.

required
scale_diag Float[ArrayLike, ' dims']

The diagonal of the covariance matrix.

required
high Float[ArrayLike, ' dims']

The upper bound for bounded squashing..

jnp.array(1.0)
low Float[ArrayLike, ' dims']

The lower bound for bounded squashing..

jnp.array(-1.0)

bijector property

bijector: bijectors.AbstractBijector

distribution instance-attribute

distribution: distributions.Transformed = (
    distributions.Transformed(mvn, bijector)
)

loc property

loc: Float[Array, ' dims']

scale_diag property

scale_diag: Float[Array, ' dims']

log_prob

log_prob(value: SampleType) -> Float[Array, '']

prob

prob(value: SampleType) -> Float[Array, '']

sample

sample(key: Key) -> SampleType

entropy

entropy() -> Float[Array, '']

mean

mean() -> SampleType

mode

mode() -> SampleType

sample_and_log_prob

sample_and_log_prob(
    key: Key,
) -> tuple[SampleType, Float[Array, ""]]

__init__

__init__(
    loc: Float[ArrayLike, " dims"],
    scale_diag: Float[ArrayLike, " dims"],
    high: Float[ArrayLike, " dims"] = jnp.array(1.0),
    low: Float[ArrayLike, " dims"] = jnp.array(-1.0),
)