Skip to content

Squashed Normal Distribution

lerax.distribution.SquashedNormal

Bases: AbstractTransformedDistribution[Float[Array, ' dims']]

Normal distribution with squashing bijector for bounded outputs.

Attributes:

Name Type Description
distribution distributions.Transformed

The underlying distreqx Transformed distribution.

Parameters:

Name Type Description Default
loc Float[ArrayLike, '']

The mean of the normal distribution.

required
scale Float[ArrayLike, '']

The standard deviation of the normal distribution.

required
high Float[ArrayLike, '']

The upper bound for bounded squashing.

jnp.array(1.0)
low Float[ArrayLike, '']

The lower bound for bounded squashing.

jnp.array(-1.0)

bijector property

bijector: bijectors.AbstractBijector

distribution instance-attribute

distribution: distributions.Transformed = (
    distributions.Transformed(normal, bijector)
)

loc property

loc: Float[Array, ' dims']

scale property

scale: Float[Array, ' dims']

log_prob

log_prob(value: SampleType) -> Float[Array, '']

prob

prob(value: SampleType) -> Float[Array, '']

sample

sample(key: Key) -> SampleType

entropy

entropy() -> Float[Array, '']

mean

mean() -> SampleType

mode

mode() -> SampleType

sample_and_log_prob

sample_and_log_prob(
    key: Key,
) -> tuple[SampleType, Float[Array, ""]]

__init__

__init__(
    loc: Float[ArrayLike, ""],
    scale: Float[ArrayLike, ""],
    high: Float[ArrayLike, ""] = jnp.array(1.0),
    low: Float[ArrayLike, ""] = jnp.array(-1.0),
)