TensorBoardCallback
TensorBoardCallback logs training statistics to TensorBoard via tensorboardX.SummaryWriter.
It is JIT-safe through debug callbacks, but must be constructed outside any JIT-compiled function.
Logged metrics
At the end of each iteration (on_iteration), it logs:
-
episode/returnExponential moving average of episode returns over all environments. -
episode/lengthExponential moving average of episode lengths. -
train/*All scalars intraining_log(from the algorithm), plus: -
train/learning_rate(if present in the Optax optimizer state; otherwiseNaN).
NaN or Inf scalar values are rejected with an error.
The smoothing factor for the exponential moving averages is controlled by alpha (default 0.9).
Initialization and log directory
from lerax.callback import TensorBoardCallback
callback = TensorBoardCallback(
name=None, # optional; log directory name
env=env, # optional; used for default name
policy=policy, # optional; used for default name
alpha=0.9, # EMA smoothing for episode stats
)
If name is None, a directory under logs/ is created using the current time and available names:
"logs/{policy.name}_{env.name}_{time}""logs/{env.name}_{time}"or"logs/{policy.name}_{time}""logs/training_{time}"as a fallback.
Behind the scenes, the callback uses a JITSummaryWriter wrapper and callback_with_numpy_wrapper so that logging can be triggered from within JAX-transformed code while the actual I/O happens on the Python side.