LoggingCallback
LoggingCallback logs training statistics to any supported backend via a pluggable AbstractLoggingBackend.
It is JIT-safe through debug callbacks, but must be constructed outside any JIT-compiled function.
Run naming
When env and policy are passed to the constructor, a run name is generated
automatically:
{PolicyName}_{EnvName}_{timestamp}
For example: MLPActorCriticPolicy_CartPole-v1_20260220_120000.
Pass name= to override:
Logged metrics
At the end of each iteration (on_iteration), it logs:
-
episode/returnExponential moving average of episode returns over all environments. -
episode/lengthExponential moving average of episode lengths. -
train/*All scalars intraining_log(from the algorithm), plus: -
train/learning_rate(if present in the Optax optimizer state; otherwiseNaN).
The smoothing factor for the exponential moving averages is controlled by alpha (default 0.9).
Usage
from lerax.callback import LoggingCallback, TensorBoardBackend
logger = LoggingCallback(
backend=TensorBoardBackend(),
env=env,
policy=policy,
)
# ... training ...
logger.close()
Backends
All backends defer initialisation to their open method, which LoggingCallback
calls at construction time with the run name.
TensorBoardBackend
Writes to TensorBoard via tensorboardX.
from lerax.callback import TensorBoardBackend
backend = TensorBoardBackend(
log_dir="logs", # base directory for event files
)
WandbBackend
Writes to Weights & Biases.
from lerax.callback import WandbBackend
backend = WandbBackend(
project="my-project",
config={"lr": 3e-4}, # optional hyperparameter dict
)
ConsoleBackend
Displays metrics, rewards history, and progress in the console.
Video recording
LoggingCallback can periodically record evaluation videos and log them under eval/video. Set video_interval to a positive integer to enable it:
logger = LoggingCallback(
backend=TensorBoardBackend(),
env=env,
policy=policy,
video_interval=10, # record every 10 iterations
video_num_steps=256, # 256 environment steps per video
video_width=640, # render width in pixels
video_height=480, # render height in pixels
video_fps=50.0, # playback fps
)
# ... training ...
logger.close()
At each recording iteration the callback runs an eager eval rollout with the
current policy using jax.debug.callback and forwards the frames to the
backend. If rendering fails for any reason a warning is emitted and training
continues uninterrupted.
Two renderer backends are supported:
- MuJoCo environments (whose
unwrappedexposes amujoco_modelattribute) useHeadlessMujocoRendererfor off-screen rendering. - PyGame environments (e.g. CartPole, MountainCar) use
HeadlessPygameRenderer, which renders into an off-screenpygame.Surfacewithout requiring a display server. The correct world transform is extracted from the environment'sdefault_renderer().
Display initialisation
HeadlessPygameRenderer sets SDL_VIDEODRIVER=dummy before initialising
pygame when no display is available. If pygame is already initialised with a
real driver, the environment's default_renderer() may briefly open a
window; this is harmless during training, where interactive pygame rendering
is unlikely.