Abstract Actor Critic Policies
lerax.policy.AbstractActorCriticPolicy
Bases:
Base class for stateful actor-critic policies.
Actor-critic policies map observations and internal states to actions, values, and new internal states.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The name of the policy. |
|
|
The action space of the policy. |
|
|
The observation space of the policy. |
observation_space
instance-attribute
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |
__call__
abstractmethod
__call__(
state: StateType,
observation: ObsType,
*,
key: Key | None = None,
action_mask: MaskType | None = None,
) -> tuple[StateType, ActType]
Return the next action and new internal state given the current observation and internal state.
A key can be provided for stochastic policies. If no key is provided, the policy should behave deterministically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current internal state of the policy. |
required |
observation
|
|
The current observation. |
required |
key
|
|
An optional JAX random key for stochastic policies. |
None
|
action_mask
|
|
An optional action mask. |
None
|
Returns:
| Type | Description |
|---|---|
|
The new internal state and the action to take. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
|
An initial internal state for the policy. |
action_and_value
abstractmethod
action_and_value(
state: StateType,
observation: ObsType,
*,
key: Key,
action_mask: MaskType | None = None,
) -> tuple[
StateType, ActType, Float[Array, ""], Float[Array, ""]
]
Get an action and value from an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current policy state. |
required |
observation
|
|
The observation to get the action and value for. |
required |
key
|
|
A JAX PRNG key. |
required |
action_mask
|
|
An optional action mask. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
|
The new policy state. |
action |
|
The action to take. |
value |
|
The value of the observation. |
log_prob |
|
The log probability of the action. |
evaluate_action
abstractmethod
evaluate_action(
state: StateType,
observation: ObsType,
action: ActType,
*,
action_mask: MaskType | None = None,
) -> tuple[
StateType,
Float[Array, ""],
Float[Array, ""],
Float[Array, ""],
]
Evaluate an action given an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current policy state. |
required |
observation
|
|
The observation to evaluate the action for. |
required |
action
|
|
The action to evaluate. |
required |
action_mask
|
|
An optional action mask. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
|
The new policy state. |
value |
|
The value of the observation. |
log_prob |
|
The log probability of the action. |
entropy |
|
The entropy of the action distribution. |
value
abstractmethod
Get the value of an observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current policy state. |
required |
observation
|
|
The observation to get the value for. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
new_state |
|
The new policy state. |
value |
|
The value of the observation. |