Skip to content

Gymnasium

lerax.compatibility.gym.GymToLeraxEnv

Bases: AbstractEnv[GymEnvState, Array, Array, None]

Wrapper of a Gymnasium environment to make it compatible with Lerax.

Note

Uses jax's io_callback to wrap the env's reset and step functions. In general, this will be slower than a native JAX environment and prevents vmapped rollout. Also removes the info dict returned by Gymnasium envs since the shape cannot be known ahead of time. Even more so than normal it is important to only call methods in order since the state objects do not contain all necessary information.

Parameters:

Name Type Description Default
env gym.Env

Gymnasium environment to wrap.

required

Attributes:

Name Type Description
name str

Name of the environment.

action_space AbstractSpace

Action space of the environment.

observation_space AbstractSpace

Observation space of the environment.

env gym.Env

The original Gymnasium environment.

name class-attribute

name: str = 'GymnasiumEnv'

action_space instance-attribute

action_space: AbstractSpace = gym_space_to_lerax_space(
    env.action_space
)

observation_space instance-attribute

observation_space: AbstractSpace = gym_space_to_lerax_space(
    env.observation_space
)

__init__

__init__(env: gym.Env)

initial

initial(*args, key: Key, **kwargs) -> GymEnvState

Forwards to the Gymnasium reset.

Note

A seed is generated if none is provided to increase reproducibility.

Parameters:

Name Type Description Default
*args

Positional arguments to pass to env.reset.

()
key Key

JAX PRNG key, used to generate a seed if none is provided.

required
**kwargs

Keyword arguments to pass to env.reset. If "seed" is provided here, it overrides the key-generated seed.

{}

Returns:

Type Description
GymEnvState

The initial environment state.

transition

transition(
    state: GymEnvState, action: Array, *, key: Key
) -> GymEnvState

Forwards to the Gymnasium step.

In practice, this just calls the env's step function via io_callback. This means that the state is ignored and order of operations is important.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required
action Array

Action to take.

required
key Key

Unused.

required

Returns:

Type Description
GymEnvState

Next environment state.

observation

observation(state: GymEnvState, *, key: Key) -> Array

Forwards to the Gymnasium observation.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required

Returns:

Type Description
Array

Observation corresponding to the environment state.

reward

reward(
    state: GymEnvState,
    action: Array,
    next_state: GymEnvState,
    *,
    key: Key,
) -> Float[Array, ""]

Forwards to the Gymnasium reward.

In practice, this just reads the reward from the next_state.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required
action Array

Action taken.

required
next_state GymEnvState

Next environment state.

required

Returns:

Type Description
Float[Array, '']

Reward obtained from the transition.

terminal

terminal(
    state: GymEnvState, *, key: Key
) -> Bool[Array, ""]

Forwards to the Gymnasium terminated flag.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required

Returns:

Type Description
Bool[Array, '']

Boolean indicating whether the state is terminal.

truncate

truncate(state: GymEnvState) -> Bool[Array, '']

Forwards to the Gymnasium truncated flag.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required

Returns:

Type Description
Bool[Array, '']

Boolean indicating whether the state is truncated.

state_info

state_info(state: GymEnvState) -> dict

Empty info dict to ensure stable shapes for JIT compilation.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required

Returns:

Type Description
dict

Empty info dict.

transition_info

transition_info(
    state: GymEnvState,
    action: Array,
    next_state: GymEnvState,
) -> dict

Empty info dict to ensure stable shapes for JIT compilation.

Parameters:

Name Type Description Default
state GymEnvState

Current environment state.

required
action Array

Action taken.

required
next_state GymEnvState

Next environment state.

required

Returns:

Type Description
dict

Empty info dict.

default_renderer

default_renderer() -> AbstractRenderer

Not supported for Gymnasium environments.

Raises:

Type Description
NotImplementedError

Always.

render

render(state: GymEnvState, renderer: AbstractRenderer)

