Skip to content

Callback

lerax.callback.AbstractCallbackStepState

Bases: eqx.Module

Base class for callback states that are vectorized across environment steps.

lerax.callback.AbstractCallbackState

Bases: eqx.Module

Base class for callback states.

lerax.callback.ResetContext

Bases: eqx.Module

Context passed to the reset method of callbacks.

lerax.callback.StepContext

Bases: eqx.Module

Values passed to step-related callback methods.

Attributes:

Name Type Description
state StepStateType

The current callback step state.

env AbstractEnvLike

The environment being interacted with.

policy AbstractPolicy

The policy being used to interact with the environment.

done Bool[Array, '']

Boolean indicating if the episode has terminated or truncated.

reward Float[Array, '']

Reward received from the environment at the current step.

locals dict

A dictionary for storing additional information.

lerax.callback.IterationContext

Bases: eqx.Module

Values passed to iteration-related callback methods.

Attributes:

Name Type Description
state StateType

The current callback state.

step_state StepStateType

The current callback step state.

env AbstractEnvLike

The environment being interacted with.

policy AbstractPolicy

The policy being used to interact with the environment.

iteration_count Int[Array, '']

The current training iteration count.

opt_state optax.OptState

The current optimizer state.

training_log dict[str, Array]

A dictionary containing training metrics.

locals dict

A dictionary for storing additional information.

lerax.callback.TrainingContext

Bases: eqx.Module

Values passed to training-related callback methods.

Attributes:

Name Type Description
state StateType

The current callback state.

step_state StepStateType

The current callback step state.

env AbstractEnvLike

The environment being interacted with.

policy AbstractPolicy

The policy being used to interact with the environment.

total_timesteps int

Total number of timesteps for training.

iteration_count Int[Array, '']

The current training iteration count.

locals dict

A dictionary for storing additional information.

lerax.callback.AbstractCallback

Bases: eqx.Module

Base class for RL algorithm callbacks.

Note

All concrete methods should work under JIT compilation.

reset abstractmethod

reset(ctx: ResetContext, *, key: Key) -> StateType

Initialize the callback state.

step_reset abstractmethod

step_reset(ctx: ResetContext, *, key: Key) -> StepStateType

Reset the callback state for vectorized steps.

on_step abstractmethod

on_step(ctx: StepContext, *, key: Key) -> StepStateType

Called at the end of each environment step.

on_iteration abstractmethod

on_iteration(
    ctx: IterationContext, *, key: Key
) -> StateType

Called at the end of each training iteration.

on_training_start abstractmethod

on_training_start(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the start of training.

on_training_end abstractmethod

on_training_end(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the end of training.

lerax.callback.AbstractStatelessCallback

Bases: AbstractCallback[EmptyCallbackState, EmptyCallbackStepState]

Callback that does not maintain any state.

on_step abstractmethod

on_step(ctx: StepContext, *, key: Key) -> StepStateType

Called at the end of each environment step.

on_iteration abstractmethod

on_iteration(
    ctx: IterationContext, *, key: Key
) -> StateType

Called at the end of each training iteration.

on_training_start abstractmethod

on_training_start(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the start of training.

on_training_end abstractmethod

on_training_end(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the end of training.

lerax.callback.AbstractStepCallback

Bases: AbstractCallback[EmptyCallbackState, StepStateType]

Callback that only implements step-related methods.

step_reset abstractmethod

step_reset(ctx: ResetContext, *, key: Key) -> StepStateType

Reset the callback state for vectorized steps.

on_step abstractmethod

on_step(ctx: StepContext, *, key: Key) -> StepStateType

Called at the end of each environment step.

lerax.callback.AbstractIterationCallback

Bases: AbstractCallback[StateType, EmptyCallbackStepState]

Callback that only implements iteration-related methods.

reset abstractmethod

reset(ctx: ResetContext, *, key: Key) -> StateType

Initialize the callback state.

on_iteration abstractmethod

on_iteration(
    ctx: IterationContext, *, key: Key
) -> StateType

Called at the end of each training iteration.

lerax.callback.AbstractTrainingCallback

Bases: AbstractCallback[StateType, EmptyCallbackStepState]

Callback that only implements training-related methods.

reset abstractmethod

reset(ctx: ResetContext, *, key: Key) -> StateType

Initialize the callback state.

on_training_start abstractmethod

on_training_start(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the start of training.

on_training_end abstractmethod

on_training_end(
    ctx: TrainingContext, *, key: Key
) -> StateType

Called at the end of training.