Utils
lerax.utils.filter_cond
filter_cond[**ParamType, RetType](
pred: Bool[ArrayLike, ""],
true_fun: Callable[ParamType, RetType],
false_fun: Callable[ParamType, RetType],
*args: ParamType.args,
**kwargs: ParamType.kwargs,
) -> RetType
Like lax.cond but handles non-array leaves (e.g. activation functions
inside Equinox modules).
Note
The non-array leaves of the outputs of true_fun and false_fun
must be identical.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Bool[ArrayLike, '']
|
A boolean scalar determining which branch to select. |
required |
true_fun
|
Callable[ParamType, RetType]
|
A callable to be executed if |
required |
false_fun
|
Callable[ParamType, RetType]
|
A callable to be executed if |
required |
args
|
ParamType.args
|
Positional arguments to be passed to both |
()
|
kwargs
|
ParamType.kwargs
|
Keyword arguments to be passed to both |
{}
|
Returns:
| Type | Description |
|---|---|
RetType
|
The result of |
RetType
|
with array leaves selected via |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the non-array leaves of the outputs of |
lerax.utils.filter_scan
filter_scan[Carry, X, Y](
f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X | None = None,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
) -> tuple[Carry, Y]
An easier to use version of lax.scan. All JAX and Numpy arrays are
traced, and only non-array parts of the carry are static.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[Carry, X], tuple[Carry, Y]]
|
a Python function to be scanned of type |
required |
init
|
Carry
|
an initial loop carry value of type |
required |
xs
|
X | None
|
the value of type |
None
|
length
|
int | None
|
optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in |
None
|
reverse
|
bool
|
optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both |
False
|
unroll
|
int | bool
|
optional non-negative int or bool specifying, in the underlying
operation of the scan primitive, how many scan iterations to unroll within
a single iteration of a loop. If an integer is provided, it determines how
many unrolled loop iterations to run within a single rolled iteration of
the loop. |
1
|
_split_transpose
|
bool
|
experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a map (computing gradients corresponding to the array arguments). Enabling this may increase memory requirements, and so is an experimental feature that may evolve or even be rolled back. |
False
|
Returns:
| Type | Description |
|---|---|
Carry
|
A pair of type |
Y
|
loop carry value and the second element represents the stacked outputs of |
tuple[Carry, Y]
|
the second output of |
lerax.utils.callback_wrapper
callback_wrapper[**InType](
func: Callable[InType, Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[InType, None]
Return a JIT‑safe version of func.
Wraps func in a jax.debug.callback so that it can be used inside JIT‑compiled
code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[InType, Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[InType, None]
|
A wrapped version of func that is JIT-safe. |
lerax.utils.callback_with_numpy_wrapper
callback_with_numpy_wrapper(
func: Callable[..., Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain numpy.ndarray before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[..., None]
|
A wrapped version of func that converts array arguments to numpy |
Callable[..., None]
|
arrays and is JIT-safe. |
lerax.utils.callback_with_list_wrapper
callback_with_list_wrapper(
func: Callable[..., Any],
ordered: bool = False,
partitioned: bool = False,
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain list before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
The callback function to wrap. |
required |
ordered
|
bool
|
Whether to enforce ordered execution of callbacks. |
False
|
partitioned
|
bool
|
If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first. |
False
|
Returns:
| Type | Description |
|---|---|
Callable[..., None]
|
A wrapped version of func that converts array arguments to lists and |
Callable[..., None]
|
is JIT-safe. |
lerax.utils.unstack_pytree
Split a stacked pytree along axis into a tuple of pytrees with the same
structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
A pytree with array leaves stacked along |
required |
axis
|
int
|
The axis along which to unstack the arrays. |
0
|
Returns:
| Type | Description |
|---|---|
Sequence[T]
|
A sequence of pytrees with the same structure, each corresponding to one |
Sequence[T]
|
slice along |
lerax.utils.Serializable
Bases: eqx.Module
serialize
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
The path to deserialize from. |
required |
*args
|
Params.args
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
Params.kwargs
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
ClassType
|
The deserialized model. |