Skip to content

Curriculum

lerax.curriculum.ScheduledCurriculum

Bases: AbstractStatelessCallback

Curriculum callback that modifies an environment field on a fixed schedule.

Uses eqx.tree_at to update a field on state.env each iteration based on the current iteration count.

Multiple ScheduledCurriculum instances can be composed via CallbackList to schedule multiple fields simultaneously.

Attributes:

Name Type Description
where Callable

A function selecting the field to modify on the env, e.g. lambda env: env.mass.

schedule_fn Callable

A function mapping iteration count to the scheduled parameter value.

Parameters:

Name Type Description Default
where

Selector for the env field to schedule.

required
schedule_fn

Schedule function (see linear_schedule, step_schedule, cosine_schedule).

required

Example::

from lerax.curriculum import ScheduledCurriculum, linear_schedule

curriculum = ScheduledCurriculum(
    where=lambda env: env.m,
    schedule_fn=linear_schedule(start=0.5, end=2.0, total=1000),
)
algo.learn(env, policy, total_timesteps=..., key=key, callback=curriculum)

where class-attribute instance-attribute

where: Callable = eqx.field(static=True)

schedule_fn class-attribute instance-attribute

schedule_fn: Callable = eqx.field(static=True)

reset

reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> EmptyCallbackState

step_reset

step_reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> EmptyCallbackStepState

on_step

on_step(ctx: StepContext, *, key: Key[Array, ''])

on_iteration

on_iteration(ctx: IterationContext, *, key: Key[Array, ''])

on_training_start

on_training_start(ctx, *, key: Key[Array, ''])

on_training_end

on_training_end(ctx, *, key: Key[Array, ''])

continue_training

continue_training(
    ctx: IterationContext, *, key: Key[Array, ""]
)

apply_curriculum

apply_curriculum[S: "AbstractAlgorithmState"](
    state: S, callback_state: EmptyCallbackState
) -> tuple[S, EmptyCallbackState]

lerax.curriculum.AbstractAdaptiveCurriculum

Bases: AbstractCallback[AdaptiveCurriculumState, AdaptiveCurriculumStepState]

Abstract base for adaptive curricula that track a performance metric.

Handles metric accumulation in on_step and EMA smoothing in on_iteration. Subclasses implement apply_curriculum to decide how the metric drives environment changes.

Attributes:

Name Type Description
metric_fn Callable

Function (done, reward, locals_dict) -> scalar that extracts a per-step metric contribution. Called every step; the value is accumulated per episode.

smoothing float

EMA smoothing factor for the running metric. Higher values give more weight to recent episodes.

Parameters:

Name Type Description Default
metric_fn

Per-step metric extraction function.

required
smoothing

EMA smoothing factor (default 0.05).

required

metric_fn class-attribute instance-attribute

metric_fn: Callable = eqx.field(static=True)

smoothing instance-attribute

smoothing: float

reset

reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

step_reset

step_reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumStepState

on_step

