Saving and Loading Policies
Lerax uses Equinox to store models as PyTrees on disk.
Any class that inherits from lerax.utils.Serializable gains:
serialize(path)— save the object to a.eqxfile.deserialize(path, ...)(class method) — load the object back from disk.
All policies (AbstractStatelessPolicy / AbstractStatefulPolicy subclasses) are Serializable.
Saving a Model
from jax import random as jr
from lerax.env.classic_control import CartPole
from lerax.policy import MLPActorCriticPolicy
policy_key = jr.key(0)
env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=policy_key)
policy.serialize("model.eqx")
Loading a Model
Use the deserialize class method on the policy (or other Serializable subclass):
from jax import random as jr
from lerax.env.classic_control import CartPole
from lerax.policy import MLPActorCriticPolicy
env = CartPole()
loaded_policy = MLPActorCriticPolicy.deserialize("model.eqx", env, key=jr.key(1)) # (1)!
- Constructor arguments must match the original class signature
After the path argument, you must provide usual constructor arguments for the class.
Values that determine PyTree structure must match the original object (e.g. the env argument above), but other arguments (e.g. random keys) can differ and will be overwritten by the loaded parameters.
Serializable Mixin
If you write your own modules and want the same save/load API, inherit from Serializable:
import equinox as eqx
from lerax.utils import Serializable
class MyModule(Serializable):
a: float
b: float
def __init__(self, a: float, b: float):
self.a = a
self.b = b
m = MyModule(1.0, 2.0)
m.serialize("my_module.eqx")
m = MyModule.deserialize("my_module.eqx", 0.0, 0.0)
This works for policies, algorithms, or any other Equinox module that you want to persist.
ONNX Export
Lerax policies can be exported to the ONNX format for deployment in production runtimes (ONNX Runtime, TensorRT, etc.) that do not require JAX.
Installation
ONNX export requires additional dependencies:
Exporting a Policy
Use to_onnx to export a trained policy's deterministic inference path:
from jax import random as jr
from lerax.env.classic_control import CartPole
from lerax.export import to_onnx
from lerax.policy import MLPActorCriticPolicy
env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=jr.key(0))
to_onnx(policy, output_path="policy.onnx")
The exported model maps a flat observation array to an action. It uses the deterministic mode of the policy (equivalent to calling policy(None, observation) with no random key).