MLP
lerax.policy.MLPQPolicy
Bases: AbstractQPolicy[None, ObsType]
Q-learning policy with an MLP Q-network.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
str
|
Name of the policy class. |
action_space |
Discrete
|
The action space of the environment. |
observation_space |
AbstractSpace[ObsType, Any]
|
The observation space of the environment. |
epsilon |
float
|
The epsilon value for epsilon-greedy action selection. |
q_network |
eqx.nn.MLP
|
The MLP Q-network used for action value estimation. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike[StateType, Integer[Array, ''], ObsType, Any]
|
The environment to create the policy for. |
required |
epsilon
|
float
|
The epsilon value for epsilon-greedy action selection. |
0.1
|
width_size
|
int
|
The width of the hidden layers in the MLP. |
64
|
depth
|
int
|
The number of hidden layers in the MLP. |
2
|
key
|
Key[Array, '']
|
JAX PRNG key for parameter initialization. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the environment's action space is not Discrete. |
observation_space
instance-attribute
q_network
instance-attribute
q_network: eqx.nn.MLP = eqx.nn.MLP(
in_size=self.observation_space.flat_size,
out_size=self.action_space.n,
width_size=width_size,
depth=depth,
key=key,
)
serialize
Serialize the model to the specified path.
Writes a 32-byte structural fingerprint followed by the Equinox
leaf data, so deserialize can verify that the skeleton it builds
matches what was saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to serialize to. The |
required |
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path.
The constructor arguments must reproduce the same static structure
(class, hyperparameters, network shapes, activations, ...) that the
model had when it was serialized. A 32-byte fingerprint stored in
the file is verified before loading; mismatches raise ValueError
instead of silently loading arrays into the wrong skeleton.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor. |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file. |
__call__
__call__(
state: StateType,
observation: ObsType,
*,
action_mask: Bool[Array, " actions"] | None = None,
key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]
Return the next state and action for a given observation and state.
Uses epsilon-greedy action selection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current internal state of the policy. |
required |
observation
|
ObsType
|
The current observation from the environment. |
required |
key
|
Array | None
|
JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[StateType, Integer[Array, '']]
|
A tuple of the next internal state and the selected action. |