Utils
lerax.utils.filter_cond
filter_cond[**ParamType, RetType](
pred: Bool[ArrayLike, ""],
true_fun: Callable[ParamType, RetType],
false_fun: Callable[ParamType, RetType],
*args: ParamType.args,
**kwargs: ParamType.kwargs,
) -> RetType
Like lax.cond but handles non-array leaves (e.g. activation functions
inside Equinox modules).
Note
The non-array leaves of the outputs of true_fun and false_fun
must be identical.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Bool[ArrayLike, '']
|
A boolean scalar determining which branch to select. |
required |
true_fun
|
Callable[ParamType, RetType]
|
A callable to be executed if |
required |
false_fun
|
Callable[ParamType, RetType]
|
A callable to be executed if |
required |
args
|
ParamType.args
|
Positional arguments to be passed to both |
()
|
kwargs
|
ParamType.kwargs
|
Keyword arguments to be passed to both |
{}
|
Returns:
| Type | Description |
|---|---|
RetType
|
The result of |
RetType
|
with array leaves selected via |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the non-array leaves of the outputs of |
lerax.utils.filter_scan
filter_scan[Carry, X, Y](
f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X | None = None,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
) -> tuple[Carry, Y]
An easier to use version of lax.scan. All JAX and Numpy arrays are
traced, and only non-array parts of the carry are static.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[Carry, X], tuple[Carry, Y]]
|
a Python function to be scanned of type |
required |
init
|
Carry
|
an initial loop carry value of type |
required |
xs
|
X | None
|
the value of type |
None
|
length
|
int | None
|
optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in |
None
|
reverse
|
bool
|
optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both |
False
|
unroll
|
int | bool
|
optional non-negative int or bool specifying, in the underlying
operation of the scan primitive, how many scan iterations to unroll within
a single iteration of a loop. If an integer is provided, it determines how
many unrolled loop iterations to run within a single rolled iteration of
the loop. |
1
|
_split_transpose
|
bool
|
experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a map (computing gradients corresponding to the array arguments). Enabling this may increase memory requirements, and so is an experimental feature that may evolve or even be rolled back. |
False
|
Returns:
| Type | Description |
|---|---|
Carry
|
A pair of type |
Y
|
loop carry value and the second element represents the stacked outputs of |
tuple[Carry, Y]
|
the second output of |
lerax.utils.callback_wrapper
callback_wrapper[**InType](
func: Callable[InType, Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[InType, None]
Return a JIT‑safe version of func.
Wraps func in a jax.debug.callback so that it can be used inside JIT‑compiled
code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[InType, Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[InType, None]
|
A wrapped version of func that is JIT-safe. |
lerax.utils.callback_with_numpy_wrapper
callback_with_numpy_wrapper(
func: Callable[..., Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain numpy.ndarray before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[..., None]
|
A wrapped version of func that converts array arguments to numpy |
Callable[..., None]
|
arrays and is JIT-safe. |
lerax.utils.callback_with_list_wrapper
callback_with_list_wrapper(
func: Callable[..., Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain list before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[..., None]
|
A wrapped version of func that converts array arguments to lists and |
Callable[..., None]
|
is JIT-safe. |
lerax.utils.unstack_pytree
Split a stacked pytree along axis into a tuple of pytrees with the same
structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
A pytree with array leaves stacked along |
required |
axis
|
int
|
The axis along which to unstack the arrays. |
0
|
Returns:
| Type | Description |
|---|---|
Sequence[T]
|
A sequence of pytrees with the same structure, each corresponding to one |
Sequence[T]
|
slice along |
lerax.utils.polyak_average
Polyak-average the parameters of two modules.
Returns a new module whose inexact-array leaves are
tau * online + (1 - tau) * target, with all other leaves
taken from online.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
online
|
T
|
The online (source) module. |
required |
target
|
T
|
The target module to update towards. |
required |
tau
|
float
|
Interpolation coefficient in |
required |
Returns:
| Type | Description |
|---|---|
T
|
The updated target module. |
lerax.utils.Serializable
Bases: eqx.Module
serialize
Serialize the model to the specified path.
Writes a 32-byte structural fingerprint followed by the Equinox
leaf data, so deserialize can verify that the skeleton it builds
matches what was saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to serialize to. The |
required |
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path.
The constructor arguments must reproduce the same static structure
(class, hyperparameters, network shapes, activations, ...) that the
model had when it was serialized. A 32-byte fingerprint stored in
the file is verified before loading; mismatches raise ValueError
instead of silently loading arrays into the wrong skeleton.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor. |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the structural fingerprint of the rebuilt skeleton does not match the one stored in the file. |