Skip to content

Logging Callback

lerax.callback.LoggingCallbackStepState

Bases: AbstractCallbackStepState

Per-environment step state for LoggingCallback.

Tracks cumulative episode returns and lengths, along with exponential moving averages of those quantities updated at episode boundaries.

Attributes:

Name Type Description
step Int[Array, '']

Total number of steps taken.

episode_return Float[Array, '']

Cumulative return for the current (in-progress) episode.

episode_length Int[Array, '']

Number of steps in the current episode.

episode_done Bool[Array, '']

Whether the previous step ended an episode.

average_return Float[Array, '']

EMA of episode returns.

average_length Float[Array, '']

EMA of episode lengths.

Parameters:

Name Type Description Default
step Int[Array, '']

Total number of steps taken.

required
episode_return Float[ArrayLike, '']

Cumulative return for the current episode.

required
episode_length Int[ArrayLike, '']

Number of steps in the current episode.

required
episode_done Bool[ArrayLike, '']

Whether the previous step ended an episode.

required
average_return Float[ArrayLike, '']

EMA of episode returns.

required
average_length Float[ArrayLike, '']

EMA of episode lengths.

required

initial classmethod

initial() -> LoggingCallbackStepState

Return a zeroed initial state.

next

next(
    reward: Float[Array, ""],
    done: Bool[Array, ""],
    alpha: float,
) -> LoggingCallbackStepState

Advance the state by one environment step.

Parameters:

Name Type Description Default
reward Float[Array, '']

Reward received at this step.

required
done Bool[Array, '']

Whether the episode ended at this step.

required
alpha float

EMA smoothing factor (weight on the new episode value).

required

Returns:

Type Description
LoggingCallbackStepState

Updated step state.

lerax.callback.AbstractLoggingBackend

Bases: eqx.Module

Abstract base class for logging backends.

Implementations receive already-converted Python/numpy values; the LoggingCallback handles the JIT-to-numpy boundary.

The lifecycle is open -> log_* -> close. The open method is called by LoggingCallback.__init__ with the run name at construction time. close is called by LoggingCallback.close().

open abstractmethod

open(name: str) -> None

Initialise the backend with the given run name.

Called exactly once before any log_* method. Backends should create their writer, run handle, or output directory here.

Parameters:

Name Type Description Default
name str

Human-readable run name.

required

log_scalars abstractmethod

log_scalars(
    scalars: dict[str, np.ndarray], step: int
) -> None

Log a dictionary of scalar values.

Parameters:

Name Type Description Default
scalars dict[str, np.ndarray]

Scalar values keyed by metric name.

required
step int

Current training step.

required

log_video abstractmethod

log_video(
    tag: str, frames: np.ndarray, step: int, fps: float
) -> None

Log a video clip.

Parameters:

Name Type Description Default
tag str

Metric name for the video.

required
frames np.ndarray

Video frames as a uint8 array of shape (T, H, W, C).

required
step int

Current training step.

required
fps float

Playback frames per second.

required

close abstractmethod

close() -> None

Flush pending data and release any held resources.

lerax.callback.TensorBoardBackend

Bases: AbstractLoggingBackend

Logging backend that writes to TensorBoard via tensorboardX.

The log directory is determined by log_dir / name where name is provided by open.

Parameters:

Name Type Description Default
log_dir str | Path

Base directory for TensorBoard event files.

'logs'

__init__

__init__(log_dir: str | Path = 'logs') -> None

log_scalars

log_scalars(
    scalars: dict[str, np.ndarray], step: int
) -> None

log_video

log_video(
    tag: str, frames: np.ndarray, step: int, fps: float
) -> None

close

close() -> None

lerax.callback.WandbBackend

Bases: AbstractLoggingBackend

Logging backend that writes to Weights & Biases.

wandb.init is called in open and wandb.finish is called in close.

Parameters:

Name Type Description Default
project str | None

W&B project name.

None
config dict[str, Any] | None

Hyperparameter dictionary passed to wandb.init.

None
quiet bool

Suppress W&B console output. Useful when combined with other logging backends like ConsoleBackend.

False

__init__

__init__(
    project: str | None = None,
    config: dict[str, Any] | None = None,
    quiet: bool = False,
) -> None

log_scalars

log_scalars(
    scalars: dict[str, np.ndarray], step: int
) -> None

log_video

log_video(
    tag: str, frames: np.ndarray, step: int, fps: float
) -> None

close

close() -> None

lerax.callback.ConsoleBackend

Bases: AbstractLoggingBackend

Logging backend that displays a live metrics table and progress bar.

Uses Rich's Live display to show a metrics table that updates in-place on each log_scalars call, with a progress bar rendered below it. The table is replaced (not appended) on each update so the display stays compact.

Progress bar display is automatically enabled when running in an interactive terminal. In non-interactive environments (logs, CI, redirected output), progress info (percentage, elapsed time, ETA) is logged with each scalar update.

log_scalars

log_scalars(
    scalars: dict[str, np.ndarray], step: int
) -> None

log_hparams

log_hparams(hparams: dict[str, Any]) -> None

close

close() -> None

lerax.callback.LoggingCallback

Bases: AbstractCallback[EmptyCallbackState, LoggingCallbackStepState]

Callback that collects training metrics and progress and logs to a backend.

Note

This callback must be constructed outside any JIT-compiled function.

Attributes:

Name Type Description
alpha float

EMA smoothing factor for episode statistics.

Parameters:

Name Type Description Default
backend AbstractLoggingBackend | Sequence[AbstractLoggingBackend]

Logging backend (or list of backends) to send metrics to.

required
name str | None

Explicit run name. When None, a name is generated from the environment name, policy name, and a timestamp. If neither env nor policy are provided, falls back to a plain timestamp.

None
env AbstractEnvLike | None

Environment used to derive the run name when name is None.

None
policy AbstractPolicy | None

Policy used to derive the run name when name is None.

None
alpha float

EMA smoothing factor (higher = more weight on recent episodes).

0.9
hparams dict[str, Any] | None

Additional explicit hyperparameters. Merged last so these values take precedence over auto-extracted ones.

None
video_interval int

Record video every this many iterations; 0 disables.

0
video_num_steps int

Environment steps per recorded video.

128
video_width int

Render width in pixels.

640
video_height int

Render height in pixels.

480
video_fps float

Playback frames per second.

50.0

__init__

__init__(
    backend: AbstractLoggingBackend
    | Sequence[AbstractLoggingBackend],
    name: str | None = None,
    env: AbstractEnvLike | None = None,
    policy: AbstractPolicy | None = None,
    alpha: float = 0.9,
    hparams: dict[str, Any] | None = None,
    video_interval: int = 0,
    video_num_steps: int = 128,
    video_width: int = 640,
    video_height: int = 480,
    video_fps: float = 50.0,
) -> None

reset

reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> EmptyCallbackState

step_reset

step_reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> LoggingCallbackStepState

on_step

on_step(
    ctx: StepContext, *, key: Key[Array, ""]
) -> LoggingCallbackStepState

on_iteration

on_iteration(
    ctx: IterationContext, *, key: Key[Array, ""]
) -> EmptyCallbackState

on_training_start

on_training_start(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> EmptyCallbackState

on_training_end

on_training_end(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> EmptyCallbackState