Skip to content

Work in Progress

Lerax is usable for training simple RL agents, but the API is still evolving and the documentation is incomplete. Expect rough edges.

Getting Started with Lerax

Lerax is a reinforcement learning library built on JAX and Equinox. It provides functional environments, policies, and training algorithms that compose cleanly under jax.jit and jax.vmap.

Installation

pip install lerax

Train a policy

The example below trains a PPO agent on CartPole and streams metrics to both the terminal and TensorBoard:

examples/ppo.py
from jax import random as jr

from lerax.algorithm import PPO
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
from lerax.env.classic_control import CartPole
from lerax.policy import MLPActorCriticPolicy

policy_key, learn_key = jr.split(jr.key(0), 2)

env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=policy_key)
algo = PPO()
logger = LoggingCallback(
    [TensorBoardBackend(), ConsoleBackend()], env=env, policy=policy
)

policy = algo.learn(env, policy, total_timesteps=2**16, key=learn_key, callback=logger)
logger.close()

That's a complete training run — no separate config file, no custom training loop. The same shape works for any combination of environment, policy, and algorithm in the library.

Next steps

  • Environments — built-in environments and how to write your own.
  • Compatibility — using Gymnasium, Gymnax, and Stable Baselines3 environments and algorithms with Lerax.
  • Callbacks — logging, progress bars, and custom training hooks.
  • Saving & Loading — serializing policies and exporting to ONNX.

Acknowledgements

A large amount of the code is a translation of patterns from Stable Baselines 3 and Gymnasium; both libraries are excellent foundations for RL in Python and Lerax owes a lot to their design.

The NDE code is heavily inspired by the work of Patrick Kidger, and the entire library is built on his Equinox, Diffrax, and jaxtyping libraries.