Skip to content

Abstract Q Policies

lerax.policy.AbstractQPolicy

Bases: AbstractPolicy[StateType, Integer[Array, ''], ObsType, Bool[Array, ' actions']]

Base class for stateful epsilon-greedy Q-learning policies.

Epsilon-greedy policies select a random action with probability epsilon and the action with the highest Q-value with probability 1-epsilon.

Attributes:

Name Type Description
name eqx.AbstractClassVar[str]

Name of the policy class.

action_space eqx.AbstractVar[Discrete]

The action space of the environment.

observation_space eqx.AbstractVar[AbstractSpace[ObsType, Any]]

The observation space of the environment.

epsilon eqx.AbstractVar[float]

The epsilon value for epsilon-greedy action selection.

name instance-attribute

name: eqx.AbstractClassVar[str]

action_space instance-attribute

action_space: eqx.AbstractVar[Discrete]

observation_space instance-attribute

observation_space: eqx.AbstractVar[
    AbstractSpace[ObsType, Any]
]

epsilon instance-attribute

epsilon: eqx.AbstractVar[float]

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.

reset abstractmethod

reset(*, key: Key) -> StateType

Return an initial internal state for the policy.

Parameters:

Name Type Description Default
key Key

A JAX random key for initializing the state.

required

Returns:

Type Description
StateType

An initial internal state for the policy.

q_values abstractmethod

q_values(
    state: StateType, observation: ObsType
) -> tuple[StateType, Float[Array, " actions"]]

Return Q-values for all actions given an observation and state.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required

Returns:

Type Description
tuple[StateType, Float[Array, ' actions']]

A tuple of the next internal state and the Q-values for all actions.

__call__

__call__(
    state: StateType,
    observation: ObsType,
    *,
    action_mask: Bool[Array, " actions"] | None = None,
    key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]

Return the next state and action for a given observation and state.

Uses epsilon-greedy action selection.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required
key Array | None

JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected.

None

Returns:

Type Description
tuple[StateType, Integer[Array, '']]

A tuple of the next internal state and the selected action.