Skip to content

Multi Categorical Distribution

lerax.distribution.MultiCategorical

Bases: AbstractMaskableDistribution[Integer[Array, ' dims'], Bool[Array, '... sum_of_classes'] | Sequence[Bool[Array, '... classes']]]

Product of independent Categorical distributions.

Attributes:

Name Type Description
distribution tuple[distributions.Categorical, ...]

Tuple of underlying Categorical distributions.

action_dims tuple[int, ...]

Tuple of number of classes for each categorical dimension.

distribution instance-attribute

distribution: tuple[distributions.Categorical, ...]

action_dims instance-attribute

action_dims: tuple[int, ...]

logits property

logits: Float[Array, '... sum_of_classes']

probs property

probs: Float[Array, '... sum_of_classes']

__init__

__init__(
    logits: Float[ArrayLike, " sum_of_classes"]
    | Sequence[Float[ArrayLike, " classes"]]
    | None = None,
    probs: Float[ArrayLike, " sum_of_classes"]
    | Sequence[Float[ArrayLike, " classes"]]
    | None = None,
    action_dims: Sequence[int] | None = None,
)

_split_or_unpack_params staticmethod

_split_or_unpack_params(
    params: Float[ArrayLike, " sum_of_classes"]
    | Sequence[Float[ArrayLike, " classes"]],
    action_dims: tuple[int, ...] | None,
) -> tuple[
    tuple[Float[Array, "... classes"], ...], tuple[int, ...]
]

mask

mask(
    mask: Bool[Array, "... sum_of_classes"]
    | Sequence[Bool[Array, "... classes"]],
) -> MultiCategorical

log_prob

log_prob(
    value: Integer[ArrayLike, " dims"],
) -> Float[Array, "..."]

prob

prob(
    value: Integer[ArrayLike, " dims"],
) -> Float[Array, "..."]

sample

sample(key: Key) -> Integer[Array, ' ... dims']

entropy

entropy() -> Float[Array, '...']

mean

mean() -> Float[Array, ' ... dims']

mode

mode() -> Integer[Array, ' ... dims']

sample_and_log_prob

sample_and_log_prob(
    key: Key,
) -> tuple[
    Integer[Array, " ... dims"], Float[Array, "..."]
]