Skip to content

On Policy Algorithm

lerax.algorithm.AbstractOnPolicyAlgorithm

Bases: AbstractAlgorithm[PolicyType, AbstractOnPolicyState[PolicyType]]

Base class for on-policy algorithms.

Collects rollouts using the current policy and trains the policy using the collected data. Subclasses implement step to define how actions are selected and what data is collected, post_collect to define any post-collection processing, and train to define the training procedure.

Attributes:

Name Type Description
optimizer eqx.AbstractVar[optax.GradientTransformation]

The optimizer used for training the policy.

gamma eqx.AbstractVar[float]

The discount factor.

num_envs eqx.AbstractVar[int]

The number of parallel environments.

num_steps eqx.AbstractVar[int]

The number of steps to collect per environment.

batch_size eqx.AbstractVar[int]

The batch size for training.

optimizer instance-attribute

optimizer: eqx.AbstractVar[optax.GradientTransformation]

gamma instance-attribute

gamma: eqx.AbstractVar[float]

num_envs instance-attribute

num_envs: eqx.AbstractVar[int]

num_steps instance-attribute

num_steps: eqx.AbstractVar[int]

batch_size instance-attribute

batch_size: eqx.AbstractVar[int]

per_iteration abstractmethod

per_iteration(state: StateType) -> StateType

Process the algorithm state after each iteration.

Used for algorithm-specific bookkeeping.

Parameters:

Name Type Description Default
state StateType

The current algorithm state.

required

Returns:

Type Description
StateType

The updated algorithm state.

consolidate_callbacks

