Skip to content

On Policy Algorithm

lerax.algorithm.AbstractOnPolicyAlgorithm

Bases: AbstractAlgorithm[PolicyType, OnPolicyState[PolicyType]]

Base class for on-policy algorithms.

Generates rollouts using the current policy and estimates advantages and returns using GAE. Trains the policy using the collected rollouts.

Attributes:

Name Type Description
optimizer eqx.AbstractVar[optax.GradientTransformation]

The optimizer used for training the policy.

gae_lambda eqx.AbstractVar[float]

The GAE lambda parameter.

gamma eqx.AbstractVar[float]

The discount factor.

num_envs eqx.AbstractVar[int]

The number of parallel environments.

num_steps eqx.AbstractVar[int]

The number of steps to collect per environment.

batch_size eqx.AbstractVar[int]

The batch size for training.

optimizer instance-attribute

optimizer: eqx.AbstractVar[optax.GradientTransformation]

gae_lambda instance-attribute

gae_lambda: eqx.AbstractVar[float]

gamma instance-attribute

gamma: eqx.AbstractVar[float]

num_envs instance-attribute

num_envs: eqx.AbstractVar[int]

num_steps instance-attribute

num_steps: eqx.AbstractVar[int]

batch_size instance-attribute

batch_size: eqx.AbstractVar[int]

per_iteration abstractmethod

per_iteration(state: StateType) -> StateType

Process the algorithm state after each iteration.

Used for algorithm-specific bookkeeping.

Parameters:

Name Type Description Default
state StateType

The current algorithm state.

required

Returns:

Type Description
StateType

The updated algorithm state.

learn

learn(
    env: AbstractEnvLike,
    policy: PolicyType,
    total_timesteps: int,
    *,
    key: Key,
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> PolicyType

Train the policy on the environment for a given number of timesteps.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
total_timesteps int

The total number of timesteps to train for.

required
key Key

A JAX PRNG key.

required
callback Sequence[AbstractCallback] | AbstractCallback | None

A callback or list of callbacks to use during training.

None

Returns:

Name Type Description
policy PolicyType

The trained policy.

num_iterations

num_iterations(total_timesteps: int) -> int

per_step abstractmethod

per_step(
    step_state: OnPolicyStepState[PolicyType],
) -> OnPolicyStepState[PolicyType]

Process the step carry after each step.

step

step(
    env: AbstractEnvLike,
    policy: PolicyType,
    state: OnPolicyStepState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]

collect_rollout

collect_rollout(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: OnPolicyStepState[PolicyType],
    callback: AbstractCallback,
    key: Key,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]

Collect a rollout using the current policy.

train abstractmethod

train(
    policy: PolicyType,
    opt_state: optax.OptState,
    buffer: RolloutBuffer,
    *,
    key: Key,
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]

Train the policy using the rollout buffer.

Parameters:

Name Type Description Default
policy PolicyType

The current policy.

required
opt_state optax.OptState

The current optimizer state.

required
buffer RolloutBuffer

The rollout buffer containing collected experiences.

required
key Key

A JAX PRNG key.

required

Returns:

Type Description
PolicyType

A tuple containing the updated policy, updated optimizer state,

optax.OptState

and a log dictionary.

reset

reset(
    env: AbstractEnvLike,
    policy: PolicyType,
    *,
    key: Key,
    callback: AbstractCallback,
) -> OnPolicyState[PolicyType]

iteration

iteration(
    state: OnPolicyState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> OnPolicyState[PolicyType]

lerax.algorithm.OnPolicyStepState

Bases: AbstractStepState

State for on-policy algorithm steps.

Attributes:

Name Type Description
env_state AbstractEnvLikeState

The state of the environment.

policy_state AbstractPolicyState

The state of the policy.

callback_state AbstractCallbackStepState

The state of the callback for this step.

env_state instance-attribute

env_state: AbstractEnvLikeState

policy_state instance-attribute

policy_state: AbstractPolicyState

callback_state instance-attribute

callback_state: AbstractCallbackStepState

with_callback_state

with_callback_state[A: AbstractStepState](
    callback_state: AbstractCallbackStepState | None,
) -> A

Return a new step state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackStepState | None

The new callback state. If None, the existing state is used.

required

Returns:

Type Description
A

A new step state with the updated callback state.

initial classmethod

initial(
    env: AbstractEnvLike,
    policy: PolicyType,
    callback: AbstractCallback,
    key: Key,
) -> OnPolicyStepState[PolicyType]

Initialize the step state for the on-policy algorithm.

Resets the environment, policy, and callback states.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to initialize.

required
policy PolicyType

The policy to initialize.

required
callback AbstractCallback

The callback to initialize.

required
key Key

A JAX PRNG key.

required

Returns:

Type Description
OnPolicyStepState[PolicyType]

The initialized step state.

lerax.algorithm.OnPolicyState

Bases: AbstractAlgorithmState[PolicyType]

State for on-policy algorithms.

Attributes:

Name Type Description
iteration_count Int[Array, '']

The current iteration count.

step_state OnPolicyStepState[PolicyType]

The state for the current step.

env AbstractEnvLike

The environment being used.

policy PolicyType

The policy being trained.

opt_state optax.OptState

The optimizer state.

callback_state AbstractCallbackState

The state of the callback for this iteration.

iteration_count instance-attribute

iteration_count: Int[Array, '']

step_state instance-attribute

step_state: OnPolicyStepState[PolicyType]

env instance-attribute

env: AbstractEnvLike

policy instance-attribute

policy: PolicyType

opt_state instance-attribute

opt_state: optax.OptState

callback_state instance-attribute

callback_state: AbstractCallbackState

next

next[A: AbstractAlgorithmState](
    step_state: AbstractStepState,
    policy: PolicyType,
    opt_state: optax.OptState,
) -> A

Return a new algorithm state for the next iteration.

Increments the iteration count and updates the step state, policy, and optimizer state.

Parameters:

Name Type Description Default
step_state AbstractStepState

The new step state.

required
policy PolicyType

The new policy.

required
opt_state optax.OptState

The new optimizer state.

required

Returns:

Type Description
A

A new algorithm state with the updated fields.

with_callback_states

with_callback_states[A: AbstractAlgorithmState](
    callback_state: AbstractCallbackState,
) -> A

Return a new algorithm state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackState

The new callback state.

required

Returns:

Type Description
A

A new algorithm state with the updated callback state.