Skip to content

Abstract Q Policies

lerax.policy.AbstractQPolicy

Bases: AbstractPolicy[StateType, Integer[Array, ''], ObsType, Bool[Array, ' actions']]

Base class for stateful epsilon-greedy Q-learning policies.

Epsilon-greedy policies select a random action with probability epsilon and the action with the highest Q-value with probability 1-epsilon.

Attributes:

Name Type Description
name eqx.AbstractClassVar[str]

Name of the policy class.

action_space eqx.AbstractVar[Discrete]

The action space of the environment.

observation_space eqx.AbstractVar[AbstractSpace[ObsType, Any]]

The observation space of the environment.

epsilon eqx.AbstractVar[float]

The epsilon value for epsilon-greedy action selection.

name instance-attribute

name: eqx.AbstractClassVar[str]

action_space instance-attribute

action_space: eqx.AbstractVar[Discrete]

observation_space instance-attribute

observation_space: eqx.AbstractVar[
    AbstractSpace[ObsType, Any]
]

epsilon instance-attribute

epsilon: eqx.AbstractVar[float]

serialize

serialize(path: str | Path) -> None

Serialize the model to the specified path.

Writes a 32-byte structural fingerprint followed by the Equinox leaf data, so deserialize can verify that the skeleton it builds matches what was saved.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to. The .eqx suffix is appended if missing.

required

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path.

The constructor arguments must reproduce the same static structure (class, hyperparameters, network shapes, activations, ...) that the model had when it was serialized. A 32-byte fingerprint stored in the file is verified before loading; mismatches raise ValueError instead of silently loading arrays into the wrong skeleton.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor.

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor.

{}

Returns:

Type Description
ClassType

The deserialized model.

Raises:

Type Description
ValueError

If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file.

reset abstractmethod

reset(*, key: Key[Array, '']) -> StateType

Return an initial internal state for the policy.

Parameters:

Name Type Description Default
key Key[Array, '']

A JAX random key for initializing the state.

required

Returns:

Type Description
StateType

An initial internal state for the policy.

q_values abstractmethod

q_values(
    state: StateType, observation: ObsType
) -> tuple[StateType, Float[Array, " actions"]]

Return Q-values for all actions given an observation and state.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required

Returns:

Type Description
tuple[StateType, Float[Array, ' actions']]

A tuple of the next internal state and the Q-values for all actions.

__call__

__call__(
    state: StateType,
    observation: ObsType,
    *,
    action_mask: Bool[Array, " actions"] | None = None,
    key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]

Return the next state and action for a given observation and state.

Uses epsilon-greedy action selection.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required
key Array | None

JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected.

None

Returns:

Type Description
tuple[StateType, Integer[Array, '']]

A tuple of the next internal state and the selected action.