Curriculum
lerax.curriculum.ScheduledCurriculum
Bases: AbstractStatelessCallback
Curriculum callback that modifies an environment field on a fixed schedule.
Uses eqx.tree_at to update a field on state.env each
iteration based on the current iteration count.
Multiple ScheduledCurriculum instances can be composed via
CallbackList to schedule multiple fields simultaneously.
Attributes:
| Name | Type | Description |
|---|---|---|
where |
Callable
|
A function selecting the field to modify on the env,
e.g. |
schedule_fn |
Callable
|
A function mapping iteration count to the scheduled parameter value. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Selector for the env field to schedule. |
required | |
schedule_fn
|
Schedule function (see |
required |
Example::
from lerax.curriculum import ScheduledCurriculum, linear_schedule
curriculum = ScheduledCurriculum(
where=lambda env: env.m,
schedule_fn=linear_schedule(start=0.5, end=2.0, total=1000),
)
algo.learn(env, policy, total_timesteps=..., key=key, callback=curriculum)
lerax.curriculum.AbstractAdaptiveCurriculum
Bases: AbstractCallback[AdaptiveCurriculumState, AdaptiveCurriculumStepState]
Abstract base for adaptive curricula that track a performance metric.
Handles metric accumulation in on_step and EMA smoothing in
on_iteration. Subclasses implement apply_curriculum to
decide how the metric drives environment changes.
Attributes:
| Name | Type | Description |
|---|---|---|
metric_fn |
Callable
|
Function |
smoothing |
float
|
EMA smoothing factor for the running metric. Higher values give more weight to recent episodes. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metric_fn
|
Per-step metric extraction function. |
required | |
smoothing
|
EMA smoothing factor (default 0.05). |
required |
on_iteration
on_training_start
on_training_end
continue_training
apply_curriculum
abstractmethod
apply_curriculum[S: "AbstractAlgorithmState"](
state: S, callback_state: AdaptiveCurriculumState
) -> tuple[S, AdaptiveCurriculumState]
Modify the algorithm state based on the tracked metric.
Called after on_iteration at the end of each training
iteration. The running metric and current level are available
in callback_state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The current algorithm state. |
required |
callback_state
|
AdaptiveCurriculumState
|
This callback's own state containing
|
required |
Returns:
| Type | Description |
|---|---|
S
|
A tuple of the (possibly modified) algorithm state and |
AdaptiveCurriculumState
|
the (possibly modified) callback state. |
lerax.curriculum.LevelCurriculum
Bases: AbstractAdaptiveCurriculum
Adaptive curriculum with discrete parameter levels.
Advances to the next level when the running performance metric
exceeds threshold. Each level maps to a specific value for
an environment field, applied via eqx.tree_at.
Attributes:
| Name | Type | Description |
|---|---|---|
where |
Callable
|
Selector for the env field to modify, e.g.
|
levels |
Float[Array, ' num_levels']
|
Array of parameter values for each level. |
metric_fn |
Per-step metric extraction function. |
|
threshold |
float
|
Advance to the next level when the running metric exceeds this value. |
smoothing |
EMA smoothing factor for the running metric. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable
|
Selector for the env field to modify. |
required |
levels
|
Float[Array, ' num_levels']
|
Parameter values per level. |
required |
metric_fn
|
Callable
|
Per-step metric extraction function. |
required |
threshold
|
float
|
Advancement threshold. |
required |
smoothing
|
float
|
EMA smoothing factor (default 0.05). |
0.05
|
Example::
from lerax.curriculum import LevelCurriculum
curriculum = LevelCurriculum(
where=lambda env: env.max_speed,
levels=jnp.array([4.0, 6.0, 8.0]),
metric_fn=lambda done, reward, locals: reward,
threshold=100.0,
)
algo.learn(env, policy, total_timesteps=..., key=key, callback=curriculum)
lerax.curriculum.linear_schedule
linear_schedule(
start: float, end: float, total: int
) -> Callable[[Int[Array, ""]], Float[Array, ""]]
Linear interpolation from start to end over total iterations.
Clamps to [start, end] outside the range.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
float
|
Value at iteration 0. |
required |
end
|
float
|
Value at iteration |
required |
total
|
int
|
Number of iterations for the full transition. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Int[Array, '']], Float[Array, '']]
|
A function mapping iteration count to the scheduled value. |
lerax.curriculum.step_schedule
step_schedule(
values: list[float], boundaries: list[int]
) -> Callable[[Int[Array, ""]], Float[Array, ""]]
Step-wise schedule that jumps between discrete values at specified iteration boundaries.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values
|
list[float]
|
Parameter values for each stage. Length must be
|
required |
boundaries
|
list[int]
|
Iteration counts at which to transition to the next value. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Int[Array, '']], Float[Array, '']]
|
A function mapping iteration count to the scheduled value. |
lerax.curriculum.cosine_schedule
cosine_schedule(
start: float, end: float, total: int
) -> Callable[[Int[Array, ""]], Float[Array, ""]]
Cosine annealing from start to end over total iterations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
float
|
Value at iteration 0. |
required |
end
|
float
|
Value at iteration |
required |
total
|
int
|
Number of iterations for the full transition. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Int[Array, '']], Float[Array, '']]
|
A function mapping iteration count to the scheduled value. |