Skip to content

Algorithm

lerax.algorithm.AbstractAlgorithm

Bases: eqx.Module

Base class for RL algorithms.

Provides the main training loop and abstract methods for algorithm-specific behavior.

Attributes:

Name Type Description
optimizer eqx.AbstractVar[optax.GradientTransformation]

The optimizer used for training.

num_envs eqx.AbstractVar[int]

The number of parallel environments.

num_steps eqx.AbstractVar[int]

The number of steps per environment per iteration.

optimizer instance-attribute

optimizer: eqx.AbstractVar[optax.GradientTransformation]

num_envs instance-attribute

num_envs: eqx.AbstractVar[int]

num_steps instance-attribute

num_steps: eqx.AbstractVar[int]

reset abstractmethod

reset(
    env: AbstractEnvLike,
    policy: PolicyType,
    *,
    key: Key,
    callback: AbstractCallback,
) -> StateType

Return the initial carry for the training iteration.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
key Key

A JAX PRNG key.

required
callback AbstractCallback

A callback or list of callbacks to use during training.

required

Returns:

Type Description
StateType

The initial algorithm state.

per_iteration abstractmethod

per_iteration(state: StateType) -> StateType

Process the algorithm state after each iteration.

Used for algorithm-specific bookkeeping.

Parameters:

Name Type Description Default
state StateType

The current algorithm state.

required

Returns:

Type Description
StateType

The updated algorithm state.

iteration abstractmethod

iteration(
    state: StateType,
    *,
    key: Key,
    callback: AbstractCallback,
) -> StateType

Perform a single iteration of training.

Parameters:

Name Type Description Default
state StateType

The current algorithm state.

required
key Key

A JAX PRNG key.

required
callback AbstractCallback

A callback or list of callbacks to use during training.

required

Returns:

Type Description
StateType

The updated algorithm state.

num_iterations abstractmethod

num_iterations(total_timesteps: int) -> int

Number of iterations per training session.

learn

learn(
    env: AbstractEnvLike,
    policy: PolicyType,
    total_timesteps: int,
    *,
    key: Key,
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> PolicyType

Train the policy on the environment for a given number of timesteps.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
total_timesteps int

The total number of timesteps to train for.

required
key Key

A JAX PRNG key.

required
callback Sequence[AbstractCallback] | AbstractCallback | None

A callback or list of callbacks to use during training.

None

Returns:

Name Type Description
policy PolicyType

The trained policy.

lerax.algorithm.AbstractStepState

Bases: eqx.Module

Base class for algorithm state that is vectorized over environment steps.

Attributes:

Name Type Description
env_state eqx.AbstractVar[AbstractEnvLikeState]

The state of the environment.

policy_state eqx.AbstractVar[AbstractPolicyState]

The state of the policy.

callback_state eqx.AbstractVar[AbstractCallbackStepState]

The state of the callback for this step.

env_state instance-attribute

env_state: eqx.AbstractVar[AbstractEnvLikeState]

policy_state instance-attribute

policy_state: eqx.AbstractVar[AbstractPolicyState]

callback_state instance-attribute

callback_state: eqx.AbstractVar[AbstractCallbackStepState]

with_callback_state

with_callback_state[A: AbstractStepState](
    callback_state: AbstractCallbackStepState | None,
) -> A

Return a new step state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackStepState | None

The new callback state. If None, the existing state is used.

required

Returns:

Type Description
A

A new step state with the updated callback state.

lerax.algorithm.AbstractAlgorithmState

Bases: eqx.Module

Base class for algorithm states.

Attributes:

Name Type Description
iteration_count eqx.AbstractVar[Int[Array, '']]

The current iteration count.

step_state eqx.AbstractVar[AbstractStepState]

The state for the current step.

env eqx.AbstractVar[AbstractEnvLike]

The environment being used.

policy eqx.AbstractVar[PolicyType]

The policy being trained.

opt_state eqx.AbstractVar[optax.OptState]

The optimizer state.

callback_state eqx.AbstractVar[AbstractCallbackState]

The state of the callback for this iteration.

iteration_count instance-attribute

iteration_count: eqx.AbstractVar[Int[Array, '']]

step_state instance-attribute

step_state: eqx.AbstractVar[AbstractStepState]

env instance-attribute

env: eqx.AbstractVar[AbstractEnvLike]

policy instance-attribute

policy: eqx.AbstractVar[PolicyType]

opt_state instance-attribute

opt_state: eqx.AbstractVar[optax.OptState]

callback_state instance-attribute

callback_state: eqx.AbstractVar[AbstractCallbackState]

next

next[A: AbstractAlgorithmState](
    step_state: AbstractStepState,
    policy: PolicyType,
    opt_state: optax.OptState,
) -> A

Return a new algorithm state for the next iteration.

Increments the iteration count and updates the step state, policy, and optimizer state.

Parameters:

Name Type Description Default
step_state AbstractStepState

The new step state.

required
policy PolicyType

The new policy.

required
opt_state optax.OptState

The new optimizer state.

required

Returns:

Type Description
A

A new algorithm state with the updated fields.

with_callback_states

with_callback_states[A: AbstractAlgorithmState](
    callback_state: AbstractCallbackState,
) -> A

Return a new algorithm state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackState

The new callback state.

required

Returns:

Type Description
A

A new algorithm state with the updated callback state.