Abstract Policies
lerax.policy.AbstractPolicy
Bases: Serializable
Base class for policies.
Policies map observations and internal states to actions and new internal states.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
eqx.AbstractClassVar[str]
|
The name of the policy. |
action_space |
eqx.AbstractVar[AbstractSpace[ActType, MaskType]]
|
The action space of the policy. |
observation_space |
eqx.AbstractVar[AbstractSpace[ObsType, Any]]
|
The observation space of the policy. |
observation_space
instance-attribute
serialize
Serialize the model to the specified path.
Writes a 32-byte structural fingerprint followed by the Equinox
leaf data, so deserialize can verify that the skeleton it builds
matches what was saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to serialize to. The |
required |
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path.
The constructor arguments must reproduce the same static structure
(class, hyperparameters, network shapes, activations, ...) that the
model had when it was serialized. A 32-byte fingerprint stored in
the file is verified before loading; mismatches raise ValueError
instead of silently loading arrays into the wrong skeleton.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor. |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file. |
__call__
abstractmethod
__call__(
state: StateType,
observation: ObsType,
*,
key: Key[Array, ""] | None = None,
action_mask: MaskType | None = None,
) -> tuple[StateType, ActType]
Return the next action and new internal state given the current observation and internal state.
A key can be provided for stochastic policies. If no key is provided, the policy should behave deterministically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation. |
required |
key
|
Key[Array, ''] | None
|
An optional JAX random key for stochastic policies. |
None
|
action_mask
|
MaskType | None
|
An optional action mask. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[StateType, ActType]
|
The new internal state and the action to take. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
An initial internal state for the policy. |
lerax.policy.AbstractPolicyState
Bases: eqx.Module
Base class for policy internal states.