Abstract Q Policies
lerax.policy.AbstractQPolicy
Bases: AbstractPolicy[StateType, Integer[Array, ''], ObsType, Bool[Array, ' actions']]
Base class for stateful epsilon-greedy Q-learning policies.
Epsilon-greedy policies select a random action with probability epsilon and the action with the highest Q-value with probability 1-epsilon.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
eqx.AbstractClassVar[str]
|
Name of the policy class. |
action_space |
eqx.AbstractVar[Discrete]
|
The action space of the environment. |
observation_space |
eqx.AbstractVar[AbstractSpace[ObsType, Any]]
|
The observation space of the environment. |
epsilon |
eqx.AbstractVar[float]
|
The epsilon value for epsilon-greedy action selection. |
observation_space
instance-attribute
serialize
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
reset
abstractmethod
Return an initial internal state for the policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
A JAX random key for initializing the state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
An initial internal state for the policy. |
q_values
abstractmethod
Return Q-values for all actions given an observation and state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation from the environment. |
required |
Returns:
| Type | Description |
|---|---|
tuple[StateType, Float[Array, ' actions']]
|
A tuple of the next internal state and the Q-values for all actions. |
__call__
__call__(
state: StateType,
observation: ObsType,
*,
action_mask: Bool[Array, " actions"] | None = None,
key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]
Return the next state and action for a given observation and state.
Uses epsilon-greedy action selection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation from the environment. |
required |
key
|
Array | None
|
JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[StateType, Integer[Array, '']]
|
A tuple of the next internal state and the selected action. |