Skip to content

Abstract Actor Critic Policies

lerax.policy.AbstractActorCriticPolicy

Bases: AbstractPolicy[StateType, ActType, ObsType, MaskType]

Base class for stateful actor-critic policies.

Actor-critic policies map observations and internal states to actions, values, and new internal states.

Attributes:

Name Type Description
name eqx.AbstractClassVar[str]

The name of the policy.

action_space eqx.AbstractVar[AbstractSpace[ActType, MaskType]]

The action space of the policy.

observation_space eqx.AbstractVar[AbstractSpace[ObsType, Any]]

The observation space of the policy.

name instance-attribute

name: eqx.AbstractClassVar[str]

action_space instance-attribute

action_space: eqx.AbstractVar[
    AbstractSpace[ActType, MaskType]
]

observation_space instance-attribute

observation_space: eqx.AbstractVar[
    AbstractSpace[ObsType, Any]
]

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.

__call__ abstractmethod

__call__(
    state: StateType,
    observation: ObsType,
    *,
    key: Key | None = None,
    action_mask: MaskType | None = None,
) -> tuple[StateType, ActType]

Return the next action and new internal state given the current observation and internal state.

A key can be provided for stochastic policies. If no key is provided, the policy should behave deterministically.

Parameters:

Name Type Description Default
state StateType

The current internal state of the policy.

required
observation ObsType

The current observation.

required
key Key | None

An optional JAX random key for stochastic policies.

None
action_mask MaskType | None

An optional action mask.

None

Returns:

Type Description
tuple[StateType, ActType]

The new internal state and the action to take.

reset abstractmethod

reset(*, key: Key) -> StateType

Return an initial internal state for the policy.

Parameters:

Name Type Description Default
key Key

A JAX random key for initializing the state.

required

Returns:

Type Description
StateType

An initial internal state for the policy.

action_and_value abstractmethod

action_and_value(
    state: StateType,
    observation: ObsType,
    *,
    key: Key,
    action_mask: MaskType | None = None,
) -> tuple[
    StateType, ActType, Float[Array, ""], Float[Array, ""]
]

Get an action and value from an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to get the action and value for.

required
key Key

A JAX PRNG key.

required
action_mask MaskType | None

An optional action mask.

None

Returns:

Name Type Description
new_state StateType

The new policy state.

action ActType

The action to take.

value Float[Array, '']

The value of the observation.

log_prob Float[Array, '']

The log probability of the action.

evaluate_action abstractmethod

evaluate_action(
    state: StateType,
    observation: ObsType,
    action: ActType,
    *,
    action_mask: MaskType | None = None,
) -> tuple[
    StateType,
    Float[Array, ""],
    Float[Array, ""],
    Float[Array, ""],
]

Evaluate an action given an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to evaluate the action for.

required
action ActType

The action to evaluate.

required
action_mask MaskType | None

An optional action mask.

None

Returns:

Name Type Description
new_state StateType

The new policy state.

value Float[Array, '']

The value of the observation.

log_prob Float[Array, '']

The log probability of the action.

entropy Float[Array, '']

The entropy of the action distribution.

value abstractmethod

value(
    state: StateType, observation: ObsType
) -> tuple[StateType, Float[Array, ""]]

Get the value of an observation.

Parameters:

Name Type Description Default
state StateType

The current policy state.

required
observation ObsType

The observation to get the value for.

required

Returns:

Name Type Description
new_state StateType

The new policy state.

value Float[Array, '']

The value of the observation.