Callback
lerax.callback.AbstractCallbackStepState
Bases: eqx.Module
Base class for callback states that are vectorized across environment steps.
lerax.callback.AbstractCallbackState
Bases: eqx.Module
Base class for callback states.
lerax.callback.ResetContext
Bases: eqx.Module
Context passed to the reset method of callbacks.
lerax.callback.StepContext
Bases: eqx.Module
Values passed to step-related callback methods.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
StepStateType
|
The current callback step state. |
env |
AbstractEnvLike
|
The environment being interacted with. |
policy |
AbstractPolicy
|
The policy being used to interact with the environment. |
done |
Bool[Array, '']
|
Boolean indicating if the episode has terminated or truncated. |
reward |
Float[Array, '']
|
Reward received from the environment at the current step. |
locals |
dict
|
A dictionary for storing additional information. |
lerax.callback.IterationContext
Bases: eqx.Module
Values passed to iteration-related callback methods.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
StateType
|
The current callback state. |
step_state |
StepStateType
|
The current callback step state. |
env |
AbstractEnvLike
|
The environment being interacted with. |
policy |
AbstractPolicy
|
The policy being used to interact with the environment. |
iteration_count |
Int[Array, '']
|
The current training iteration count. |
opt_state |
optax.OptState
|
The current optimizer state. |
training_log |
dict[str, Array]
|
A dictionary containing training metrics. |
algorithm |
AbstractAlgorithm
|
The algorithm performing the training. |
locals |
dict
|
A dictionary for storing additional information. |
lerax.callback.TrainingContext
Bases: eqx.Module
Values passed to training-related callback methods.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
StateType
|
The current callback state. |
step_state |
StepStateType
|
The current callback step state. |
env |
AbstractEnvLike
|
The environment being interacted with. |
policy |
AbstractPolicy
|
The policy being used to interact with the environment. |
total_timesteps |
int
|
Total number of timesteps for training. |
iteration_count |
Int[Array, '']
|
The current training iteration count. |
algorithm |
AbstractAlgorithm
|
The algorithm performing the training. |
locals |
dict
|
A dictionary for storing additional information. |
lerax.callback.AbstractCallback
Bases: eqx.Module
Base class for RL algorithm callbacks.
Note
All concrete methods should work under JIT compilation.
reset
abstractmethod
Initialize the callback state.
step_reset
abstractmethod
Reset the callback state for vectorized steps.
on_step
abstractmethod
Called at the end of each environment step.
on_iteration
abstractmethod
Called at the end of each training iteration.
on_training_start
abstractmethod
Called at the start of training.
lerax.callback.AbstractStatelessCallback
Bases: AbstractCallback[EmptyCallbackState, EmptyCallbackStepState]
Callback that does not maintain any state.
on_step
abstractmethod
Called at the end of each environment step.
on_iteration
abstractmethod
Called at the end of each training iteration.
on_training_start
abstractmethod
Called at the start of training.
lerax.callback.AbstractStepCallback
Bases: AbstractCallback[EmptyCallbackState, StepStateType]
Callback that only implements step-related methods.
lerax.callback.AbstractIterationCallback
Bases: AbstractCallback[StateType, EmptyCallbackStepState]
Callback that only implements iteration-related methods.
lerax.callback.AbstractTrainingCallback
Bases: AbstractCallback[StateType, EmptyCallbackStepState]
Callback that only implements training-related methods.
reset
abstractmethod
Initialize the callback state.
on_training_start
abstractmethod
Called at the start of training.