Skip to content

Off Policy Algorithm

lerax.algorithm.AbstractOffPolicyAlgorithm

Bases: AbstractAlgorithm[PolicyType, OffPolicyState[PolicyType]]

Base class for off-policy algorithms.

Generates experience using a policy and environment, stores it in a replay buffer, and trains the policy using samples from the replay buffer.

Attributes:

Name Type Description
optimizer eqx.AbstractVar[optax.GradientTransformation]

The optimizer used for training the policy.

buffer_size eqx.AbstractVar[int]

The size of the replay buffer.

gamma eqx.AbstractVar[float]

The discount factor for future rewards.

learning_starts eqx.AbstractVar[int]

The number of initial steps to collect before training.

num_envs eqx.AbstractVar[int]

The number of parallel environments.

batch_size eqx.AbstractVar[int]

The batch size for training.

optimizer instance-attribute

optimizer: eqx.AbstractVar[optax.GradientTransformation]

buffer_size instance-attribute

buffer_size: eqx.AbstractVar[int]

gamma instance-attribute

gamma: eqx.AbstractVar[float]

learning_starts instance-attribute

learning_starts: eqx.AbstractVar[int]

num_envs instance-attribute

num_envs: eqx.AbstractVar[int]

num_steps instance-attribute

num_steps: eqx.AbstractVar[int]

batch_size instance-attribute

batch_size: eqx.AbstractVar[int]

per_iteration abstractmethod

per_iteration(state: StateType) -> StateType

Process the algorithm state after each iteration.

Used for algorithm-specific bookkeeping.

Parameters:

Name Type Description Default
state StateType

The current algorithm state.

required

Returns:

Type Description
StateType

The updated algorithm state.

learn

learn(
    env: AbstractEnvLike,
    policy: PolicyType,
    total_timesteps: int,
    *,
    key: Key,
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> PolicyType

Train the policy on the environment for a given number of timesteps.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
total_timesteps int

The total number of timesteps to train for.

required
key Key

A JAX PRNG key.

required
callback Sequence[AbstractCallback] | AbstractCallback | None

A callback or list of callbacks to use during training.

None

Returns:

Name Type Description
policy PolicyType

The trained policy.

num_iterations

num_iterations(total_timesteps: int) -> int

per_step abstractmethod

per_step(
    step_state: OffPolicyStepState[PolicyType],
) -> OffPolicyStepState[PolicyType]

Process the step carry after each step.

step

step(
    env: AbstractEnvLike,
    policy: PolicyType,
    state: OffPolicyStepState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> OffPolicyStepState[PolicyType]

collect_learning_starts

collect_learning_starts(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: OffPolicyStepState[PolicyType],
    callback: AbstractCallback,
    key: Key,
) -> OffPolicyStepState[PolicyType]

collect_rollout

collect_rollout(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: OffPolicyStepState[PolicyType],
    callback: AbstractCallback,
    key: Key,
) -> OffPolicyStepState[PolicyType]

train abstractmethod

train(
    policy: PolicyType,
    opt_state: optax.OptState,
    buffer: ReplayBuffer,
    *,
    key: Key,
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]

Trains the policy using data from the replay buffer.

Parameters:

Name Type Description Default
policy PolicyType

The policy to train.

required
opt_state optax.OptState

The current optimizer state.

required
buffer ReplayBuffer

The replay buffer containing experience.

required
key Key

A JAX PRNG key.

required

Returns:

Type Description
PolicyType

A tuple containing the updated policy, updated optimizer state,

optax.OptState

and a log dictionary with training information.

reset

reset(
    env: AbstractEnvLike,
    policy: PolicyType,
    *,
    key: Key,
    callback: AbstractCallback,
) -> OffPolicyState[PolicyType]

iteration

iteration(
    state: OffPolicyState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> OffPolicyState[PolicyType]

lerax.algorithm.OffPolicyStepState

Bases: AbstractStepState

State object for off-policy algorithms steps.

Attributes:

Name Type Description
env_state AbstractEnvLikeState

The state of the environment.

policy_state AbstractPolicyState

The state of the policy.

callback_state AbstractCallbackStepState

The state of the callback for this step.

buffer ReplayBuffer

The replay buffer storing experience.

env_state instance-attribute

env_state: AbstractEnvLikeState

policy_state instance-attribute

policy_state: AbstractPolicyState

callback_state instance-attribute

callback_state: AbstractCallbackStepState

buffer instance-attribute

buffer: ReplayBuffer

with_callback_state

with_callback_state[A: AbstractStepState](
    callback_state: AbstractCallbackStepState | None,
) -> A

Return a new step state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackStepState | None

The new callback state. If None, the existing state is used.

required

Returns:

Type Description
A

A new step state with the updated callback state.

initial classmethod

initial(
    size: int,
    env: AbstractEnvLike,
    policy: PolicyType,
    callback: AbstractCallback,
    key: Key,
) -> OffPolicyStepState[PolicyType]

Initialize the off-policy step state.

Resets the environment and policy, initializes the callback state, and creates an empty replay buffer.

Parameters:

Name Type Description Default
size int

The size of the replay buffer.

required
env AbstractEnvLike

The environment to initialize.

required
policy PolicyType

The policy to initialize.

required
callback AbstractCallback

The callback to initialize.

required
key Key

A JAX PRNG key.

required

Returns:

Type Description
OffPolicyStepState[PolicyType]

The initialized step state.

lerax.algorithm.OffPolicyState

Bases: AbstractAlgorithmState[PolicyType]

State for off-policy algorithms.

Attributes:

Name Type Description
iteration_count Int[Array, '']

The current iteration count.

step_state OffPolicyStepState[PolicyType]

The current step state.

env AbstractEnvLike

The environment being used.

policy PolicyType

The policy being trained.

opt_state optax.OptState

The optimizer state.

callback_state AbstractCallbackState

The state of the callback for this iteration.

iteration_count instance-attribute

iteration_count: Int[Array, '']

step_state instance-attribute

step_state: OffPolicyStepState[PolicyType]

env instance-attribute

env: AbstractEnvLike

policy instance-attribute

policy: PolicyType

opt_state instance-attribute

opt_state: optax.OptState

callback_state instance-attribute

callback_state: AbstractCallbackState

next

next[A: AbstractAlgorithmState](
    step_state: AbstractStepState,
    policy: PolicyType,
    opt_state: optax.OptState,
) -> A

Return a new algorithm state for the next iteration.

Increments the iteration count and updates the step state, policy, and optimizer state.

Parameters:

Name Type Description Default
step_state AbstractStepState

The new step state.

required
policy PolicyType

The new policy.

required
opt_state optax.OptState

The new optimizer state.

required

Returns:

Type Description
A

A new algorithm state with the updated fields.

with_callback_states

with_callback_states[A: AbstractAlgorithmState](
    callback_state: AbstractCallbackState,
) -> A

Return a new algorithm state with the given callback state.

Parameters:

Name Type Description Default
callback_state AbstractCallbackState

The new callback state.

required

Returns:

Type Description
A

A new algorithm state with the updated callback state.