Callbacks
Callbacks let you monitor and control training without modifying algorithms. They are Equinox modules called from the training loop and are designed to be JAX-friendly (I/O goes through JAX debug callbacks).
All callbacks subclass AbstractCallback and implement some or all of:
reset/step_reset— called before training and before each rollout.on_step— called after each environment step.on_iteration— called after each training iteration / update.on_training_start/on_training_end— called at the start and end of training.continue_training— optional early-stopping hook.
Each hook receives a context object (ResetContext, StepContext, IterationContext, TrainingContext) containing the current environment, policy, optimizer state, training log, callback state, and a locals dict.
Using callbacks with algorithms
Algorithms accept either:
- A single callback instance, or
- A Python list of callbacks (internally wrapped in a
CallbackList)
via the callback argument to learn.
from jax import random as jr
from lerax.algorithm import PPO
from lerax.callback import ProgressBarCallback, TensorBoardCallback
from lerax.env import CartPole
from lerax.policy import MLPActorCriticPolicy
policy_key, learn_key = jr.split(jr.key(0), 2)
env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=policy_key)
algo = PPO()
callbacks = [
ProgressBarCallback(total_timesteps=2**16, env=env, policy=policy),
TensorBoardCallback(env=env, policy=policy),
]
policy = algo.learn(
env,
policy,
total_timesteps=2**16,
key=learn_key,
callback=callbacks,
)
Built-in callbacks
-
ProgressBarCallback: Rich-based progress bar showing iterations, elapsed/remaining time, and iterations per second. -
TensorBoardCallback: Logs training metrics (learning rate, training log entries, episode return/length EMAs) to TensorBoard. -
CallbackList: Aggregates multiple callbacks and forwards all hooks to each one. Used automatically when you pass a list of callbacks. -
EmptyCallback: No-op callback that can be used as a placeholder or default.