Algorithm
lerax.algorithm.AbstractAlgorithm
Bases: eqx.Module
Base class for RL algorithms.
Provides the main training loop and abstract methods for algorithm-specific behavior.
Attributes:
| Name | Type | Description |
|---|---|---|
optimizer |
eqx.AbstractVar[optax.GradientTransformation]
|
The optimizer used for training. |
num_envs |
eqx.AbstractVar[int]
|
The number of parallel environments. |
num_steps |
eqx.AbstractVar[int]
|
The number of steps per environment per iteration. |
reset
abstractmethod
reset(
env: AbstractEnvLike,
policy: PolicyType,
*,
key: Key[Array, ""],
callback: AbstractCallback,
) -> StateType
Return the initial carry for the training iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment to train on. |
required |
policy
|
PolicyType
|
The policy to train. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
AbstractCallback
|
A callback or list of callbacks to use during training. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
The initial algorithm state. |
per_iteration
abstractmethod
Process the algorithm state after each iteration.
Used for algorithm-specific bookkeeping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current algorithm state. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
The updated algorithm state. |
iteration
abstractmethod
Perform a single iteration of training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StateType
|
The current algorithm state. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
AbstractCallback
|
A callback or list of callbacks to use during training. |
required |
Returns:
| Type | Description |
|---|---|
StateType
|
The updated algorithm state. |
num_iterations
abstractmethod
Number of iterations per training session.
consolidate_callbacks
consolidate_callbacks(
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> AbstractCallback
learn
learn(
env: AbstractEnvLike,
policy: PolicyType,
total_timesteps: int,
*,
key: Key[Array, ""],
callback: Sequence[AbstractCallback]
| AbstractCallback
| None = None,
) -> PolicyType
Train the policy on the environment for a given number of timesteps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
AbstractEnvLike
|
The environment to train on. |
required |
policy
|
PolicyType
|
The policy to train. |
required |
total_timesteps
|
int
|
The total number of timesteps to train for. |
required |
key
|
Key[Array, '']
|
A JAX PRNG key. |
required |
callback
|
Sequence[AbstractCallback] | AbstractCallback | None
|
A callback or list of callbacks to use during training. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
PolicyType
|
The trained policy. |
lerax.algorithm.AbstractStepState
Bases: eqx.Module
Base class for algorithm state that is vectorized over environment steps.
Attributes:
| Name | Type | Description |
|---|---|---|
env_state |
eqx.AbstractVar[AbstractEnvLikeState]
|
The state of the environment. |
policy_state |
eqx.AbstractVar[AbstractPolicyState]
|
The state of the policy. |
callback_state |
eqx.AbstractVar[AbstractCallbackStepState]
|
The state of the callback for this step. |
with_callback_state
Return a new step state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackStepState | None
|
The new callback state. If None, the existing state is used. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new step state with the updated callback state. |
lerax.algorithm.AbstractAlgorithmState
Bases: eqx.Module
Base class for algorithm states.
Attributes:
| Name | Type | Description |
|---|---|---|
iteration_count |
eqx.AbstractVar[Int[Array, '']]
|
The current iteration count. |
step_state |
eqx.AbstractVar[AbstractStepState]
|
The state for the current step. |
env |
eqx.AbstractVar[AbstractEnvLike]
|
The environment being used. |
policy |
eqx.AbstractVar[PolicyType]
|
The policy being trained. |
opt_state |
eqx.AbstractVar[optax.OptState]
|
The optimizer state. |
callback_state |
eqx.AbstractVar[AbstractCallbackState]
|
The state of the callback for this iteration. |
next
next[A: AbstractAlgorithmState](
step_state: AbstractStepState,
policy: PolicyType,
opt_state: optax.OptState,
) -> A
Return a new algorithm state for the next iteration.
Increments the iteration count and updates the step state, policy, and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_state
|
AbstractStepState
|
The new step state. |
required |
policy
|
PolicyType
|
The new policy. |
required |
opt_state
|
optax.OptState
|
The new optimizer state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated fields. |
with_callback_states
Return a new algorithm state with the given callback state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
callback_state
|
AbstractCallbackState
|
The new callback state. |
required |
Returns:
| Type | Description |
|---|---|
A
|
A new algorithm state with the updated callback state. |