Skip to content

G1Locomotion

lerax.env.unitree.g1.G1Locomotion

Bases: AbstractG1Env

Unitree G1 locomotion: track velocity commands with a natural gait.

The agent controls 29 joint position targets (legs, waist, arms) and must track randomly sampled linear and angular velocity commands while maintaining a stable bipedal gait. Domain randomization is applied per-episode for sim-to-real transfer.

Observation (103 dims): - Local linear velocity (3) - Gyroscope angular velocity (3) - Gravity vector in body frame (3) - Velocity command [vx, vy, yaw_rate] (3) - Joint angles offset from default (29) - Joint velocities (29) - Last action (29) - Gait phase [cos_l, sin_l, cos_r, sin_r] (4)

Action: [-1, 1]^29 scaled to joint position targets around default pose.

Reward: Weighted sum of velocity tracking, energy penalties, gait phase tracking, and pose regularization, multiplied by dt.

unwrapped property

unwrapped: Self

Return the unwrapped environment

action_mask

action_mask(
    state: G1EnvState, *, key: Key[Array, ""]
) -> None

transition

transition(
    state: G1EnvState,
    action: Float[Array, "29"],
    *,
    key: Key[Array, ""],
) -> G1EnvState

Step physics with domain-randomized model.

Computes motor targets from the action, optionally applies push perturbations or pulling forces, steps physics via lax.scan using the per-episode randomized model on the state, then updates gait phase and tracking.

truncate

truncate(state: G1EnvState) -> Bool[Array, '']

state_info

state_info(state: G1EnvState) -> dict

transition_info

transition_info(
    state: G1EnvState,
    action: Float[Array, "29"],
    next_state: G1EnvState,
) -> dict

default_renderer

default_renderer() -> MujocoRenderer

render

render(state: G1EnvState, renderer: AbstractRenderer)

render_states

render_states(
    states: Sequence[StateType],
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render a sequence of frames from multiple states.

Parameters:

Name Type Description Default
states Sequence[StateType]

A sequence of environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

render_stacked

render_stacked(
    states: StateType,
    renderer: AbstractRenderer | Literal["auto"] = "auto",
    dt: float = 0.0,
)

Render multiple frames from stacked states.

Stacked states are typically batched states stored in a pytree structure.

Parameters:

Name Type Description Default
states StateType

A pytree of stacked environment states to render.

required
renderer AbstractRenderer | Literal['auto']

The renderer to use for rendering. If "auto", uses the default renderer.

'auto'
dt float

The time delay between rendering each frame, in seconds.

0.0

reset

reset(
    *, key: Key[Array, ""]
) -> tuple[StateType, ObsType, dict]

Wrap the functional logic into a Gym API reset method.

Parameters:

Name Type Description Default
key Key[Array, '']

A JAX PRNG key for any stochasticity in the reset.

required

Returns:

Type Description
tuple[StateType, ObsType, dict]

A tuple of the initial state, initial observation, and additional info.

step

step(
    state: StateType,
    action: ActType,
    *,
    key: Key[Array, ""],
) -> tuple[
    StateType,
    ObsType,
    Float[Array, ""],
    Bool[Array, ""],
    Bool[Array, ""],
    dict,
]

Wrap the functional logic into a Gym API step method.

Parameters:

Name Type Description Default
state StateType

The current environment state.

required
action ActType

The action to take.

required
key Key[Array, '']

A JAX PRNG key for any stochasticity in the step.

required

Returns:

Type Description
tuple[StateType, ObsType, Float[Array, ''], Bool[Array, ''], Bool[Array, ''], dict]

A tuple of the next state, observation, reward, terminal flag, truncate flag, and additional info.

_local_linvel

_local_linvel(data: mjx.Data) -> Float[Array, '3']

Local linear velocity from pelvis IMU velocimeter.

_gyro

_gyro(data: mjx.Data) -> Float[Array, '3']

Angular velocity from pelvis gyroscope.

_gravity_vector

_gravity_vector(data: mjx.Data) -> Float[Array, '3']

Gravity direction in pelvis frame (upvector sensor).

_torso_gravity_vector

_torso_gravity_vector(data: mjx.Data) -> Float[Array, '3']

Gravity direction in torso frame.

_global_linvel

_global_linvel(
    data: mjx.Data, site: str = "pelvis"
) -> Float[Array, "3"]

Global linear velocity from frame sensor.

_global_angvel

_global_angvel(
    data: mjx.Data, site: str = "pelvis"
) -> Float[Array, "3"]

Global angular velocity from frame sensor.

_joint_angles_offset

_joint_angles_offset(data: mjx.Data) -> Float[Array, '29']

Joint angles relative to default pose.

_joint_velocities

_joint_velocities(data: mjx.Data) -> Float[Array, '29']

Actuated joint velocities.

_foot_contact

_foot_contact(data: mjx.Data) -> Bool[Array, '2']

Binary foot contact flags from contact sensors.

_self_contact_termination

_self_contact_termination(
    data: mjx.Data,
) -> Bool[Array, ""]

Check for dangerous self-contacts (foot-foot, foot-shin).

_hand_collision

_hand_collision(data: mjx.Data) -> Bool[Array, '']

Check for hand-thigh collisions.

_foot_positions

_foot_positions(data: mjx.Data) -> Float[Array, '2 3']

World-frame positions of foot sites.

_foot_velocities

_foot_velocities(data: mjx.Data) -> Float[Array, '2 3']

Global linear velocities of foot sites.

_snap_to_ground

_snap_to_ground(
    model: mjx.Model, data: mjx.Data
) -> mjx.Data

Shift the robot vertically so its lowest point just touches the ground.

Checks both body positions and foot site positions to find the true lowest point, then adjusts qpos[2] so that point sits at a small clearance above z = 0.

_init_common

_init_common(
    xml_file: str | Path = "scene_mjx.xml",
    control_frequency_hz: float = 50.0,
    action_scale: float = 0.5,
    keyframe_name: str = "knees_bent",
    soft_joint_pos_limit_factor: float = 0.95,
    push_enable: bool = True,
    push_interval_range: tuple[float, float] = (5.0, 10.0),
    push_magnitude_range: tuple[float, float] = (0.1, 2.0),
    noise_level: float = 1.0,
    noise_scales: dict[str, float] | None = None,
    friction_range: tuple[float, float] = (0.4, 1.0),
    friction_loss_scale_range: tuple[float, float] = (
        0.5,
        2.0,
    ),
    armature_scale_range: tuple[float, float] = (1.0, 1.05),
    mass_scale_range: tuple[float, float] = (0.9, 1.1),
    torso_offset_range: tuple[float, float] = (-1.0, 1.0),
    relative_actions: bool = False,
    pulling_force_magnitude: float = 0.0,
    unactuated_steps: int = 0,
)

Shared initialization for all G1 environments.