Utils
lerax.utils.filter_scan
filter_scan(
f,
init,
xs=None,
length=None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
)
An easier to use version of lax.scan. All JAX and Numpy arrays are
traced, and only non-array parts of the carry are static.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
a Python function to be scanned of type |
required | |
init
|
an initial loop carry value of type |
required | |
xs
|
the value of type |
None
|
|
length
|
optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in |
None
|
|
reverse
|
|
optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both |
False
|
unroll
|
|
optional non-negative int or bool specifying, in the underlying
operation of the scan primitive, how many scan iterations to unroll within
a single iteration of a loop. If an integer is provided, it determines how
many unrolled loop iterations to run within a single rolled iteration of
the loop. |
1
|
_split_transpose
|
|
experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a map (computing gradients corresponding to the array arguments). Enabling this may increase memory requirements, and so is an experimental feature that may evolve or even be rolled back. |
False
|
Returns:
| Type | Description |
|---|---|
|
A pair of type |
|
|
loop carry value and the second element represents the stacked outputs of |
|
|
the second output of |
lerax.utils.callback_wrapper
callback_wrapper[**InType](
func: Callable[InType, Any], ordered: bool = False
) -> Callable[InType, None]
Return a JIT‑safe version of func.
Wraps func in a jax.debug.callback so that it can be used inside JIT‑compiled
code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
|
The callback function to wrap. |
required |
ordered
|
|
Whether to enforce ordered execution of callbacks. |
False
|
Returns:
| Type | Description |
|---|---|
|
A wrapped version of func that is JIT-safe. |
lerax.utils.callback_with_numpy_wrapper
callback_with_numpy_wrapper(
func: Callable[..., Any], ordered: bool = False
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain numpy.ndarray before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
|
The callback function to wrap. |
required |
ordered
|
|
Whether to enforce ordered execution of callbacks. |
False
|
Returns:
| Type | Description |
|---|---|
|
A wrapped version of func that converts array arguments to numpy |
|
arrays and is JIT-safe. |
lerax.utils.callback_with_list_wrapper
callback_with_list_wrapper(
func: Callable[..., Any], ordered: bool = False
) -> Callable[..., None]
Like debug_wrapper but converts every jax.Array/jnp.ndarray argument
to a plain list before calling func.
It is impossible with Python's current type system to express the transformation so parameter information is lost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
|
The callback function to wrap. |
required |
ordered
|
|
Whether to enforce ordered execution of callbacks. |
False
|
Returns:
| Type | Description |
|---|---|
|
A wrapped version of func that converts array arguments to lists and |
|
is JIT-safe. |
lerax.utils.unstack_pytree
Split a stacked pytree along axis into a tuple of pytrees with the same
structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
|
A pytree with array leaves stacked along |
required |
axis
|
|
The axis along which to unstack the arrays. |
0
|
Returns:
| Type | Description |
|---|---|
|
A sequence of pytrees with the same structure, each corresponding to one |
|
slice along |
lerax.utils.Serializable
Bases:
serialize
Serialize the model to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to serialize to. |
required |
no_suffix
|
|
If True, do not append the ".eqx" suffix |
False
|
deserialize
classmethod
deserialize[**Params, ClassType](
path: str | Path,
*args: Params.args,
**kwargs: Params.kwargs,
) -> ClassType
Deserialize the model from the specified path. Must provide any additional arguments required by the class constructor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
|
The path to deserialize from. |
required |
*args
|
|
Additional arguments to pass to the class constructor |
()
|
**kwargs
|
|
Additional keyword arguments to pass to the class constructor |
{}
|
Returns:
| Type | Description |
|---|---|
|
The deserialized model. |