Skip to content

MLP

lerax.policy.MLPActorCriticPolicy

Bases: AbstractActorCriticPolicy[None, ActType, ObsType, MaskType]

Actor–critic policy with MLP components.

Uses an MLP to encode observations into features, then separate MLPs to produce action distributions and value estimates from those features.

Action distributions are produced by mapping action head outputs to the parameters of the appropriate distribution for the action space.

Attributes:

Name Type Description
name str

Name of the policy class.

action_space AbstractSpace[ActType, MaskType]

The action space of the environment.

observation_space AbstractSpace[ObsType, Any]

The observation space of the environment.

encoder MLP

MLP to encode observations into features.

value_head MLP

MLP to produce value estimates from features.

action_head ActionLayer

MLP to produce action distributions from features.

Parameters:

Name Type Description Default
env AbstractEnvLike[S, ActType, ObsType, MaskType]

The environment to create the policy for.

required
feature_size int

Size of the feature representation.

16
feature_width int

Width of the hidden layers in the feature encoder.

64
feature_depth int

Depth of the hidden layers in the feature encoder.

2
value_width int

Width of the hidden layers in the value head.

64
value_depth int

Depth of the hidden layers in the value head.

2
action_width int

Width of the hidden layers in the action head.

64
action_depth int

Depth of the hidden layers in the action head.

2
log_std_init float

Initial log standard deviation for continuous action spaces.

0.0
key Key

JAX PRNG key for parameter initialization.

required

name class-attribute

name: str = 'MLPActorCriticPolicy'

action_space instance-attribute

action_space: AbstractSpace[ActType, MaskType] = (
    env.action_space
)

observation_space instance-attribute

observation_space: AbstractSpace[ObsType, Any] = (
    env.observation_space
)

encoder instance-attribute

encoder: MLP = MLP(
    in_size=self.observation_space.flat_size,
    out_size=feature_size,
    width_size=feature_width,
    depth=feature_depth,
    key=feat_key,
)

value_head instance-attribute

value_head: MLP = MLP(
    in_size=feature_size,
    out_size="scalar",
    width_size=value_width,
    depth=value_depth,
    key=val_key,
)

action_head instance-attribute

action_head: ActionLayer = ActionLayer(
    self.action_space,
    feature_size,
    action_width,
    action_depth,
    key=act_key,
    log_std_init=log_std_init,
)

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.

__init__

__init__[S: AbstractEnvLikeState](
    env: AbstractEnvLike[S, ActType, ObsType, MaskType],
    *,
    feature_size: int = 16,
    feature_width: int = 64,
    feature_depth: int = 2,
    value_width: int = 64,
    value_depth: int = 2,
    action_width: int = 64,
    action_depth: int = 2,
    log_std_init: float = 0.0,
    key: Key,
)

reset

reset(*, key: Key) -> None

__call__

__call__(
    state: None,
    observation: ObsType,
    *,
    key: Key | None = None,
    action_mask: MaskType | None = None,
) -> tuple[None, ActType]

action_and_value

action_and_value(
    state: None,
    observation: ObsType,
    *,
    key: Key,
    action_mask: MaskType | None = None,
) -> tuple[
    None, ActType, Float[Array, ""], Float[Array, ""]
]

Get an action and value from an observation.

If key is provided, it will be used for sampling actions, if no key is provided the policy will return the most likely action.

value

value(
    state: None, observation: ObsType
) -> tuple[None, Float[Array, ""]]

Get the value of an observation.

evaluate_action

evaluate_action(
    state: None,
    observation: ObsType,
    action: ActType,
    *,
    action_mask: MaskType | None = None,
) -> tuple[
    None,
    Float[Array, ""],
    Float[Array, ""],
    Float[Array, ""],
]

Evaluate an action given an observation.