Not supported for Gymnasium environments.

Raises:

Type Description
NotImplementedError

Always.

render_stacked

render_stacked(
    states: StateType,
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render multiple frames from stacked states.

Stacked states are typically batched states stored in a pytree structure.

Parameters:

Name Type Description Default
states StateType

A pytree of stacked environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

reset

reset(*, key: Key) -> tuple[StateType, ObsType, dict]

Wrap the functional logic into a Gym API reset method.

Parameters:

Name Type Description Default
key Key

A JAX PRNG key for any stochasticity in the reset.

required

Returns:

Type Description
tuple[StateType, ObsType, dict]

A tuple of the initial state, initial observation, and additional info.

step

step(
    state: StateType, action: ActType, *, key: Key
) -> tuple[
    StateType,
    ObsType,
    Float[Array, ""],
    Bool[Array, ""],
    Bool[Array, ""],
    dict,
]

Wrap the functional logic into a Gym API step method.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action to take.

required
key Key

A JAX PRNG key for any stochasticity in the step.

required

Returns:

Type Description
tuple[StateType, ObsType, Float[Array, ''], Bool[Array, ''], Bool[Array, ''], dict]

A tuple of the next state, observation, reward, terminal flag, truncate flag, and additional info.

close

close()

lerax.compatibility.gym.LeraxToGymEnv

Bases: gym.Env

Wrapper of an Lerax environment to make it compatible with Gymnasium.

Executes the Lerax env directly (Python side). Keeps an internal eqx state and PRNG.

Attributes:

Name Type Description
metadata dict

Metadata for the Gym environment.

action_space gym.Space

Action space of the environment.

observation_space gym.Space

Observation space of the environment.

render_mode str | None

Render mode for the environment.

env AbstractEnv[StateType, Array, Array, Any]

The Lerax environment to wrap.

state StateType

Current state of the Lerax environment.

key Key

PRNG key for the environment.

Parameters:

Name Type Description Default
env AbstractEnv[StateType, Array, Array, Any]

Lerax environment to wrap.

required
render_mode Literal['human'] | None

Render mode for the environment.

None

metadata class-attribute instance-attribute

metadata: dict = {'render_modes': ['human']}

state instance-attribute

state: StateType

key instance-attribute

key: Key = jr.key(0)

env instance-attribute

env: AbstractEnv[StateType, Array, Array, Any] = env

action_space instance-attribute

action_space: gym.Space = lerax_to_gym_space(
    env.action_space
)

observation_space instance-attribute

observation_space: gym.Space = lerax_to_gym_space(
    env.observation_space
)

render_mode class-attribute instance-attribute

render_mode: str | None = render_mode

__init__

__init__(
    env: AbstractEnv[StateType, Array, Array, Any],
    render_mode: Literal["human"] | None = None,
)

reset

reset(
    *, seed: int | None = None, options: dict | None = None
)

step

step(action)

render

render()

Not supported yet.

Raises:

Type Description
NotImplementedError

Always.

close

close()

Placeholder close method.

Does nothing but completes the Gymnasium Env interface.

lerax.compatibility.gym.gym_space_to_lerax_space

gym_space_to_lerax_space(
    space: gymnasium.Space,
) -> AbstractSpace

Returns a Lerax space corresponding to the given Gymnasium space.

Parameters:

Name Type Description Default
space gymnasium.Space

Gymnasium space to convert.

required

Returns:

Type Description
lerax.space.AbstractSpace

The corresponding Lerax space.

lerax.compatibility.gym.lerax_to_gym_space

lerax_to_gym_space(
    space: lerax.space.AbstractSpace,
) -> gym.Space

Returns a Gymnasium space corresponding to the given Lerax space.

Parameters:

Name Type Description Default
space lerax.space.AbstractSpace

Lerax space to convert.

required

Returns:

Type Description
gymnasium.Space

The corresponding Gymnasium space.