Logging Callback
lerax.callback.LoggingCallbackStepState
Bases: AbstractCallbackStepState
Per-environment step state for LoggingCallback.
Tracks cumulative episode returns and lengths, along with exponential moving averages of those quantities updated at episode boundaries.
Attributes:
| Name | Type | Description |
|---|---|---|
step |
Int[Array, '']
|
Total number of steps taken. |
episode_return |
Float[Array, '']
|
Cumulative return for the current (in-progress) episode. |
episode_length |
Int[Array, '']
|
Number of steps in the current episode. |
episode_done |
Bool[Array, '']
|
Whether the previous step ended an episode. |
average_return |
Float[Array, '']
|
EMA of episode returns. |
average_length |
Float[Array, '']
|
EMA of episode lengths. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step
|
Int[Array, '']
|
Total number of steps taken. |
required |
episode_return
|
Float[ArrayLike, '']
|
Cumulative return for the current episode. |
required |
episode_length
|
Int[ArrayLike, '']
|
Number of steps in the current episode. |
required |
episode_done
|
Bool[ArrayLike, '']
|
Whether the previous step ended an episode. |
required |
average_return
|
Float[ArrayLike, '']
|
EMA of episode returns. |
required |
average_length
|
Float[ArrayLike, '']
|
EMA of episode lengths. |
required |
next
Advance the state by one environment step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reward
|
Float[Array, '']
|
Reward received at this step. |
required |
done
|
Bool[Array, '']
|
Whether the episode ended at this step. |
required |
alpha
|
float
|
EMA smoothing factor (weight on the new episode value). |
required |
Returns:
| Type | Description |
|---|---|
LoggingCallbackStepState
|
Updated step state. |
lerax.callback.AbstractLoggingBackend
Bases: eqx.Module
Abstract base class for logging backends.
Implementations receive already-converted Python/numpy values; the
LoggingCallback handles the JIT-to-numpy boundary.
The lifecycle is open -> log_* -> close. The open method
is called by LoggingCallback.__init__ with the run name at
construction time. close is called by LoggingCallback.close().
open
abstractmethod
Initialise the backend with the given run name.
Called exactly once before any log_* method. Backends should create
their writer, run handle, or output directory here.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Human-readable run name. |
required |
log_scalars
abstractmethod
log_video
abstractmethod
lerax.callback.TensorBoardBackend
Bases: AbstractLoggingBackend
Logging backend that writes to TensorBoard via tensorboardX.
The log directory is determined by log_dir / name where name is
provided by open.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_dir
|
str | Path
|
Base directory for TensorBoard event files. |
'logs'
|
lerax.callback.WandbBackend
Bases: AbstractLoggingBackend
Logging backend that writes to Weights & Biases.
wandb.init is called in open and wandb.finish is called in
close.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
project
|
str | None
|
W&B project name. |
None
|
config
|
dict[str, Any] | None
|
Hyperparameter dictionary passed to |
None
|
quiet
|
bool
|
Suppress W&B console output. Useful when combined with
other logging backends like |
False
|
lerax.callback.ConsoleBackend
Bases: AbstractLoggingBackend
Logging backend that displays a live metrics table and progress bar.
Uses Rich's Live display to show a metrics table that updates in-place
on each log_scalars call, with a progress bar rendered below it. The
table is replaced (not appended) on each update so the display stays
compact.
Progress bar display is automatically enabled when running in an interactive terminal. In non-interactive environments (logs, CI, redirected output), progress info (percentage, elapsed time, ETA) is logged with each scalar update.
lerax.callback.LoggingCallback
Bases: AbstractCallback[EmptyCallbackState, LoggingCallbackStepState]
Callback that collects training metrics and progress and logs to a backend.
Note
This callback must be constructed outside any JIT-compiled function.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
float
|
EMA smoothing factor for episode statistics. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backend
|
AbstractLoggingBackend | Sequence[AbstractLoggingBackend]
|
Logging backend (or list of backends) to send metrics to. |
required |
name
|
str | None
|
Explicit run name. When |
None
|
env
|
AbstractEnvLike | None
|
Environment used to derive the run name when |
None
|
policy
|
AbstractPolicy | None
|
Policy used to derive the run name when |
None
|
alpha
|
float
|
EMA smoothing factor (higher = more weight on recent episodes). |
0.9
|
hparams
|
dict[str, Any] | None
|
Additional explicit hyperparameters. Merged last so these values take precedence over auto-extracted ones. |
None
|
video_interval
|
int
|
Record video every this many iterations; |
0
|
video_num_steps
|
int
|
Environment steps per recorded video. |
128
|
video_width
|
int
|
Render width in pixels. |
640
|
video_height
|
int
|
Render height in pixels. |
480
|
video_fps
|
float
|
Playback frames per second. |
50.0
|
__init__
__init__(
backend: AbstractLoggingBackend
| Sequence[AbstractLoggingBackend],
name: str | None = None,
env: AbstractEnvLike | None = None,
policy: AbstractPolicy | None = None,
alpha: float = 0.9,
hparams: dict[str, Any] | None = None,
video_interval: int = 0,
video_num_steps: int = 128,
video_width: int = 640,
video_height: int = 480,
video_fps: float = 50.0,
) -> None