MLP
lerax.policy.MLPActorCriticPolicy
Bases:
Actor–critic policy with MLP components.
Uses an MLP to encode observations into features, then separate MLPs to produce action distributions and value estimates from those features.
Action distributions are produced by mapping action head outputs to the parameters of the appropriate distribution for the action space.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
Name of the policy class. |
|
|
The action space of the environment. |
|
|
The observation space of the environment. |
|
|
MLP to encode observations into features. |
|
|
MLP to produce value estimates from features. |
|
|
MLP to produce action distributions from features. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to create the policy for. |
required |
feature_size
|
|
Size of the feature representation. |
16
|
feature_width
|
|
Width of the hidden layers in the feature encoder. |
64
|
feature_depth
|
|
Depth of the hidden layers in the feature encoder. |
2
|
value_width
|
|
Width of the hidden layers in the value head. |
64
|
value_depth
|
|
Depth of the hidden layers in the value head. |
2
|
action_width
|
|
Width of the hidden layers in the action head. |
64
|
action_depth
|
|
Depth of the hidden layers in the action head. |
2
|
log_std_init
|
|
Initial log standard deviation for continuous action spaces. |
0.0
|
key
|
|
JAX PRNG key for parameter initialization. |
required |
action_space
instance-attribute
observation_space
instance-attribute
encoder
instance-attribute
encoder: MLP = MLP(
in_size=self.observation_space.flat_size,
out_size=feature_size,
width_size=feature_width,
depth=feature_depth,
key=feat_key,
)
value_head
instance-attribute
value_head: MLP = MLP(
in_size=feature_size,
out_size="scalar",
width_size=value_width,
depth=value_depth,
key=val_key,
)
action_head
instance-attribute
action_head: ActionLayer = ActionLayer(
self.action_space,
feature_size,
action_width,
action_depth,
key=act_key,
log_std_init=log_std_init,
)
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |
__init__
__init__[S: AbstractEnvLikeState](
env: AbstractEnvLike[S, ActType, ObsType, MaskType],
*,
feature_size: int = 16,
feature_width: int = 64,
feature_depth: int = 2,
value_width: int = 64,
value_depth: int = 2,
action_width: int = 64,
action_depth: int = 2,
log_std_init: float = 0.0,
key: Key,
)
__call__
__call__(
state: None,
observation: ObsType,
*,
key: Key | None = None,
action_mask: MaskType | None = None,
) -> tuple[None, ActType]
action_and_value
action_and_value(
state: None,
observation: ObsType,
*,
key: Key,
action_mask: MaskType | None = None,
) -> tuple[
None, ActType, Float[Array, ""], Float[Array, ""]
]
Get an action and value from an observation.
If key is provided, it will be used for sampling actions, if no key is
provided the policy will return the most likely action.
value
Get the value of an observation.