Abstract Q Policies
lerax.policy.AbstractQPolicy
Bases:
Base class for stateful epsilon-greedy Q-learning policies.
Epsilon-greedy policies select a random action with probability epsilon and the action with the highest Q-value with probability 1-epsilon.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
Name of the policy class. |
|
|
The action space of the environment. |
|
|
The observation space of the environment. |
|
|
The epsilon value for epsilon-greedy action selection. |
observation_space
instance-attribute
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
|
An initial internal state for the policy. |
q_values
abstractmethod
Return Q-values for all actions given an observation and state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current internal state of the policy. |
required |
observation
|
|
The current observation from the environment. |
required |
Returns:
| Type | Description |
|---|---|
|
A tuple of the next internal state and the Q-values for all actions. |
__call__
__call__(
state: StateType,
observation: ObsType,
*,
action_mask: Bool[Array, " actions"] | None = None,
key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]
Return the next state and action for a given observation and state.
Uses epsilon-greedy action selection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current internal state of the policy. |
required |
observation
|
|
The current observation from the environment. |
required |
key
|
|
JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected. |
None
|
Returns:
| Type | Description |
|---|---|
|
A tuple of the next internal state and the selected action. |