Skip to content

Wrapper

lerax.wrapper.AbstractWrapper

Bases: AbstractEnvLike[WrapperStateType, WrapperActType, WrapperObsType, WrapperMaskType]

Base class for environment wrappers.

Attributes:

Name Type Description
name str

The name of the environment

env eqx.AbstractVar[AbstractEnvLike[StateType, ActType, ObsType, MaskType]]

The wrapped environment

unwrapped AbstractEnv

The environment without any wrappers

action_space eqx.AbstractVar[AbstractSpace[ActType, WrapperMaskType]]

The action space of the environment after wrapping

observation_space eqx.AbstractVar[AbstractSpace[ObsType, Any]]

The observation space of the environment after wrapping

unwrapped property

unwrapped: AbstractEnv

Return the wrapped environment

name property

name: str

Return the name of the environment

initial abstractmethod

initial(*, key: Key) -> StateType

Generate the initial state of the environment.

Parameters:

Name Type Description Default
key Key

A JAX PRNG key for any stochasticity in the initial state.

required

Returns:

Type Description
StateType

An initial environment state.

action_mask abstractmethod

action_mask(
    state: StateType, *, key: Key
) -> MaskType | None

Generate an action mask from the environment state.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
key Key

A JAX PRNG key for any stochasticity in the action mask.

required

Returns:

Type Description
MaskType | None

A mask indicating valid and invalid actions for the environment state.

transition abstractmethod

transition(
    state: StateType, action: ActType, *, key: Key
) -> StateType

Update the environment state given an action.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action to take.

required
key Key

A JAX PRNG key for any stochasticity in the transition.

required

Returns:

Type Description
StateType

The next environment state.

observation abstractmethod

observation(state: StateType, *, key: Key) -> ObsType

Generate an observation from the environment state.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
key Key

A JAX PRNG key for any stochasticity in the observation.

required

Returns:

Type Description
ObsType

An observation corresponding to the environment state.

reward abstractmethod

reward(
    state: StateType,
    action: ActType,
    next_state: StateType,
    *,
    key: Key,
) -> Float[Array, ""]

Generate a reward from the environment state transition.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action taken.

required
next_state StateType

The next environment state.

required
key Key

A JAX PRNG key for any stochasticity in the reward.

required

Returns:

Type Description
Float[Array, '']

A reward corresponding to the environment state transition.

terminal abstractmethod

terminal(state: StateType, *, key: Key) -> Bool[Array, '']

Determine whether the environment state is terminal.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
key Key

A JAX PRNG key for any stochasticity in the terminal condition.

required

Returns:

Type Description
Bool[Array, '']

A boolean indicating whether the environment state is terminal.

truncate abstractmethod

truncate(state: StateType) -> Bool[Array, '']

Determine whether the environment state is truncated.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required

Returns:

Type Description
Bool[Array, '']

A boolean indicating whether the environment state is truncated.

state_info abstractmethod

state_info(state: StateType) -> dict

Generate additional info from the environment state.

In many cases, this can simply return an empty dictionary.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required

Returns:

Type Description
dict

A dictionary of additional info from the environment state.

transition_info abstractmethod

transition_info(
    state: StateType, action: ActType, next_state: StateType
) -> dict

Generate additional info from the environment state transition.

In many cases, this can simply return an empty dictionary.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action taken.

required
next_state StateType

The next environment state.

required

Returns:

Type Description
dict

A dictionary of additional info from the environment state transition.

render_states

render_states(
    states: Sequence[StateType],
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render a sequence of frames from multiple states.

Parameters:

Name Type Description Default
states Sequence[StateType]

A sequence of environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

render_stacked

render_stacked(
    states: StateType,
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render multiple frames from stacked states.

Stacked states are typically batched states stored in a pytree structure.

Parameters:

Name Type Description Default
states StateType

A pytree of stacked environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

reset

reset(*, key: Key) -> tuple[StateType, ObsType, dict]

Wrap the functional logic into a Gym API reset method.

Parameters:

Name Type Description Default
key Key

A JAX PRNG key for any stochasticity in the reset.

required

Returns:

Type Description
tuple[StateType, ObsType, dict]

A tuple of the initial state, initial observation, and additional info.

step

step(
    state: StateType, action: ActType, *, key: Key
) -> tuple[
    StateType,
    ObsType,
    Float[Array, ""],
    Bool[Array, ""],
    Bool[Array, ""],
    dict,
]

Wrap the functional logic into a Gym API step method.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action to take.

required
key Key

A JAX PRNG key for any stochasticity in the step.

required

Returns:

Type Description
tuple[StateType, ObsType, Float[Array, ''], Bool[Array, ''], Bool[Array, ''], dict]

A tuple of the next state, observation, reward, terminal flag, truncate flag, and additional info.

default_renderer

default_renderer() -> AbstractRenderer

Return the default renderer for the wrapped environment

render

render(state: WrapperStateType, renderer: AbstractRenderer)

Render a frame from a state

options: members: ["env", "unwrapped", "initial", "action_mask", "transition", "observation", "reward", "terminal", "truncate", "state_info", "transition_info", "reset", "step", "default_renderer", "render"]

lerax.wrapper.AbstractWrapperState

Bases: AbstractEnvLikeState

unwrapped property

unwrapped: AbstractEnvState

The state of the wrapped environment

options: members: ["env_state", "unwrapped"]