On Policy Algorithm
lerax.algorithm.AbstractOnPolicyAlgorithm
Bases:
Base class for on-policy algorithms.
Generates rollouts using the current policy and estimates advantages and returns using GAE. Trains the policy using the collected rollouts.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The optimizer used for training the policy. |
|
|
The GAE lambda parameter. |
|
|
The discount factor. |
|
|
The number of parallel environments. |
|
|
The number of steps to collect per environment. |
|
|
The batch size for training. |
per_iteration
abstractmethod
Process the algorithm state after each iteration.
Used for algorithm-specific bookkeeping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current algorithm state. |
required |
Returns:
| Type | Description |
|---|---|
|
The updated algorithm state. |
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key,
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to train on. |
required |
policy
|
|
The policy to train. |
required |
total_timesteps
|
|
The total number of timesteps to train for. |
required |
key
|
|
A JAX PRNG key. |
required |
callback
|
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
|
The trained policy. |
per_step
abstractmethod
Process the step carry after each step.
step
step(
env: AbstractEnvLike,
policy: PolicyType,
state: OnPolicyStepState[PolicyType],
*,
key: Key,
callback: AbstractCallback,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]
collect_rollout
collect_rollout(
env: AbstractEnvLike,
policy: PolicyType,
step_state: OnPolicyStepState[PolicyType],
callback: AbstractCallback,
key: Key,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]
Collect a rollout using the current policy.
train
abstractmethod
train(
policy: PolicyType,
opt_state: optax.OptState,
buffer: RolloutBuffer,
*,
key: Key,
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]
Train the policy using the rollout buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
|
The current policy. |
required |
opt_state
|
|
The current optimizer state. |
required |
buffer
|
|
The rollout buffer containing collected experiences. |
required |
key
|
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
|
A tuple containing the updated policy, updated optimizer state, |
|
and a log dictionary. |
lerax.algorithm.OnPolicyStepState
Bases:
State for on-policy algorithm steps.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The state of the environment. |
|
|
The state of the policy. |
|
|
The state of the callback for this step. |
with_callback_state
Return a new step state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
|
The new callback state. If None, the existing state is used. |
required |
Returns:
| Type | Description |
|---|---|
|
A new step state with the updated callback state. |
initial
classmethod
initial(
env: AbstractEnvLike,
policy: PolicyType,
callback: AbstractCallback,
key: Key,
) -> OnPolicyStepState[PolicyType]
Initialize the step state for the on-policy algorithm.
Resets the environment, policy, and callback states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to initialize. |
required |
policy
|
|
The policy to initialize. |
required |
callback
|
|
The callback to initialize. |
required |
key
|
|
A JAX PRNG key. |
required |
Returns:
| Type | Description |
|---|---|
|
The initialized step state. |
lerax.algorithm.OnPolicyState
Bases:
State for on-policy algorithms.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The current iteration count. |
|
|
The state for the current step. |
|
|
The environment being used. |
|
|
The policy being trained. |
|
|
The optimizer state. |
|
|
The state of the callback for this iteration. |
next
next[A: AbstractAlgorithmState](
step_state: AbstractStepState,
policy: PolicyType,
opt_state: optax.OptState,
) -> A
Return a new algorithm state for the next iteration.
Increments the iteration count and updates the step state, policy, and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_state
|
|
The new step state. |
required |
policy
|
|
The new policy. |
required |
opt_state
|
|
The new optimizer state. |
required |
Returns:
| Type | Description |
|---|---|
|
A new algorithm state with the updated fields. |
with_callback_states
Return a new algorithm state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
|
The new callback state. |
required |
Returns:
| Type | Description |
|---|---|
|
A new algorithm state with the updated callback state. |