Multi Categorical Distribution
lerax.distribution.MultiCategorical
Bases:
Product of independent Categorical distributions.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
Tuple of underlying Categorical distributions. |
|
|
Tuple of number of classes for each categorical dimension. |
__init__
__init__(
logits: Float[ArrayLike, " sum_of_classes"]
| Sequence[Float[ArrayLike, " classes"]]
| None = None,
probs: Float[ArrayLike, " sum_of_classes"]
| Sequence[Float[ArrayLike, " classes"]]
| None = None,
action_dims: Sequence[int] | None = None,
)
_split_or_unpack_params
staticmethod
_split_or_unpack_params(
params: Float[ArrayLike, " sum_of_classes"]
| Sequence[Float[ArrayLike, " classes"]],
action_dims: tuple[int, ...] | None,
) -> tuple[
tuple[Float[Array, "... classes"], ...], tuple[int, ...]
]