Skip to content

Bernoulli Distribution

lerax.distribution.Bernoulli

Bases: AbstractMaskableDistribution[Bool[Array, ' dims'], Bool[Array, ' dims']], AbstractDistreqxWrapper[Bool[Array, ' dims']]

Bernoulli distribution.

Attributes:

Name Type Description
distribution distributions.Bernoulli

The underlying distreqx Bernoulli distribution.

Parameters:

Name Type Description Default
logits Float[ArrayLike, ' dims'] | None

The log-odds of the distribution.

None
probs Float[ArrayLike, ' dims'] | None

The probabilities of the distribution.

None

distribution instance-attribute

distribution: distributions.Bernoulli = (
    distributions.Bernoulli(logits=logits, probs=probs)
)

logits property

logits: Float[Array, ' dims']

probs property

probs: Float[Array, ' dims']

log_prob

log_prob(value: SampleType) -> Float[Array, '']

prob

prob(value: SampleType) -> Float[Array, '']

sample

sample(key: Key) -> SampleType

entropy

entropy() -> Float[Array, '']

mean

mean() -> SampleType

mode

mode() -> SampleType

sample_and_log_prob

sample_and_log_prob(
    key: Key,
) -> tuple[SampleType, Float[Array, ""]]

__init__

__init__(
    logits: Float[ArrayLike, " dims"] | None = None,
    probs: Float[ArrayLike, " dims"] | None = None,
)

mask

mask(mask: Bool[Array, ' dims']) -> Bernoulli