NCDE
lerax.policy.NCDEActorCriticPolicy
Bases:
Actor–critic with a shared MLPNeuralCDE encoder and MLP heads.
Acts by encoding observations with a Neural CDE, then passing the encoded features to separate MLPs to produce action distributions and value estimates.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
Name of the policy class. |
|
|
The action space of the environment. |
|
|
The observation space of the environment. |
|
|
Neural CDE to encode observations into features. |
|
|
MLP to produce value estimates from features. |
|
|
MLP to produce action distributions from features. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to create the policy for. |
required |
solver
|
|
Diffrax solver to use for the Neural CDE. |
None
|
feature_size
|
|
Size of the feature representation. |
4
|
latent_size
|
|
Size of the latent state in the Neural CDE. |
4
|
field_width
|
|
Width of the hidden layers in the Neural CDE vector field. |
8
|
field_depth
|
|
Depth of the hidden layers in the Neural CDE vector field. |
1
|
initial_width
|
|
Width of the hidden layers in the Neural CDE initial network. |
16
|
initial_depth
|
|
Depth of the hidden layers in the Neural CDE initial network. |
1
|
value_width
|
|
Width of the hidden layers in the value head. |
16
|
value_depth
|
|
Depth of the hidden layers in the value head. |
1
|
action_width
|
|
Width of the hidden layers in the action head. |
16
|
action_depth
|
|
Depth of the hidden layers in the action head. |
1
|
history_length
|
|
Number of past observations to condition on. |
4
|
dt
|
|
Time step between observations for the Neural CDE. |
1.0
|
log_std_init
|
|
Initial log standard deviation for continuous action spaces. |
0.0
|
key
|
|
JAX PRNG key for parameter initialization. |
required |
action_space
instance-attribute
observation_space
instance-attribute
encoder
instance-attribute
encoder: MLPNeuralCDE = MLPNeuralCDE(
in_size=self.observation_space.flat_size,
latent_size=latent_size,
solver=solver,
field_width=field_width,
field_depth=field_depth,
initial_width=initial_width,
initial_depth=initial_depth,
time_in_input=False,
history_length=history_length,
key=enc_key,
)
value_head
instance-attribute
value_head: MLP = MLP(
in_size=latent_size,
out_size="scalar",
width_size=value_width,
depth=value_depth,
key=val_key,
)
action_head
instance-attribute
action_head: ActionLayer = ActionLayer(
self.action_space,
feature_size,
action_width,
action_depth,
key=act_key,
log_std_init=log_std_init,
)
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |
__init__
__init__[StateType: AbstractEnvLikeState](
env: AbstractEnvLike[
StateType, ActType, ObsType, MaskType
],
*,
solver: diffrax.AbstractSolver | None = None,
feature_size: int = 4,
latent_size: int = 4,
field_width: int = 8,
field_depth: int = 1,
initial_width: int = 16,
initial_depth: int = 1,
value_width: int = 16,
value_depth: int = 1,
action_width: int = 16,
action_depth: int = 1,
history_length: int = 4,
dt: float = 1.0,
log_std_init: float = 0.0,
key: Key,
)
_step_encoder
_step_encoder(
state: NCDEPolicyState, obs: ObsType
) -> tuple[NCDEPolicyState, Float[Array, " feat"]]
__call__
__call__(
state: NCDEPolicyState,
observation: ObsType,
*,
key: Key | None = None,
action_mask: MaskType | None = None,
) -> tuple[NCDEPolicyState, ActType]
action_and_value
action_and_value(
state: NCDEPolicyState,
observation: ObsType,
*,
key: Key,
action_mask: MaskType | None = None,
) -> tuple[
NCDEPolicyState,
ActType,
Float[Array, ""],
Float[Array, ""],
]