Skip to content

Utils

lerax.utils.filter_scan

filter_scan(
    f,
    init,
    xs=None,
    length=None,
    reverse: bool = False,
    unroll: int | bool = 1,
    _split_transpose: bool = False,
)

An easier to use version of lax.scan. All JAX and Numpy arrays are traced, and only non-array parts of the carry are static.

Parameters:

Name Type Description Default
f

a Python function to be scanned of type c -> a -> (c, b), meaning that f accepts two arguments where the first is a value of the loop carry and the second is a slice of xs along its leading axis, and that f returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.

required
init

an initial loop carry value of type c, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by f.

required
xs

the value of type [a] over which to scan along the leading axis, where [a] can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.

None
length

optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in xs (but can be used to perform scans where no input xs are needed).

None
reverse bool

optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

False
unroll int | bool

optional non-negative int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. unroll=0 unrolls the entire loop. If a boolean is provided, it will determine if the loop is completely unrolled (i.e. unroll=True) or left completely rolled (i.e. unroll=False).

1
_split_transpose bool

experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a map (computing gradients corresponding to the array arguments). Enabling this may increase memory requirements, and so is an experimental feature that may evolve or even be rolled back.

False

Returns:

Type Description

A pair of type (c, [b]) where the first element represents the final

loop carry value and the second element represents the stacked outputs of

the second output of f when scanned over the leading axis of the inputs.

lerax.utils.callback_wrapper

callback_wrapper[**InType](
    func: Callable[InType, Any], ordered: bool = False
) -> Callable[InType, None]

Return a JIT‑safe version of func.

Wraps func in a jax.debug.callback so that it can be used inside JIT‑compiled code.

Parameters:

Name Type Description Default
func Callable[InType, Any]

The callback function to wrap.

required
ordered bool

Whether to enforce ordered execution of callbacks.

False

Returns:

Type Description
Callable[InType, None]

A wrapped version of func that is JIT-safe.

lerax.utils.callback_with_numpy_wrapper

callback_with_numpy_wrapper(
    func: Callable[..., Any], ordered: bool = False
) -> Callable[..., None]

Like debug_wrapper but converts every jax.Array/jnp.ndarray argument to a plain numpy.ndarray before calling func.

It is impossible with Python's current type system to express the transformation so parameter information is lost.

Parameters:

Name Type Description Default
func Callable[..., Any]

The callback function to wrap.

required
ordered bool

Whether to enforce ordered execution of callbacks.

False

Returns:

Type Description
Callable[..., None]

A wrapped version of func that converts array arguments to numpy

Callable[..., None]

arrays and is JIT-safe.

lerax.utils.callback_with_list_wrapper

callback_with_list_wrapper(
    func: Callable[..., Any], ordered: bool = False
) -> Callable[..., None]

Like debug_wrapper but converts every jax.Array/jnp.ndarray argument to a plain list before calling func.

It is impossible with Python's current type system to express the transformation so parameter information is lost.

Parameters:

Name Type Description Default
func Callable[..., Any]

The callback function to wrap.

required
ordered bool

Whether to enforce ordered execution of callbacks.

False

Returns:

Type Description
Callable[..., None]

A wrapped version of func that converts array arguments to lists and

Callable[..., None]

is JIT-safe.

lerax.utils.unstack_pytree

unstack_pytree[T](tree: T, *, axis: int = 0) -> Sequence[T]

Split a stacked pytree along axis into a tuple of pytrees with the same structure.

Parameters:

Name Type Description Default
tree T

A pytree with array leaves stacked along axis.

required
axis int

The axis along which to unstack the arrays.

0

Returns:

Type Description
Sequence[T]

A sequence of pytrees with the same structure, each corresponding to one

Sequence[T]

slice along axis.

lerax.utils.Serializable

Bases: eqx.Module

serialize

serialize(
    path: str | Path, no_suffix: bool = False
) -> None

Serialize the model to the specified path.

Parameters:

Name Type Description Default
path str | Path

The path to serialize to.

required
no_suffix bool

If True, do not append the ".eqx" suffix

False

deserialize classmethod

deserialize[**Params, ClassType](
    path: str | Path,
    *args: Params.args,
    **kwargs: Params.kwargs,
) -> ClassType

Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.

Parameters:

Name Type Description Default
path str | Path

The path to deserialize from.

required
*args Params.args

Additional arguments to pass to the class constructor

()
**kwargs Params.kwargs

Additional keyword arguments to pass to the class constructor

{}

Returns:

Type Description
ClassType

The deserialized model.