On Policy Algorithm
lerax.algorithm.AbstractOnPolicyAlgorithm
Bases: AbstractAlgorithm[PolicyType, AbstractOnPolicyState[PolicyType]]
Base class for on-policy algorithms.
Collects rollouts using the current policy and trains the policy
using the collected data. Subclasses implement step to define
how actions are selected and what data is collected, post_collect
to define any post-collection processing, and train to define
the training procedure.
Attributes:
| Name | Type | Description |
|---|---|---|
optimizer |
eqx.AbstractVar[optax.GradientTransformation]
|
The optimizer used for training the policy. |
gamma |
eqx.AbstractVar[float]
|
The discount factor. |
num_envs |
eqx.AbstractVar[int]
|
The number of parallel environments. |
num_steps |
eqx.AbstractVar[int]
|
The number of steps to collect per environment. |
batch_size |
eqx.AbstractVar[int]
|
The batch size for training. |
per_iteration
abstractmethod
Process the algorithm state after each iteration.
Used for algorithm-specific bookkeeping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current algorithm state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
The updated algorithm state. |
consolidate_callbacks
consolidate_callbacks(
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> AbstractCallback
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key[Array, ""],
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment to train on. |
required |
policy
|
PolicyType
|
The policy to train. |
required |
total_timesteps
|
int
|
The total number of timesteps to train for. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
Sequence[AbstractCallback] | AbstractCallback | None
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
PolicyType
|
The trained policy. |
per_step
abstractmethod
per_step(
step_state: AbstractOnPolicyStepState[PolicyType],
) -> AbstractOnPolicyStepState[PolicyType]
Process the step carry after each step.
step
abstractmethod
step(
env: AbstractEnvLike,
policy: PolicyType,
state: AbstractOnPolicyStepState[PolicyType],
*,
key: Key[Array, ""],
callback: AbstractCallback,
) -> tuple[
AbstractOnPolicyStepState[PolicyType], RolloutBuffer
]
Perform a single environment step and collect rollout data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment. |
required |
policy
|
PolicyType
|
The current policy. |
required |
state
|
AbstractOnPolicyStepState[PolicyType]
|
The current step state. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
AbstractCallback
|
The callback for this step. |
required |
Returns:
| Type | Description |
|---|---|
tuple[AbstractOnPolicyStepState[PolicyType], RolloutBuffer]
|
A tuple of the new step state and a rollout buffer entry. |
post_collect
post_collect(
env: AbstractEnvLike,
policy: PolicyType,
step_state: AbstractOnPolicyStepState[PolicyType],
buffer: RolloutBuffer,
*,
key: Key[Array, ""],
) -> RolloutBuffer
Process the rollout buffer after collection.
Override to compute returns, advantages, or other post-collection processing. By default returns the buffer unchanged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment. |
required |
policy
|
PolicyType
|
The current policy. |
required |
step_state
|
AbstractOnPolicyStepState[PolicyType]
|
The step state after the rollout. |
required |
buffer
|
RolloutBuffer
|
The collected rollout buffer. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
RolloutBuffer
|
The processed rollout buffer. |
collect_rollout
collect_rollout(
env: AbstractEnvLike,
policy: PolicyType,
step_state: AbstractOnPolicyStepState[PolicyType],
callback: AbstractCallback,
key: Key[Array, ""],
) -> tuple[
AbstractOnPolicyStepState[PolicyType], RolloutBuffer
]
Collect a rollout using the current policy.
train
abstractmethod
train(
policy: PolicyType,
opt_state: optax.OptState,
buffer: RolloutBuffer,
*,
key: Key[Array, ""],
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]
Train the policy using the rollout buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
PolicyType
|
The current policy. |
required |
opt_state
|
optax.OptState
|
The current optimizer state. |
required |
buffer
|
RolloutBuffer
|
The rollout buffer containing collected experiences. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
PolicyType
|
A tuple containing the updated policy, updated optimizer state, |
optax.OptState
|
and a log dictionary. |
lerax.algorithm.AbstractOnPolicyStepState
Bases: AbstractStepState
State for on-policy algorithm steps.
Attributes:
| Name | Type | Description |
|---|---|---|
env_state |
eqx.AbstractVar[AbstractEnvLikeState]
|
The state of the environment. |
policy_state |
eqx.AbstractVar[AbstractPolicyState]
|
The state of the policy. |
callback_state |
eqx.AbstractVar[AbstractCallbackStepState]
|
The state of the callback for this step. |
with_callback_state
Return a new step state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackStepState | None
|
The new callback state. If None, the existing state is used. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new step state with the updated callback state. |
initial
classmethod
initial(
env: AbstractEnvLike,
policy: PolicyType,
callback: AbstractCallback,
key: Key[Array, ""],
) -> AbstractOnPolicyStepState[PolicyType]
Initialize the step state for the on-policy algorithm.
Resets the environment, policy, and callback states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment to initialize. |
required |
policy
|
PolicyType
|
The policy to initialize. |
required |
callback
|
AbstractCallback
|
The callback to initialize. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
AbstractOnPolicyStepState[PolicyType]
|
The initialized step state. |
lerax.algorithm.AbstractOnPolicyState
Bases: AbstractAlgorithmState[PolicyType]
State for on-policy algorithms.
Attributes:
| Name | Type | Description |
|---|---|---|
iteration_count |
eqx.AbstractVar[Int[Array, '']]
|
The current iteration count. |
step_state |
eqx.AbstractVar[AbstractOnPolicyStepState[PolicyType]]
|
The state for the current step. |
env |
eqx.AbstractVar[AbstractEnvLike]
|
The environment being used. |
policy |
eqx.AbstractVar[PolicyType]
|
The policy being trained. |
opt_state |
eqx.AbstractVar[optax.OptState]
|
The optimizer state. |
callback_state |
eqx.AbstractVar[AbstractCallbackState]
|
The state of the callback for this iteration. |
next
next[A: AbstractAlgorithmState](
step_state: AbstractStepState,
policy: PolicyType,
opt_state: optax.OptState,
) -> A
Return a new algorithm state for the next iteration.
Increments the iteration count and updates the step state, policy, and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_state
|
AbstractStepState
|
The new step state. |
required |
policy
|
PolicyType
|
The new policy. |
required |
opt_state
|
optax.OptState
|
The new optimizer state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated fields. |
with_callback_states
Return a new algorithm state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackState
|
The new callback state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated callback state. |