Transform Reward
lerax.wrapper.TransformReward
Bases:
Apply an arbitrary function to the rewards emitted by the wrapped environment.
Attributes:
| Name | Type | Description |
|---|---|---|
|
|
The environment to wrap. |
|
|
The function to apply to the rewards. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
|
The environment to wrap. |
required |
func
|
|
The function to apply to the rewards. |
required |
action_space
instance-attribute
observation_space
instance-attribute
action_mask
transition
transition(
state: PureTransformRewardState[StateType],
action: ActType,
*,
key: Key,
) -> PureTransformRewardState[StateType]
reward
reward(
state: PureTransformRewardState[StateType],
action: ActType,
next_state: PureTransformRewardState[StateType],
*,
key: Key,
) -> Float[Array, ""]
transition_info
transition_info(
state: PureTransformRewardState[StateType],
action: ActType,
next_state: PureTransformRewardState[StateType],
) -> dict
default_renderer
Return the default renderer for the wrapped environment
render_states
render_states(
states: Sequence[StateType],
renderer: AbstractRenderer | Literal["auto"] = "auto",
dt: float = 0.0,
)
Render a sequence of frames from multiple states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
|
A sequence of environment states to render. |
required |
renderer
|
|
The renderer to use for rendering. If "auto", uses the default renderer. |
'auto'
|
dt
|
|
The time delay between rendering each frame, in seconds. |
0.0
|
render_stacked
render_stacked(
states: StateType,
renderer: AbstractRenderer | Literal["auto"] = "auto",
dt: float = 0.0,
)
Render multiple frames from stacked states.
Stacked states are typically batched states stored in a pytree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
|
A pytree of stacked environment states to render. |
required |
renderer
|
|
The renderer to use for rendering. If "auto", uses the default renderer. |
'auto'
|
dt
|
|
The time delay between rendering each frame, in seconds. |
0.0
|
reset
Wrap the functional logic into a Gym API reset method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
|
A JAX PRNG key for any stochasticity in the reset. |
required |
Returns:
| Type | Description |
|---|---|
|
A tuple of the initial state, initial observation, and additional info. |
step
step(
state: StateType, action: ActType, *, key: Key
) -> tuple[
StateType,
ObsType,
Float[Array, ""],
Bool[Array, ""],
Bool[Array, ""],
dict,
]
Wrap the functional logic into a Gym API step method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
|
The current environment state. |
required |
action
|
|
The action to take. |
required |
key
|
|
A JAX PRNG key for any stochasticity in the step. |
required |
Returns:
| Type | Description |
|---|---|
|
A tuple of the next state, observation, reward, terminal flag, truncate flag, and additional info. |