Skip to content

PPO

lerax.algorithm.PPO

Bases: AbstractOnPolicyAlgorithm[PolicyType]

Proximal Policy Optimization (PPO) algorithm.

Attributes:

Name Type Description
optimizer optax.GradientTransformation

The optimizer used for training.

gae_lambda float

Lambda parameter for Generalized Advantage Estimation (GAE).

gamma float

Discount factor.

num_envs int

Number of parallel environments.

num_steps int

Number of steps to run for each environment per update.

batch_size int

Size of each training batch.

num_epochs int

Number of epochs to train the policy per update.

normalize_advantages bool

Whether to normalize advantages.

clip_coefficient float

Clipping coefficient for policy and value function updates.

clip_value_loss bool

Whether to clip the value function loss.

entropy_loss_coefficient float

Coefficient for the entropy loss term.

value_loss_coefficient float

Coefficient for the value function loss term.

max_grad_norm float

Maximum gradient norm for gradient clipping.

Parameters:

Name Type Description Default
num_envs int

Number of parallel environments.

4
num_steps int

Number of steps to run for each environment per update.

512
num_epochs int

Number of epochs to train the policy per update.

16
num_batches int

Number of batches to split the rollout buffer into for training.

32
gae_lambda float

Lambda parameter for Generalized Advantage Estimation (GAE).

0.95
gamma float

Discount factor.

0.99
clip_coefficient float

Clipping coefficient for policy and value function updates.

0.2
clip_value_loss bool

Whether to clip the value function loss.

False
entropy_loss_coefficient float

Coefficient for the entropy loss term.

0.0
value_loss_coefficient float

Coefficient for the value function loss term.

0.5
max_grad_norm float

Maximum gradient norm for gradient clipping.

0.5
normalize_advantages bool

Whether to normalize advantages.

True
learning_rate optax.ScalarOrSchedule

Learning rate for the optimizer.

0.0003

optimizer instance-attribute

optimizer: optax.GradientTransformation = optax.chain(
    clip, adam
)

gae_lambda instance-attribute

gae_lambda: float = gae_lambda

gamma instance-attribute

gamma: float = gamma

num_envs instance-attribute

num_envs: int = num_envs

num_steps instance-attribute

num_steps: int = num_steps

batch_size instance-attribute

batch_size: int = (
    self.num_steps * self.num_envs // num_batches
)

num_epochs instance-attribute

num_epochs: int = num_epochs

normalize_advantages instance-attribute

normalize_advantages: bool = normalize_advantages

clip_coefficient instance-attribute

clip_coefficient: float = clip_coefficient

clip_value_loss instance-attribute

clip_value_loss: bool = clip_value_loss

entropy_loss_coefficient instance-attribute

entropy_loss_coefficient: float = entropy_loss_coefficient

value_loss_coefficient instance-attribute

value_loss_coefficient: float = value_loss_coefficient

max_grad_norm instance-attribute

max_grad_norm: float = max_grad_norm

ppo_loss_grad class-attribute instance-attribute

ppo_loss_grad = staticmethod(
    eqx.filter_value_and_grad(ppo_loss, has_aux=True)
)

reset

reset(
    env: AbstractEnvLike,
    policy: PolicyType,
    *,
    key: Key,
    callback: AbstractCallback,
) -> OnPolicyState[PolicyType]

iteration

iteration(
    state: OnPolicyState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> OnPolicyState[PolicyType]

num_iterations

num_iterations(total_timesteps: int) -> int

learn

learn(
    env: AbstractEnvLike,
    policy: PolicyType,
    total_timesteps: int,
    *,
    key: Key,
    callback: Sequence[AbstractCallback]
    | AbstractCallback
    | None = None,
) -> PolicyType

Train the policy on the environment for a given number of timesteps.

Parameters:

Name Type Description Default
env AbstractEnvLike

The environment to train on.

required
policy PolicyType

The policy to train.

required
total_timesteps int

The total number of timesteps to train for.

required
key Key

A JAX PRNG key.

required
callback Sequence[AbstractCallback] | AbstractCallback | None

A callback or list of callbacks to use during training.

None

Returns:

Name Type Description
policy PolicyType

The trained policy.

step

step(
    env: AbstractEnvLike,
    policy: PolicyType,
    state: OnPolicyStepState[PolicyType],
    *,
    key: Key,
    callback: AbstractCallback,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]

collect_rollout

collect_rollout(
    env: AbstractEnvLike,
    policy: PolicyType,
    step_state: OnPolicyStepState[PolicyType],
    callback: AbstractCallback,
    key: Key,
) -> tuple[OnPolicyStepState[PolicyType], RolloutBuffer]

Collect a rollout using the current policy.

__init__

__init__(
    *,
    num_envs: int = 4,
    num_steps: int = 512,
    num_epochs: int = 16,
    num_batches: int = 32,
    gae_lambda: float = 0.95,
    gamma: float = 0.99,
    clip_coefficient: float = 0.2,
    clip_value_loss: bool = False,
    entropy_loss_coefficient: float = 0.0,
    value_loss_coefficient: float = 0.5,
    max_grad_norm: float = 0.5,
    normalize_advantages: bool = True,
    learning_rate: optax.ScalarOrSchedule = 0.0003,
)

per_step

per_step(
    step_state: OnPolicyStepState[PolicyType],
) -> OnPolicyStepState[PolicyType]

per_iteration

per_iteration(
    state: OnPolicyState[PolicyType],
) -> OnPolicyState[PolicyType]

ppo_loss staticmethod

ppo_loss(
    policy: PolicyType,
    rollout_buffer: RolloutBuffer,
    normalize_advantages: bool,
    clip_coefficient: float,
    clip_value_loss: bool,
    value_loss_coefficient: float,
    entropy_loss_coefficient: float,
) -> tuple[Float[Array, ""], PPOStats]

train_batch

train_batch(
    policy: PolicyType,
    opt_state: optax.OptState,
    rollout_buffer: RolloutBuffer,
) -> tuple[PolicyType, optax.OptState, PPOStats]

train_epoch

train_epoch(
    policy: PolicyType,
    opt_state: optax.OptState,
    rollout_buffer: RolloutBuffer,
    *,
    key: Key,
) -> tuple[PolicyType, optax.OptState, PPOStats]

explained_variance staticmethod

explained_variance(
    returns: Float[Array, ""], values: Float[Array, ""]
) -> Float[Array, ""]

train

train(
    policy: PolicyType,
    opt_state: optax.OptState,
    buffer: RolloutBuffer,
    *,
    key: Key,
) -> tuple[PolicyType, optax.OptState, dict[str, Scalar]]