MLP
lerax.policy.MLPQPolicy
Bases:
Q-learning policy with an MLP Q-network.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
Name of the policy class. |
|
|
The action space of the environment. |
|
|
The observation space of the environment. |
|
|
The epsilon value for epsilon-greedy action selection. |
|
|
The MLP Q-network used for action value estimation. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to create the policy for. |
required |
epsilon
|
|
The epsilon value for epsilon-greedy action selection. |
0.1
|
width_size
|
|
The width of the hidden layers in the MLP. |
64
|
depth
|
|
The number of hidden layers in the MLP. |
2
|
key
|
|
JAX PRNG key for parameter initialization. |
required |
Raises:
| Type | Description |
|---|---|
|
If the environment's action space is not Discrete. |
observation_space
instance-attribute
q_network
instance-attribute
q_network: MLP = MLP(
in_size=self.observation_space.flat_size,
out_size=self.action_space.n,
width_size=width_size,
depth=depth,
key=key,
)
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |
__call__
__call__(
state: StateType,
observation: ObsType,
*,
action_mask: Bool[Array, " actions"] | None = None,
key: Array | None = None,
) -> tuple[StateType, Integer[Array, ""]]
Return the next state and action for a given observation and state.
Uses epsilon-greedy action selection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current internal state of the policy. |
required |
observation
|
|
The current observation from the environment. |
required |
key
|
|
JAX PRNG key for stochastic action selection. If None, the action with the highest Q-value is always selected. |
None
|
Returns:
| Type | Description |
|---|---|
|
A tuple of the next internal state and the selected action. |