Skip to content

Distribution

lerax.distribution.AbstractDistribution

Bases: eqx.Module

Base class for all distributions in Lerax.

log_prob abstractmethod

log_prob(value: SampleType) -> Float[Array, '']

Compute the log probability of a sample.

prob abstractmethod

prob(value: SampleType) -> Float[Array, '']

Compute the probability of a sample.

sample abstractmethod

sample(key: Key) -> SampleType

Return a sample from the distribution.

entropy abstractmethod

entropy() -> Float[Array, '']

Compute the entropy of the distribution.

mean abstractmethod

mean() -> SampleType

Compute the mean of the distribution.

mode abstractmethod

mode() -> SampleType

Compute the mode of the distribution.

sample_and_log_prob abstractmethod

sample_and_log_prob(
    key: Key,
) -> tuple[SampleType, Float[Array, ""]]

Return a sample and its log probability.

lerax.distribution.AbstractMaskableDistribution

Bases: AbstractDistribution[SampleType]

Base class for all maskable distributions in Lerax.

Maskable distributions allow masking of elements in the distribution.

Attributes:

Name Type Description
distribution

The underlying distreqx distribution.

mask abstractmethod

mask(mask: MaskType) -> Self

Return a masked version of the distribution.

A masked distribution only considers the elements where the mask is True.

Parameters:

Name Type Description Default
mask MaskType

A mask indicating which elements to consider.

required

Returns:

Type Description
Self

A new masked distribution.

lerax.distribution.AbstractTransformedDistribution

Bases: AbstractDistreqxWrapper[SampleType]

Base class for all transformed distributions in Lerax.

Transformed distributions apply a bijective transformation to a base distribution.

Attributes:

Name Type Description
distribution eqx.AbstractVar[distributions.AbstractTransformed]

The underlying distreqx transformed distribution.

bijector bijectors.AbstractBijector

The bijective transformation applied to the base distribution.

distribution instance-attribute

distribution: eqx.AbstractVar[
    distributions.AbstractTransformed
]

bijector property

bijector: bijectors.AbstractBijector