Skip to content

NCDE

lerax.policy.NCDEActorCriticPolicy

Bases: AbstractActorCriticPolicy[NCDEPolicyState, ActType, ObsType, MaskType]

Actor–critic with a shared MLPNeuralCDE encoder and MLP heads.

Acts by encoding observations with a Neural CDE, then passing the encoded features to separate MLPs to produce action distributions and value estimates.

Attributes:

Name Type Description
name str

Name of the policy class.

action_space AbstractSpace[ActType, MaskType]

The action space of the environment.

observation_space AbstractSpace[ObsType, Any]

The observation space of the environment.

encoder MLPNeuralCDE

Neural CDE to encode observations into features.

value_head MLP

MLP to produce value estimates from features.

action_head ActionLayer

MLP to produce action distributions from features.

Parameters:

Name Type Description Default
env AbstractEnvLike[StateType, ActType, ObsType, MaskType]

The environment to create the policy for.

required
solver diffrax.AbstractSolver | None

Diffrax solver to use for the Neural CDE.

None
feature_size int

Size of the feature representation.

4
latent_size int

Size of the latent state in the Neural CDE.

4
field_width int

Width of the hidden layers in the Neural CDE vector field.

8
field_depth int

Depth of the hidden layers in the Neural CDE vector field.

1
initial_width int

Width of the hidden layers in the Neural CDE initial network.

16
initial_depth int

Depth of the hidden layers in the Neural CDE initial network.

1
value_width int

Width of the hidden layers in the value head.

16
value_depth int

Depth of the hidden layers in the value head.

1
action_width int

Width of the hidden layers in the action head.

16
action_depth int

Depth of the hidden layers in the action head.

1
history_length int

Number of past observations to condition on.

4
dt float

Time step between observations for the Neural CDE.

1.0
log_std_init float

Initial log standard deviation for continuous action spaces.

0.0
key Key

JAX PRNG key for parameter initialization.

required

name class-attribute

name: str = 'NCDEActorCriticPolicy'

action_space instance-attribute

action_space: AbstractSpace[ActType, MaskType] = (
    env.action_space
)

observation_space instance-attribute

observation_space: AbstractSpace[ObsType, Any] = (
    env.observation_space
)

encoder instance-attribute

encoder: MLPNeuralCDE = MLPNeuralCDE(
    in_size=self.observation_space.flat_size,
    latent_size=latent_size,
    solver=solver,
    field_width=field_width,
    field_depth=field_depth,
    initial_width=initial_width,
    initial_depth=initial_depth,
    time_in_input=False,
    history_length=history_length,
    key=enc_key,
)

value_head instance-attribute

value_head: MLP = MLP(
    in_size=latent_size,
    out_size="scalar",
    width_size=value_width,
    depth=value_depth,
    key=val_key,
)

action_head instance-attribute

action_head: ActionLayer = ActionLayer(
    self.action_space,
    feature_size,
    action_width,
    action_depth,
    key=act_key,
    log_std_init=log_std_init,
)

dt class-attribute instance-attribute

dt: float = float(dt)

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.

__init__

__init__[StateType: AbstractEnvLikeState](
    env: AbstractEnvLike[
        StateType, ActType, ObsType, MaskType
    ],
    *,
    solver: diffrax.AbstractSolver | None = None,
    feature_size: int = 4,
    latent_size: int = 4,
    field_width: int = 8,
    field_depth: int = 1,
    initial_width: int = 16,
    initial_depth: int = 1,
    value_width: int = 16,
    value_depth: int = 1,
    action_width: int = 16,
    action_depth: int = 1,
    history_length: int = 4,
    dt: float = 1.0,
    log_std_init: float = 0.0,
    key: Key,
)

_step_encoder

_step_encoder(
    state: NCDEPolicyState, obs: ObsType
) -> tuple[NCDEPolicyState, Float[Array, " feat"]]

reset

reset(*, key: Key) -> NCDEPolicyState

__call__

__call__(
    state: NCDEPolicyState,
    observation: ObsType,
    *,
    key: Key | None = None,
    action_mask: MaskType | None = None,
) -> tuple[NCDEPolicyState, ActType]

action_and_value

action_and_value(
    state: NCDEPolicyState,
    observation: ObsType,
    *,
    key: Key,
    action_mask: MaskType | None = None,
) -> tuple[
    NCDEPolicyState,
    ActType,
    Float[Array, ""],
    Float[Array, ""],
]

evaluate_action

evaluate_action(
    state: NCDEPolicyState,
    observation: ObsType,
    action: ActType,
    *,
    action_mask: MaskType | None = None,
) -> tuple[
    NCDEPolicyState,
    Float[Array, ""],
    Float[Array, ""],
    Float[Array, ""],
]

value

value(
    state: NCDEPolicyState, observation: ObsType
) -> tuple[NCDEPolicyState, Float[Array, ""]]