Skip to content

Clip Reward

lerax.wrapper.ClipReward

Bases: AbstractPureTransformRewardWrapper[StateType, ActType, ObsType, MaskType]

Cip the rewards emitted by the wrapped environment to a specified range.

Attributes:

Name Type Description
env AbstractEnvLike[StateType, ActType, ObsType, MaskType]

The environment to wrap.

min Float[Array, '']

The minimum reward value.

max Float[Array, '']

The maximum reward value.

Parameters:

Name Type Description Default
env AbstractEnvLike[StateType, ActType, ObsType, MaskType]

The environment to wrap.

required
min Float[ArrayLike, '']

The minimum reward value.

jnp.asarray(-1.0)
max Float[ArrayLike, '']

The maximum reward value.

jnp.asarray(1.0)

name property

name: str

Return the name of the environment

action_space instance-attribute

action_space: eqx.AbstractVar[
    AbstractSpace[ActType, WrapperMaskType]
]

observation_space instance-attribute

observation_space: eqx.AbstractVar[
    AbstractSpace[ObsType, Any]
]

unwrapped property

unwrapped: AbstractEnv

Return the wrapped environment

env instance-attribute

env: AbstractEnvLike[
    StateType, ActType, ObsType, MaskType
] = env

func instance-attribute

func: Callable[[Float[Array, ""]], Float[Array, ""]] = (
    partial(jnp.clip, min=self.min, max=self.max)
)

min instance-attribute

min: Float[Array, ''] = jnp.asarray(min)

max instance-attribute

max: Float[Array, ''] = jnp.asarray(max)

initial

initial(*, key: Key) -> PureTransformRewardState[StateType]

action_mask

action_mask(
    state: PureTransformRewardState[StateType], *, key: Key
) -> MaskType | None

transition

transition(
    state: PureTransformRewardState[StateType],
    action: ActType,
    *,
    key: Key,
) -> PureTransformRewardState[StateType]

observation

observation(
    state: PureTransformRewardState[StateType], *, key: Key
) -> ObsType

reward

reward(
    state: PureTransformRewardState[StateType],
    action: ActType,
    next_state: PureTransformRewardState[StateType],
    *,
    key: Key,
) -> Float[Array, ""]

terminal

terminal(
    state: PureTransformRewardState[StateType], *, key: Key
) -> Bool[Array, ""]

truncate

truncate(
    state: PureTransformRewardState[StateType],
) -> Bool[Array, ""]

state_info

state_info(
    state: PureTransformRewardState[StateType],
) -> dict

transition_info

transition_info(
    state: PureTransformRewardState[StateType],
    action: ActType,
    next_state: PureTransformRewardState[StateType],
) -> dict

default_renderer

default_renderer() -> AbstractRenderer

Return the default renderer for the wrapped environment

render

render(state: WrapperStateType, renderer: AbstractRenderer)

Render a frame from a state

render_states

render_states(
    states: Sequence[StateType],
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render a sequence of frames from multiple states.

Parameters:

Name Type Description Default
states Sequence[StateType]

A sequence of environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

render_stacked

render_stacked(
    states: StateType,
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render multiple frames from stacked states.

Stacked states are typically batched states stored in a pytree structure.

Parameters:

Name Type Description Default
states StateType

A pytree of stacked environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

reset

reset(*, key: Key) -> tuple[StateType, ObsType, dict]

Wrap the functional logic into a Gym API reset method.

Parameters:

Name Type Description Default
key Key

A JAX PRNG key for any stochasticity in the reset.

required

Returns:

Type Description
tuple[StateType, ObsType, dict]

A tuple of the initial state, initial observation, and additional info.

step

step(
    state: StateType, action: ActType, *, key: Key
) -> tuple[
    StateType,
    ObsType,
    Float[Array, ""],
    Bool[Array, ""],
    Bool[Array, ""],
    dict,
]

Wrap the functional logic into a Gym API step method.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action to take.

required
key Key

A JAX PRNG key for any stochasticity in the step.

required

Returns:

Type Description
tuple[StateType, ObsType, Float[Array, ''], Bool[Array, ''], Bool[Array, ''], dict]

A tuple of the next state, observation, reward, terminal flag, truncate flag, and additional info.

__init__

__init__(
    env: AbstractEnvLike[
        StateType, ActType, ObsType, MaskType
    ],
    min: Float[ArrayLike, ""] = jnp.asarray(-1.0),
    max: Float[ArrayLike, ""] = jnp.asarray(1.0),
)