Abstract Actor Critic Policies
lerax.policy.AbstractActorCriticPolicy
Bases: AbstractPolicy[StateType, ActType, ObsType, MaskType]
Base class for stateful actor-critic policies.
Actor-critic policies map observations and internal states to actions, values, and new internal states.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
eqx.AbstractClassVar[str]
|
The name of the policy. |
action_space |
eqx.AbstractVar[AbstractSpace[ActType, MaskType]]
|
The action space of the policy. |
observation_space |
eqx.AbstractVar[AbstractSpace[ObsType, Any]]
|
The observation space of the policy. |
observation_space
instance-attribute
serialize
Serialize the model to the specified path.
Writes a 32-byte structural fingerprint followed by the Equinox
leaf data, so deserialize can verify that the skeleton it builds
matches what was saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to serialize to. The |
required |
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path.
The constructor arguments must reproduce the same static structure
(class, hyperparameters, network shapes, activations, ...) that the
model had when it was serialized. A 32-byte fingerprint stored in
the file is verified before loading; mismatches raise ValueError
instead of silently loading arrays into the wrong skeleton.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor. |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file. |
__call__
abstractmethod
__call__(
state: StateType,
observation: ObsType,
*,
key: Key[Array, ""] | None = None,
action_mask: MaskType | None = None,
) -> tuple[StateType, ActType]
Return the next action and new internal state given the current observation and internal state.
A key can be provided for stochastic policies. If no key is provided, the policy should behave deterministically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation. |
required |
key
|
Key[Array, ''] | None
|
An optional JAX random key for stochastic policies. |
None
|
action_mask
|
MaskType | None
|
An optional action mask. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[StateType, ActType]
|
The new internal state and the action to take. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
An initial internal state for the policy. |
action_and_value
abstractmethod
action_and_value(
state: StateType,
observation: ObsType,
*,
key: Key[Array, ""],
action_mask: MaskType | None = None,
) -> tuple[
StateType, ActType, Float[Array, ""], Float[Array, ""]
]
Get an action and value from an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current policy state. |
required |
observation
|
ObsType
|
The observation to get the action and value for. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
action_mask
|
MaskType | None
|
An optional action mask. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
StateType
|
The new policy state. |
action |
ActType
|
The action to take. |
value |
Float[Array, '']
|
The value of the observation. |
log_prob |
Float[Array, '']
|
The log probability of the action. |
evaluate_action
abstractmethod
evaluate_action(
state: StateType,
observation: ObsType,
action: ActType,
*,
action_mask: MaskType | None = None,
) -> tuple[
StateType,
Float[Array, ""],
Float[Array, ""],
Float[Array, ""],
]
Evaluate an action given an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current policy state. |
required |
observation
|
ObsType
|
The observation to evaluate the action for. |
required |
action
|
ActType
|
The action to evaluate. |
required |
action_mask
|
MaskType | None
|
An optional action mask. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
StateType
|
The new policy state. |
value |
Float[Array, '']
|
The value of the observation. |
log_prob |
Float[Array, '']
|
The log probability of the action. |
entropy |
Float[Array, '']
|
The entropy of the action distribution. |
value
abstractmethod
Get the value of an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current policy state. |
required |
observation
|
ObsType
|
The observation to get the value for. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
StateType
|
The new policy state. |
value |
Float[Array, '']
|
The value of the observation. |