Skip to content

MLP

lerax.policy.MLPQPolicy

Bases: AbstractQPolicy[None, ObsType]

Q-learning policy with an MLP Q-network.

Attributes:

Name Type Description
name str

Name of the policy class.

action_space Discrete

The action space of the environment.

observation_space AbstractSpace[ObsType, Any]

The observation space of the environment.

epsilon float

The epsilon value for epsilon-greedy action selection.

q_network MLP

The MLP Q-network used for action value estimation.

Parameters:

Name Type Description Default
env AbstractEnvLike[StateType, Integer[Array, ''], ObsType, Bool[Array, ' n']]

The environment to create the policy for.

required
epsilon float

The epsilon value for epsilon-greedy action selection.

0.1
width_size int

The width of the hidden layers in the MLP.

64
depth int

The number of hidden layers in the MLP.

2
key Key

JAX PRNG key for parameter initialization.

required

Raises:

Type Description
ValueError

If the environment's action space is not Discrete.

name class-attribute

name: str = 'MLPQPolicy'

action_space instance-attribute

action_space: Discrete = env.action_space

observation_space instance-attribute

observation_space: AbstractSpace[ObsType, Any] = (
    env.observation_space
)

epsilon instance-attribute

epsilon: float = epsilon

q_network instance-attribute

q_network: MLP = MLP(
    in_size=self.observation_space.flat_size,
    out_size=self.action_space.n,
    width_size=width_size,
    depth=depth,
    key=key,
)

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.

__call__

__call__(
    state: StateType,
    observation: ObsType,
    *,
    action_mask: Bool[Array, " actions"] | None = None,
    key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]

Return the next state and action for a given observation and state.

Uses epsilon-greedy action selection.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation from the environment.

required
key Array | None

JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected.

None

Returns:

Type Description
tuple[StateType, Integer[Array, '']]

A tuple of the next internal state and the selected action.

__init__

__init__[StateType: AbstractEnvLikeState](
    env: AbstractEnvLike[
        StateType,
        Integer[Array, ""],
        ObsType,
        Bool[Array, " n"],
    ],
    *,
    epsilon: float = 0.1,
    width_size: int = 64,
    depth: int = 2,
    key: Key,
)

reset

reset(*, key: Key) -> None

q_values

q_values(
    state: None, observation: ObsType
) -> tuple[None, Float[Array, " actions"]]