Off Policy Algorithm
lerax.algorithm.AbstractOffPolicyAlgorithm
Bases: AbstractAlgorithm[PolicyType, AbstractOffPolicyState[PolicyType]]
Base class for off-policy algorithms.
Generates experience using a policy and environment, stores it in a replay buffer, and trains the policy using samples from the replay buffer.
Attributes:
| Name | Type | Description |
|---|---|---|
optimizer |
eqx.AbstractVar[optax.GradientTransformation]
|
The optimizer used for training the policy. |
buffer_size |
eqx.AbstractVar[int]
|
The size of the replay buffer. |
gamma |
eqx.AbstractVar[float]
|
The discount factor for future rewards. |
learning_starts |
eqx.AbstractVar[int]
|
The number of initial steps to collect before training. |
num_envs |
eqx.AbstractVar[int]
|
The number of parallel environments. |
batch_size |
eqx.AbstractVar[int]
|
The batch size for training. |
per_iteration
abstractmethod
Process the algorithm state after each iteration.
Used for algorithm-specific bookkeeping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current algorithm state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
The updated algorithm state. |
consolidate_callbacks
consolidate_callbacks(
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> AbstractCallback
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key[Array, ""],
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment to train on. |
required |
policy
|
PolicyType
|
The policy to train. |
required |
total_timesteps
|
int
|
The total number of timesteps to train for. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
Sequence[AbstractCallback] | AbstractCallback | None
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
PolicyType
|
The trained policy. |
per_step
abstractmethod
per_step(
step_state: AbstractOffPolicyStepState[PolicyType],
) -> AbstractOffPolicyStepState[PolicyType]
Process the step carry after each step.
step
step(
env: AbstractEnvLike,
policy: PolicyType,
state: AbstractOffPolicyStepState[PolicyType],
*,
key: Key[Array, ""],
callback: AbstractCallback,
) -> AbstractOffPolicyStepState[PolicyType]
collect_learning_starts
collect_learning_starts(
env: AbstractEnvLike,
policy: PolicyType,
step_state: AbstractOffPolicyStepState[PolicyType],
callback: AbstractCallback,
key: Key[Array, ""],
) -> AbstractOffPolicyStepState[PolicyType]
collect_rollout
collect_rollout(
env: AbstractEnvLike,
policy: PolicyType,
step_state: AbstractOffPolicyStepState[PolicyType],
callback: AbstractCallback,
key: Key[Array, ""],
) -> AbstractOffPolicyStepState[PolicyType]
train
abstractmethod
train(
policy: PolicyType,
opt_state: optax.OptState,
buffer: ReplayBuffer,
*,
key: Key[Array, ""],
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]
Trains the policy using data from the replay buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
PolicyType
|
The policy to train. |
required |
opt_state
|
optax.OptState
|
The current optimizer state. |
required |
buffer
|
ReplayBuffer
|
The replay buffer containing experience. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
PolicyType
|
A tuple containing the updated policy, updated optimizer state, |
optax.OptState
|
and a log dictionary with training information. |
lerax.algorithm.AbstractOffPolicyStepState
Bases: AbstractStepState
State object for off-policy algorithms steps.
Attributes:
| Name | Type | Description |
|---|---|---|
env_state |
eqx.AbstractVar[AbstractEnvLikeState]
|
The state of the environment. |
policy_state |
eqx.AbstractVar[AbstractPolicyState]
|
The state of the policy. |
callback_state |
eqx.AbstractVar[AbstractCallbackStepState]
|
The state of the callback for this step. |
buffer |
eqx.AbstractVar[ReplayBuffer]
|
The replay buffer storing experience. |
with_callback_state
Return a new step state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackStepState | None
|
The new callback state. If None, the existing state is used. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new step state with the updated callback state. |
initial
classmethod
initial(
size: int,
env: AbstractEnvLike,
policy: PolicyType,
callback: AbstractCallback,
key: Key[Array, ""],
) -> AbstractOffPolicyStepState[PolicyType]
Initialize the off-policy step state.
Resets the environment and policy, initializes the callback state, and creates an empty replay buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
int
|
The size of the replay buffer. |
required |
env
|
AbstractEnvLike
|
The environment to initialize. |
required |
policy
|
PolicyType
|
The policy to initialize. |
required |
callback
|
AbstractCallback
|
The callback to initialize. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
AbstractOffPolicyStepState[PolicyType]
|
The initialized step state. |
lerax.algorithm.AbstractOffPolicyState
Bases: AbstractAlgorithmState[PolicyType]
State for off-policy algorithms.
Attributes:
| Name | Type | Description |
|---|---|---|
iteration_count |
eqx.AbstractVar[Int[Array, '']]
|
The current iteration count. |
step_state |
eqx.AbstractVar[AbstractOffPolicyStepState[PolicyType]]
|
The current step state. |
env |
eqx.AbstractVar[AbstractEnvLike]
|
The environment being used. |
policy |
eqx.AbstractVar[PolicyType]
|
The policy being trained. |
opt_state |
eqx.AbstractVar[optax.OptState]
|
The optimizer state. |
callback_state |
eqx.AbstractVar[AbstractCallbackState]
|
The state of the callback for this iteration. |
next
next[A: AbstractAlgorithmState](
step_state: AbstractStepState,
policy: PolicyType,
opt_state: optax.OptState,
) -> A
Return a new algorithm state for the next iteration.
Increments the iteration count and updates the step state, policy, and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_state
|
AbstractStepState
|
The new step state. |
required |
policy
|
PolicyType
|
The new policy. |
required |
opt_state
|
optax.OptState
|
The new optimizer state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated fields. |
with_callback_states
Return a new algorithm state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackState
|
The new callback state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated callback state. |