Getting Started with Lerax
Do you want to leverage the power of JAX for high performance reinforcement learning? Lerax is a reinforcement learning library built on top of JAX, designed to make it easy to implement and experiment with RL algorithms while taking advantage of JAX's speed and scalability.
Lerax provides environments, policies, and training algorithms. All with a modular design that makes it easy to compose different components together.
Work in Progress
Lerax is very much a work in progress, but it is already usable for training simple RL agents. The API is still evolving, and there are many features that are yet to be implemented. Additionally, the documentation is still being written, so please bear with me as I continue to improve it.
Installation
Train a policy
from jax import random as jr
from lerax.algorithm import PPO
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
from lerax.env.classic_control import CartPole
from lerax.policy import MLPActorCriticPolicy
policy_key, learn_key = jr.split(jr.key(0), 2)
env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=policy_key)
algo = PPO()
logger = LoggingCallback(
[TensorBoardBackend(), ConsoleBackend()], env=env, policy=policy
)
policy = algo.learn(env, policy, total_timesteps=2**16, key=learn_key, callback=logger)
logger.close()
Acknowledgements
A ton of the code is a slight translation of the code found in the Stable Baselines 3 and Gymnasium libraries. The developers of these excellent libraries have done a great job of creating a solid foundation for reinforcement learning in Python, and I have learned a lot from their design decisions.
In addition, the NDE code is heavily inspired by the work of Patrick Kidger and the entire library is based on his excellent Equinox library along with some use of Diffrax and jaxtyping.