on_step(
    ctx: StepContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumStepState

on_iteration

on_iteration(
    ctx: IterationContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

on_training_start

on_training_start(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

on_training_end

on_training_end(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

continue_training

continue_training(
    ctx: IterationContext, *, key: Key[Array, ""]
) -> Bool[Array, ""]

apply_curriculum abstractmethod

apply_curriculum[S: "AbstractAlgorithmState"](
    state: S, callback_state: AdaptiveCurriculumState
) -> tuple[S, AdaptiveCurriculumState]

Modify the algorithm state based on the tracked metric.

Called after on_iteration at the end of each training iteration. The running metric and current level are available in callback_state.

Parameters:

Name Type Description Default
state S

The current algorithm state.

required
callback_state AdaptiveCurriculumState

This callback's own state containing running_metric and level.

required

Returns:

Type Description
S

A tuple of the (possibly modified) algorithm state and

AdaptiveCurriculumState

the (possibly modified) callback state.

lerax.curriculum.LevelCurriculum

Bases: AbstractAdaptiveCurriculum

Adaptive curriculum with discrete parameter levels.

Advances to the next level when the running performance metric exceeds threshold. Each level maps to a specific value for an environment field, applied via eqx.tree_at.

Attributes:

Name Type Description
where Callable

Selector for the env field to modify, e.g. lambda env: env.max_speed.

levels Float[Array, ' num_levels']

Array of parameter values for each level.

metric_fn

Per-step metric extraction function.

threshold float

Advance to the next level when the running metric exceeds this value.

smoothing

EMA smoothing factor for the running metric.

Parameters:

Name Type Description Default
where Callable

Selector for the env field to modify.

required
levels Float[Array, ' num_levels']

Parameter values per level.

required
metric_fn Callable

Per-step metric extraction function.

required
threshold float

Advancement threshold.

required
smoothing float

EMA smoothing factor (default 0.05).

0.05

Example::

from lerax.curriculum import LevelCurriculum

curriculum = LevelCurriculum(
    where=lambda env: env.max_speed,
    levels=jnp.array([4.0, 6.0, 8.0]),
    metric_fn=lambda done, reward, locals: reward,
    threshold=100.0,
)
algo.learn(env, policy, total_timesteps=..., key=key, callback=curriculum)

where class-attribute instance-attribute

where: Callable = where

levels instance-attribute

levels: Float[Array, ' num_levels'] = levels

threshold instance-attribute

threshold: float = threshold

metric_fn instance-attribute

metric_fn = metric_fn

smoothing instance-attribute

smoothing = smoothing

reset

reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

step_reset

step_reset(
    ctx: ResetContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumStepState

on_step

on_step(
    ctx: StepContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumStepState

on_iteration

on_iteration(
    ctx: IterationContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

on_training_start

on_training_start(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

on_training_end

on_training_end(
    ctx: TrainingContext, *, key: Key[Array, ""]
) -> AdaptiveCurriculumState

continue_training

continue_training(
    ctx: IterationContext, *, key: Key[Array, ""]
) -> Bool[Array, ""]

__init__

__init__(
    where: Callable,
    levels: Float[Array, " num_levels"],
    metric_fn: Callable,
    threshold: float,
    smoothing: float = 0.05,
)

apply_curriculum

apply_curriculum[S: "AbstractAlgorithmState"](
    state: S, callback_state: AdaptiveCurriculumState
) -> tuple[S, AdaptiveCurriculumState]

lerax.curriculum.linear_schedule

linear_schedule(
    start: float, end: float, total: int
) -> Callable[[Int[Array, ""]], Float[Array, ""]]

Linear interpolation from start to end over total iterations.

Clamps to [start, end] outside the range.

Parameters:

Name Type Description Default
start float

Value at iteration 0.

required
end float

Value at iteration total.

required
total int

Number of iterations for the full transition.

required

Returns:

Type Description
Callable[[Int[Array, '']], Float[Array, '']]

A function mapping iteration count to the scheduled value.

lerax.curriculum.step_schedule

step_schedule(
    values: list[float], boundaries: list[int]
) -> Callable[[Int[Array, ""]], Float[Array, ""]]

Step-wise schedule that jumps between discrete values at specified iteration boundaries.

Parameters:

Name Type Description Default
values list[float]

Parameter values for each stage. Length must be len(boundaries) + 1.

required
boundaries list[int]

Iteration counts at which to transition to the next value.

required

Returns:

Type Description
Callable[[Int[Array, '']], Float[Array, '']]

A function mapping iteration count to the scheduled value.

lerax.curriculum.cosine_schedule

cosine_schedule(
    start: float, end: float, total: int
) -> Callable[[Int[Array, ""]], Float[Array, ""]]

Cosine annealing from start to end over total iterations.

Parameters:

Name Type Description Default
start float

Value at iteration 0.

required
end float

Value at iteration total.

required
total int

Number of iterations for the full transition.

required

Returns:

Type Description
Callable[[Int[Array, '']], Float[Array, '']]

A function mapping iteration count to the scheduled value.