Skip to content

Commit

Permalink
Replace hardcoded observation size with jax.eval_shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720599097
Change-Id: I71b7914417b389155c716b08c08cf5781adfdfe5
  • Loading branch information
kevinzakka authored and copybara-github committed Jan 28, 2025
1 parent 5305951 commit 4424fa1
Show file tree
Hide file tree
Showing 37 changed files with 3 additions and 161 deletions.
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 6

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/ball_in_cup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 8

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 5

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 17

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.reset)(jax.random.PRNGKey(42))
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
self.assertEqual(state.obs.shape[0], env.observation_size)
self.assertEqual(state.obs.shape, env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())


Expand Down
8 changes: 0 additions & 8 deletions mujoco_playground/_src/dm_control_suite/finger.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,6 @@ def sim_dt(self) -> float:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 9

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down Expand Up @@ -344,10 +340,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 12

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 24

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 15

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 67

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 3

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/point_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 4

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/reacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 6

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/swimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 25

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/dm_control_suite/walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 94

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/locomotion/barkour/joystick.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,6 @@ def xml_path(self) -> str:
def action_size(self) -> int:
return self.mjx_model.nu

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 465

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:
jp.zeros(3),
jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]),
)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {
"state": (52,),
"privileged_state": (114,),
}
7 changes: 0 additions & 7 deletions mujoco_playground/_src/locomotion/g1/joystick.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:
jp.zeros(3),
jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]),
)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {
"state": (103,),
"privileged_state": (216,),
}
7 changes: 0 additions & 7 deletions mujoco_playground/_src/locomotion/go1/getup.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,3 @@ def _cost_dof_vel(self, qvel: jax.Array) -> jax.Array:

def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array:
return jp.sum(jp.square(qacc))

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {
"state": (42,),
"privileged_state": (91,),
}
7 changes: 0 additions & 7 deletions mujoco_playground/_src/locomotion/go1/handstand.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,6 @@ def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array:
def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array:
return jp.sum(jp.square(qacc))

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {
"state": (45,),
"privileged_state": (94,),
}


class Footstand(Handstand):
"""Footstand task for Go1."""
Expand Down
7 changes: 0 additions & 7 deletions mujoco_playground/_src/locomotion/go1/joystick.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,3 @@ def sample_command(self, rng: jax.Array, x_k: jax.Array) -> jax.Array:
w_k = jax.random.bernoulli(w_rng, 0.5, shape=(3,))
x_kp1 = x_k - w_k * (x_k - y_k * z_k)
return x_kp1

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {
"state": (48,),
"privileged_state": (123,),
}
4 changes: 0 additions & 4 deletions mujoco_playground/_src/locomotion/h1/inplace_gait_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,3 @@ def _cost_ang_vel(self, global_angvel: jax.Array) -> jax.Array:
def _cost_lin_vel(self, global_linvel: jax.Array) -> jax.Array:
# Penalize xy linear velocity.
return jp.sum(jp.square(global_linvel[:2]))

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 186
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:
)
cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])
return cmd

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 113
1 change: 0 additions & 1 deletion mujoco_playground/_src/locomotion/locomotion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs)
obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape
self.assertEqual(obs_shape, env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())

Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/locomotion/op3/joystick.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:

cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])
return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 147
4 changes: 0 additions & 4 deletions mujoco_playground/_src/locomotion/spot/getup.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,3 @@ def _cost_action_rate(
c1 = jp.sum(jp.square(act - info["last_act"]))
c2 = jp.sum(jp.square(act - 2 * info["last_act"] + info["last_last_act"]))
return c1 + c2

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 30
4 changes: 0 additions & 4 deletions mujoco_playground/_src/locomotion/spot/joystick.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:
)
cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])
return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {"privileged_state": (167,), "state": (81,)}
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array:
)
cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])
return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 69
4 changes: 0 additions & 4 deletions mujoco_playground/_src/manipulation/aloha/handover.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,3 @@ def _get_obs(self, data: mjx.Data, info: Dict[str, Any]) -> jax.Array:
])

return obs

@property
def observation_size(self) -> int:
return 83
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,3 @@ def _get_reward(
"peg_z_up": peg_orientation * peg_lift,
"peg_insertion_reward": peg_insertion_reward,
}

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 82
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,3 @@ def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
])

return obs

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 55
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:

return obs

@property
def observation_size(self) -> mjx_env.ObservationSize:
return 66


class PandaPickCubeOrientation(PandaPickCube):
"""Bring a box to a target and orientation."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,6 @@ def _move_tip(

return new_ctrl, new_tip_pos, no_soln

@property
def observation_size(self) -> mjx_env.ObservationSize:
ret = {'pixels/view_0': (64, 64, 3)} if self._vision else 70
return ret

@property
def action_size(self) -> int:
return 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,3 @@ def _get_single_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
@property
def action_size(self):
return 7

@property
def observation_size(self):
return 48
4 changes: 0 additions & 4 deletions mujoco_playground/_src/manipulation/leap_hand/reorient.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,6 @@ def get_xfrc(
data = state.data.replace(xfrc_applied=xfrc)
return state.replace(data=data)

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {"privileged_state": (128,), "state": (57,)}


def domain_randomize(model: mjx.Model, rng: jax.Array):
mj_model = CubeReorient().mj_model
Expand Down
4 changes: 0 additions & 4 deletions mujoco_playground/_src/manipulation/leap_hand/rotate_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,6 @@ def _cost_action_rate(
def _cost_pose(self, joint_angles: jax.Array) -> jax.Array:
return jp.sum(jp.square(joint_angles - self._default_pose))

@property
def observation_size(self) -> mjx_env.ObservationSize:
return {"privileged_state": (105,), "state": (32,)}


def domain_randomize(model: mjx.Model, rng: jax.Array):
mj_model = CubeRotateZAxis().mj_model
Expand Down
1 change: 0 additions & 1 deletion mujoco_playground/_src/manipulation/manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs)
obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape
self.assertEqual(obs_shape, env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())

Expand Down
8 changes: 2 additions & 6 deletions mujoco_playground/_src/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,8 @@ def n_substeps(self) -> int:

@property
def observation_size(self) -> ObservationSize:
rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
obs = reset_state.obs
if isinstance(obs, jax.Array):
return obs.shape[-1]
return jax.tree_util.tree_map(lambda x: x.shape, obs)
out = jax.eval_shape(self.reset, jax.random.PRNGKey(0))
return jax.tree_util.tree_map(lambda x: x.shape, out.obs)

def render(
self,
Expand Down

0 comments on commit 4424fa1

Please sign in to comment.