Skip to content

Export

lerax.export.to_onnx

to_onnx(
    policy: AbstractPolicy,
    *,
    output_path: str | Path | None = None,
    model_name: str = "lerax_policy",
    opset: int = 21,
) -> onnx_lib.ModelProto

Export a policy's deterministic inference path to an ONNX model.

Traces policy(None, observation) with key=None (deterministic mode) and converts the resulting JAX computation graph to ONNX format using jax2onnx.

The exported model maps a flat observation array to an action array.

Parameters:

Name Type Description Default
policy AbstractPolicy

The trained policy to export.

required
output_path str | Path | None

If provided, save the ONNX model to this file path.

None
model_name str

Name embedded in the ONNX model metadata.

'lerax_policy'
opset int

Target ONNX opset version.

21

Returns:

Type Description
onnx_lib.ModelProto

The ONNX ModelProto object.

Raises:

Type Description
ImportError

If jax2onnx is not installed.

Example
from lerax.export import to_onnx

proto = to_onnx(policy, output_path="policy.onnx")