Skip to content

MLP

lerax.policy.MLPQPolicy

Bases: AbstractQPolicy[None, ObsType]

Q-learning policy with an MLP Q-network.

Attributes:

Name Type Description
name str

Name of the policy class.

action_space Discrete

The action space of the environment.

observation_space AbstractSpace[ObsType, Any]

The observation space of the environment.

epsilon float

The epsilon value for epsilon-greedy action selection.

q_network eqx.nn.MLP

The MLP Q-network used for action value estimation.

Parameters:

Name Type Description Default
env AbstractEnvLike[StateType, Integer[Array, ''], ObsType, Any]

The environment to create the policy for.

required
epsilon float

The epsilon value for epsilon-greedy action selection.

0.1
width_size int

The width of the hidden layers in the MLP.

64
depth int

The number of hidden layers in the MLP.

2
key Key[Array, '']

JAX PRNG key for parameter initialization.

required

Raises:

Type Description
ValueError

If the environment's action space is not Discrete.

name class-attribute

name: str = 'MLPQPolicy'

action_space instance-attribute

action_space: Discrete = env.action_space

observation_space instance-attribute

observation_space: AbstractSpace[ObsType, Any] = (
    env.observation_space
)

epsilon instance-attribute

epsilon: float = epsilon

q_network instance-attribute

q_network: eqx.nn.MLP = eqx.nn.MLP(
    in_size=self.observation_space.flat_size,
    out_size=self.action_space.n,
    width_size=width_size,
    depth=depth,
    key=key,
)

serialize

serialize(path: str | Path) -> None

Serialize the model to the specified path.

Writes a 32-byte structural fingerprint followed by the Equinox leaf data, so deserialize can verify that the skeleton it builds matches what was saved.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to. The .eqx suffix is appended if missing.

required

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path.

The constructor arguments must reproduce the same static structure (class, hyperparameters, network shapes, activations, ...) that the model had when it was serialized. A 32-byte fingerprint stored in the file is verified before loading; mismatches raise ValueError instead of silently loading arrays into the wrong skeleton.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor.

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor.

{}

Returns:

Type Description
ClassType

The deserialized model.

Raises:

Type Description
ValueError

If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file.

__call__

__call__(
    state: StateType,
    observation: ObsType,
    *,
    action_mask: Bool[Array, " actions"] | None = None,
    key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]

Return the next state and action for a given observation and state.

Uses epsilon-greedy action selection.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required
key Array | None

JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected.

None

Returns:

Type Description
tuple[StateType, Integer[Array, '']]

A tuple of the next internal state and the selected action.

__init__

__init__[StateType: AbstractEnvLikeState](
    env: AbstractEnvLike[
        StateType, Integer[Array, ""], ObsType, Any
    ],
    *,
    epsilon: float = 0.1,
    width_size: int = 64,
    depth: int = 2,
    key: Key[Array, ""],
)

reset

reset(*, key: Key[Array, '']) -> None

q_values

q_values(
    state: None, observation: ObsType
) -> tuple[None, Float[Array, " actions"]]