Algorithm
lerax.algorithm.AbstractAlgorithm
Bases:
Base class for RL algorithms.
Provides the main training loop and abstract methods for algorithm-specific behavior.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The optimizer used for training. |
|
|
The number of parallel environments. |
|
|
The number of steps per environment per iteration. |
reset
abstractmethod
reset(
env: AbstractEnvLike,
policy: PolicyType,
*,
key: Key,
callback: AbstractCallback,
) -> StateType
Return the initial carry for the training iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to train on. |
required |
policy
|
|
The policy to train. |
required |
key
|
|
A JAX PRNG key. |
required |
callback
|
|
A callback or list of callbacks to use during training. |
required |
Returns:
| Type | Description |
|---|---|
|
The initial algorithm state. |
per_iteration
abstractmethod
Process the algorithm state after each iteration.
Used for algorithm-specific bookkeeping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current algorithm state. |
required |
Returns:
| Type | Description |
|---|---|
|
The updated algorithm state. |
iteration
abstractmethod
Perform a single iteration of training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current algorithm state. |
required |
key
|
|
A JAX PRNG key. |
required |
callback
|
|
A callback or list of callbacks to use during training. |
required |
Returns:
| Type | Description |
|---|---|
|
The updated algorithm state. |
num_iterations
abstractmethod
Number of iterations per training session.
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key,
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to train on. |
required |
policy
|
|
The policy to train. |
required |
total_timesteps
|
|
The total number of timesteps to train for. |
required |
key
|
|
A JAX PRNG key. |
required |
callback
|
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
|
The trained policy. |
lerax.algorithm.AbstractStepState
Bases:
Base class for algorithm state that is vectorized over environment steps.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The state of the environment. |
|
|
The state of the policy. |
|
|
The state of the callback for this step. |
with_callback_state
Return a new step state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
|
The new callback state. If None, the existing state is used. |
required |
Returns:
| Type | Description |
|---|---|
|
A new step state with the updated callback state. |
lerax.algorithm.AbstractAlgorithmState
Bases:
Base class for algorithm states.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The current iteration count. |
|
|
The state for the current step. |
|
|
The environment being used. |
|
|
The policy being trained. |
|
|
The optimizer state. |
|
|
The state of the callback for this iteration. |
next
next[A: AbstractAlgorithmState](
step_state: AbstractStepState,
policy: PolicyType,
opt_state: optax.OptState,
) -> A
Return a new algorithm state for the next iteration.
Increments the iteration count and updates the step state, policy, and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_state
|
|
The new step state. |
required |
policy
|
|
The new policy. |
required |
opt_state
|
|
The new optimizer state. |
required |
Returns:
| Type | Description |
|---|---|
|
A new algorithm state with the updated fields. |
with_callback_states
Return a new algorithm state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
|
The new callback state. |
required |
Returns:
| Type | Description |
|---|---|
|
A new algorithm state with the updated callback state. |