consolidate_callbacks(
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> AbstractCallback

learn

learn(
    env: AbstractEnvLike,
    policy: PolicyType,
    total_timesteps: int,
    *,
    key: Key[Array, ""],
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> PolicyType

Train the policy on the environment for a given number of timesteps.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
total_timesteps int

The total number of timesteps to train for.

required
key Key[Array, '']

A JAX PRNG key.

required
callback Sequence[AbstractCallback] | AbstractCallback | None

A callback or list of callbacks to use during training.

None

Returns:

Name Type Description
policy PolicyType

The trained policy.

num_iterations

num_iterations(total_timesteps: int) -> int

per_step abstractmethod

per_step(
    step_state: AbstractOnPolicyStepState[PolicyType],
) -> AbstractOnPolicyStepState[PolicyType]

Process the step carry after each step.

step abstractmethod

step(
    env: AbstractEnvLike,
    policy: PolicyType,
    state: AbstractOnPolicyStepState[PolicyType],
    *,
    key: Key[Array, ""],
    callback: AbstractCallback,
) -> tuple[
    AbstractOnPolicyStepState[PolicyType], RolloutBuffer
]

Perform a single environment step and collect rollout data.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment.

required
policy PolicyType

The current policy.

required
state AbstractOnPolicyStepState[PolicyType]

The current step state.

required
key Key[Array, '']

A JAX PRNG key.

required
callback AbstractCallback

The callback for this step.

required

Returns:

Type Description
tuple[AbstractOnPolicyStepState[PolicyType], RolloutBuffer]

A tuple of the new step state and a rollout buffer entry.

post_collect

post_collect(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: AbstractOnPolicyStepState[PolicyType],
    buffer: RolloutBuffer,
    *,
    key: Key[Array, ""],
) -> RolloutBuffer

Process the rollout buffer after collection.

Override to compute returns, advantages, or other post-collection processing. By default returns the buffer unchanged.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment.

required
policy PolicyType

The current policy.

required
step_state AbstractOnPolicyStepState[PolicyType]

The step state after the rollout.

required
buffer RolloutBuffer

The collected rollout buffer.

required
key Key[Array, '']

A JAX PRNG key.

required

Returns:

Type Description
RolloutBuffer

The processed rollout buffer.

collect_rollout

collect_rollout(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: AbstractOnPolicyStepState[PolicyType],
    callback: AbstractCallback,
    key: Key[Array, ""],
) -> tuple[
    AbstractOnPolicyStepState[PolicyType], RolloutBuffer
]

Collect a rollout using the current policy.

train abstractmethod

train(
    policy: PolicyType,
    opt_state: optax.OptState,
    buffer: RolloutBuffer,
    *,
    key: Key[Array, ""],
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]

Train the policy using the rollout buffer.

Parameters:

Name Type Description Default
policy PolicyType

The current policy.

required
opt_state optax.OptState

The current optimizer state.

required
buffer RolloutBuffer

The rollout buffer containing collected experiences.

required
key Key[Array, '']

A JAX PRNG key.

required

Returns:

Type Description
PolicyType

A tuple containing the updated policy, updated optimizer state,

optax.OptState

and a log dictionary.

reset

reset(
    env: AbstractEnvLike,
    policy: PolicyType,
    *,
    key: Key[Array, ""],
    callback: AbstractCallback,
) -> AbstractOnPolicyState[PolicyType]

iteration

iteration(
    state: AbstractOnPolicyState[PolicyType],
    *,
    key: Key[Array, ""],
    callback: AbstractCallback,
) -> AbstractOnPolicyState[PolicyType]

lerax.algorithm.AbstractOnPolicyStepState

Bases: AbstractStepState

State for on-policy algorithm steps.

Attributes:

Name Type Description
env_state eqx.AbstractVar[AbstractEnvLikeState]

The state of the environment.

policy_state eqx.AbstractVar[AbstractPolicyState]

The state of the policy.

callback_state eqx.AbstractVar[AbstractCallbackStepState]

The state of the callback for this step.

env_state instance-attribute

env_state: eqx.AbstractVar[AbstractEnvLikeState]

policy_state instance-attribute

policy_state: eqx.AbstractVar[AbstractPolicyState]

callback_state instance-attribute

callback_state: eqx.AbstractVar[AbstractCallbackStepState]

with_callback_state

with_callback_state[A: AbstractStepState](
    callback_state: AbstractCallbackStepState | None,
) -> A

Return a new step state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackStepState | None

The new callback state. If None, the existing state is used.

required

Returns:

Type Description
A

A new step state with the updated callback state.

initial classmethod

initial(
    env: AbstractEnvLike,
    policy: PolicyType,
    callback: AbstractCallback,
    key: Key[Array, ""],
) -> AbstractOnPolicyStepState[PolicyType]

Initialize the step state for the on-policy algorithm.

Resets the environment, policy, and callback states.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to initialize.

required
policy PolicyType

The policy to initialize.

required
callback AbstractCallback

The callback to initialize.

required
key Key[Array, '']

A JAX PRNG key.

required

Returns:

Type Description
AbstractOnPolicyStepState[PolicyType]

The initialized step state.

lerax.algorithm.AbstractOnPolicyState

Bases: AbstractAlgorithmState[PolicyType]

State for on-policy algorithms.

Attributes:

Name Type Description
iteration_count eqx.AbstractVar[Int[Array, '']]

The current iteration count.

step_state eqx.AbstractVar[AbstractOnPolicyStepState[PolicyType]]

The state for the current step.

env eqx.AbstractVar[AbstractEnvLike]

The environment being used.

policy eqx.AbstractVar[PolicyType]

The policy being trained.

opt_state eqx.AbstractVar[optax.OptState]

The optimizer state.

callback_state eqx.AbstractVar[AbstractCallbackState]

The state of the callback for this iteration.

iteration_count instance-attribute

iteration_count: eqx.AbstractVar[Int[Array, '']]

step_state instance-attribute

step_state: eqx.AbstractVar[
    AbstractOnPolicyStepState[PolicyType]
]

env instance-attribute

env: eqx.AbstractVar[AbstractEnvLike]

policy instance-attribute

policy: eqx.AbstractVar[PolicyType]

opt_state instance-attribute

opt_state: eqx.AbstractVar[optax.OptState]

callback_state instance-attribute

callback_state: eqx.AbstractVar[AbstractCallbackState]

next

next[A: AbstractAlgorithmState](
    step_state: AbstractStepState,
    policy: PolicyType,
    opt_state: optax.OptState,
) -> A

Return a new algorithm state for the next iteration.

Increments the iteration count and updates the step state, policy, and optimizer state.

Parameters:

Name Type Description Default
step_state AbstractStepState

The new step state.

required
policy PolicyType

The new policy.

required
opt_state optax.OptState

The new optimizer state.

required

Returns:

Type Description
A

A new algorithm state with the updated fields.

with_callback_states

with_callback_states[A: AbstractAlgorithmState](
    callback_state: AbstractCallbackState,
) -> A

Return a new algorithm state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackState

The new callback state.

required

Returns:

Type Description
A

A new algorithm state with the updated callback state.