Skip to content

Abstract Actor Critic Policies

lerax.policy.AbstractActorCriticPolicy

Bases: AbstractPolicy[StateType, ActType, ObsType, MaskType]

Base class for stateful actor-critic policies.

Actor-critic policies map observations and internal states to actions, values, and new internal states.

Attributes:

Name Type Description
name eqx.AbstractClassVar[str]

The name of the policy.

action_space eqx.AbstractVar[AbstractSpace[ActType, MaskType]]

The action space of the policy.

observation_space eqx.AbstractVar[AbstractSpace[ObsType, Any]]

The observation space of the policy.

name instance-attribute

name: eqx.AbstractClassVar[str]

action_space instance-attribute

action_space: eqx.AbstractVar[
    AbstractSpace[ActType, MaskType]
]

observation_space instance-attribute

observation_space: eqx.AbstractVar[
    AbstractSpace[ObsType, Any]
]

serialize

serialize(path: str | Path) -> None

Serialize the model to the specified path.

Writes a 32-byte structural fingerprint followed by the Equinox leaf data, so deserialize can verify that the skeleton it builds matches what was saved.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to. The .eqx suffix is appended if missing.

required

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path.

The constructor arguments must reproduce the same static structure (class, hyperparameters, network shapes, activations, ...) that the model had when it was serialized. A 32-byte fingerprint stored in the file is verified before loading; mismatches raise ValueError instead of silently loading arrays into the wrong skeleton.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor.

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor.

{}

Returns:

Type Description
ClassType

The deserialized model.

Raises:

Type Description
ValueError

If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file.

__call__ abstractmethod

__call__(
    state: StateType,
    observation: ObsType,
    *,
    key: Key[Array, ""] | None = None,
    action_mask: MaskType | None = None,
) -> tuple[StateType, ActType]

Return the next action and new internal state given the current observation and internal state.

A key can be provided for stochastic policies. If no key is provided, the policy should behave deterministically.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation.

required
key Key[Array, ''] | None

An optional JAX random key for stochastic policies.

None
action_mask MaskType | None

An optional action mask.

None

Returns:

Type Description
tuple[StateType, ActType]

The new internal state and the action to take.

reset abstractmethod

reset(*, key: Key[Array, '']) -> StateType

Return an initial internal state for the policy.

Parameters:

Name Type Description Default
key Key[Array, '']

A JAX random key for initializing the state.

required

Returns:

Type Description
StateType

An initial internal state for the policy.

action_and_value abstractmethod

action_and_value(
    state: StateType,
    observation: ObsType,
    *,
    key: Key[Array, ""],
    action_mask: MaskType | None = None,
) -> tuple[
    StateType, ActType, Float[Array, ""], Float[Array, ""]
]

Get an action and value from an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to get the action and value for.

required
key Key[Array, '']

A JAX PRNG key.

required
action_mask MaskType | None

An optional action mask.

None

Returns:

Name Type Description
new_state StateType

The new policy state.

action ActType

The action to take.

value Float[Array, '']

The value of the observation.

log_prob Float[Array, '']

The log probability of the action.

evaluate_action abstractmethod

evaluate_action(
    state: StateType,
    observation: ObsType,
    action: ActType,
    *,
    action_mask: MaskType | None = None,
) -> tuple[
    StateType,
    Float[Array, ""],
    Float[Array, ""],
    Float[Array, ""],
]

Evaluate an action given an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to evaluate the action for.

required
action ActType

The action to evaluate.

required
action_mask MaskType | None

An optional action mask.

None

Returns:

Name Type Description
new_state StateType

The new policy state.

value Float[Array, '']

The value of the observation.

log_prob Float[Array, '']

The log probability of the action.

entropy Float[Array, '']

The entropy of the action distribution.

value abstractmethod

value(
    state: StateType, observation: ObsType
) -> tuple[StateType, Float[Array, ""]]

Get the value of an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to get the value for.

required

Returns:

Name Type Description
new_state StateType

The new policy state.

value Float[Array, '']

The value of the observation.