PPO
lerax.algorithm.PPO
Bases:
Proximal Policy Optimization (PPO) algorithm.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The optimizer used for training. |
|
|
Lambda parameter for Generalized Advantage Estimation (GAE). |
|
|
Discount factor. |
|
|
Number of parallel environments. |
|
|
Number of steps to run for each environment per update. |
|
|
Size of each training batch. |
|
|
Number of epochs to train the policy per update. |
|
|
Whether to normalize advantages. |
|
|
Clipping coefficient for policy and value function updates. |
|
|
Whether to clip the value function loss. |
|
|
Coefficient for the entropy loss term. |
|
|
Coefficient for the value function loss term. |
|
|
Maximum gradient norm for gradient clipping. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_envs
|
|
Number of parallel environments. |
4
|
num_steps
|
|
Number of steps to run for each environment per update. |
512
|
num_epochs
|
|
Number of epochs to train the policy per update. |
16
|
num_batches
|
|
Number of batches to split the rollout buffer into for training. |
32
|
gae_lambda
|
|
Lambda parameter for Generalized Advantage Estimation (GAE). |
0.95
|
gamma
|
|
Discount factor. |
0.99
|
clip_coefficient
|
|
Clipping coefficient for policy and value function updates. |
0.2
|
clip_value_loss
|
|
Whether to clip the value function loss. |
False
|
entropy_loss_coefficient
|
|
Coefficient for the entropy loss term. |
0.0
|
value_loss_coefficient
|
|
Coefficient for the value function loss term. |
0.5
|
max_grad_norm
|
|
Maximum gradient norm for gradient clipping. |
0.5
|
normalize_advantages
|
|
Whether to normalize advantages. |
True
|
learning_rate
|
|
Learning rate for the optimizer. |
0.0003
|
entropy_loss_coefficient
instance-attribute
ppo_loss_grad
class-attribute
instance-attribute
reset
reset(
env: AbstractEnvLike,
policy: PolicyType,
*,
key: Key,
callback: AbstractCallback,
) -> OnPolicyState[PolicyType]
iteration
iteration(
state: OnPolicyState[PolicyType],
*,
key: Key,
callback: AbstractCallback,
) -> OnPolicyState[PolicyType]
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key,
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to train on. |
required |
policy
|
|
The policy to train. |
required |
total_timesteps
|
|
The total number of timesteps to train for. |
required |
key
|
|
A JAX PRNG key. |
required |
callback
|
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
|
The trained policy. |
step
step(
env: AbstractEnvLike,
policy: PolicyType,
state: OnPolicyStepState[PolicyType],
*,
key: Key,
callback: AbstractCallback,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]
collect_rollout
collect_rollout(
env: AbstractEnvLike,
policy: PolicyType,
step_state: OnPolicyStepState[PolicyType],
callback: AbstractCallback,
key: Key,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]
Collect a rollout using the current policy.
__init__
__init__(
*,
num_envs: int = 4,
num_steps: int = 512,
num_epochs: int = 16,
num_batches: int = 32,
gae_lambda: float = 0.95,
gamma: float = 0.99,
clip_coefficient: float = 0.2,
clip_value_loss: bool = False,
entropy_loss_coefficient: float = 0.0,
value_loss_coefficient: float = 0.5,
max_grad_norm: float = 0.5,
normalize_advantages: bool = True,
learning_rate: optax.ScalarOrSchedule = 0.0003,
)
ppo_loss
staticmethod
ppo_loss(
policy: PolicyType,
rollout_buffer: RolloutBuffer,
normalize_advantages: bool,
clip_coefficient: float,
clip_value_loss: bool,
value_loss_coefficient: float,
entropy_loss_coefficient: float,
) -> tuple[Float[Array, ""], PPOStats]
train_batch
train_batch(
policy: PolicyType,
opt_state: optax.OptState,
rollout_buffer: RolloutBuffer,
) -> tuple[PolicyType, optax.OptState, PPOStats]
train_epoch
train_epoch(
policy: PolicyType,
opt_state: optax.OptState,
rollout_buffer: RolloutBuffer,
*,
key: Key,
) -> tuple[PolicyType, optax.OptState, PPOStats]