Abstract Q Policies
lerax.policy.AbstractQPolicy
Bases: AbstractPolicy[StateType, Integer[Array, ''], ObsType, Bool[Array, ' actions']]
Base class for stateful epsilon-greedy Q-learning policies.
Epsilon-greedy policies select a random action with probability epsilon and the action with the highest Q-value with probability 1-epsilon.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
eqx.AbstractClassVar[str]
|
Name of the policy class. |
action_space |
eqx.AbstractVar[Discrete]
|
The action space of the environment. |
observation_space |
eqx.AbstractVar[AbstractSpace[ObsType, Any]]
|
The observation space of the environment. |
epsilon |
eqx.AbstractVar[float]
|
The epsilon value for epsilon-greedy action selection. |
observation_space
instance-attribute
serialize
Serialize the model to the specified path.
Writes a 32-byte structural fingerprint followed by the Equinox
leaf data, so deserialize can verify that the skeleton it builds
matches what was saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to serialize to. The |
required |
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path.
The constructor arguments must reproduce the same static structure
(class, hyperparameters, network shapes, activations, ...) that the
model had when it was serialized. A 32-byte fingerprint stored in
the file is verified before loading; mismatches raise ValueError
instead of silently loading arrays into the wrong skeleton.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor. |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
An initial internal state for the policy. |
q_values
abstractmethod
Return Q-values for all actions given an observation and state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation from the environment. |
required |
Returns:
| Type | Description |
|---|---|
tuple[StateType, Float[Array, ' actions']]
|
A tuple of the next internal state and the Q-values for all actions. |
__call__
__call__(
state: StateType,
observation: ObsType,
*,
action_mask: Bool[Array, " actions"] | None = None,
key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]
Return the next state and action for a given observation and state.
Uses epsilon-greedy action selection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation from the environment. |
required |
key
|
Array | None
|
JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[StateType, Integer[Array, '']]
|
A tuple of the next internal state and the selected action. |