diff --git a/mujoco_playground/_src/dm_control_suite/acrobot.py b/mujoco_playground/_src/dm_control_suite/acrobot.py index 22f3399..951178f 100644 --- a/mujoco_playground/_src/dm_control_suite/acrobot.py +++ b/mujoco_playground/_src/dm_control_suite/acrobot.py @@ -142,10 +142,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 6 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/ball_in_cup.py b/mujoco_playground/_src/dm_control_suite/ball_in_cup.py index 830ac71..1564b52 100644 --- a/mujoco_playground/_src/dm_control_suite/ball_in_cup.py +++ b/mujoco_playground/_src/dm_control_suite/ball_in_cup.py @@ -116,10 +116,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 8 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/cartpole.py b/mujoco_playground/_src/dm_control_suite/cartpole.py index 3b01c29..2df6e63 100644 --- a/mujoco_playground/_src/dm_control_suite/cartpole.py +++ b/mujoco_playground/_src/dm_control_suite/cartpole.py @@ -278,10 +278,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 5 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/cheetah.py b/mujoco_playground/_src/dm_control_suite/cheetah.py index 7160335..74d253d 100644 --- a/mujoco_playground/_src/dm_control_suite/cheetah.py +++ b/mujoco_playground/_src/dm_control_suite/cheetah.py @@ -135,10 +135,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 17 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py index 124838e..f8ec3ab 100644 --- a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py +++ b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py @@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.reset)(jax.random.PRNGKey(42)) state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) - self.assertEqual(state.obs.shape[0], env.observation_size) + self.assertEqual(state.obs.shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/dm_control_suite/finger.py b/mujoco_playground/_src/dm_control_suite/finger.py index b0f8d1a..26d21ed 100644 --- a/mujoco_playground/_src/dm_control_suite/finger.py +++ b/mujoco_playground/_src/dm_control_suite/finger.py @@ -188,10 +188,6 @@ def sim_dt(self) -> float: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 9 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model @@ -344,10 +340,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 12 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/fish.py b/mujoco_playground/_src/dm_control_suite/fish.py index bee0098..e270af7 100644 --- a/mujoco_playground/_src/dm_control_suite/fish.py +++ b/mujoco_playground/_src/dm_control_suite/fish.py @@ -187,10 +187,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 24 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/hopper.py b/mujoco_playground/_src/dm_control_suite/hopper.py index dbcc480..a11db19 100644 --- a/mujoco_playground/_src/dm_control_suite/hopper.py +++ b/mujoco_playground/_src/dm_control_suite/hopper.py @@ -194,10 +194,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 15 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/humanoid.py b/mujoco_playground/_src/dm_control_suite/humanoid.py index 6f3f584..33c9f27 100644 --- a/mujoco_playground/_src/dm_control_suite/humanoid.py +++ b/mujoco_playground/_src/dm_control_suite/humanoid.py @@ -215,10 +215,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 67 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/pendulum.py b/mujoco_playground/_src/dm_control_suite/pendulum.py index 6e3544e..5d0d32c 100644 --- a/mujoco_playground/_src/dm_control_suite/pendulum.py +++ b/mujoco_playground/_src/dm_control_suite/pendulum.py @@ -132,10 +132,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 3 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/point_mass.py b/mujoco_playground/_src/dm_control_suite/point_mass.py index 9a4b291..760cdfb 100644 --- a/mujoco_playground/_src/dm_control_suite/point_mass.py +++ b/mujoco_playground/_src/dm_control_suite/point_mass.py @@ -138,10 +138,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 4 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/reacher.py b/mujoco_playground/_src/dm_control_suite/reacher.py index ea65865..5c933e7 100644 --- a/mujoco_playground/_src/dm_control_suite/reacher.py +++ b/mujoco_playground/_src/dm_control_suite/reacher.py @@ -158,10 +158,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 6 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/swimmer.py b/mujoco_playground/_src/dm_control_suite/swimmer.py index 850b9c7..402c776 100644 --- a/mujoco_playground/_src/dm_control_suite/swimmer.py +++ b/mujoco_playground/_src/dm_control_suite/swimmer.py @@ -233,10 +233,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 25 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/dm_control_suite/walker.py b/mujoco_playground/_src/dm_control_suite/walker.py index 0952a90..fd79375 100644 --- a/mujoco_playground/_src/dm_control_suite/walker.py +++ b/mujoco_playground/_src/dm_control_suite/walker.py @@ -185,10 +185,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 94 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/locomotion/barkour/joystick.py b/mujoco_playground/_src/locomotion/barkour/joystick.py index c14b7a9..72a085e 100644 --- a/mujoco_playground/_src/locomotion/barkour/joystick.py +++ b/mujoco_playground/_src/locomotion/barkour/joystick.py @@ -457,10 +457,6 @@ def xml_path(self) -> str: def action_size(self) -> int: return self.mjx_model.nu - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 465 - @property def mj_model(self) -> mujoco.MjModel: return self._mj_model diff --git a/mujoco_playground/_src/locomotion/berkeley_humanoid/joystick.py b/mujoco_playground/_src/locomotion/berkeley_humanoid/joystick.py index 650d470b..b43c00d 100644 --- a/mujoco_playground/_src/locomotion/berkeley_humanoid/joystick.py +++ b/mujoco_playground/_src/locomotion/berkeley_humanoid/joystick.py @@ -667,10 +667,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: jp.zeros(3), jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]), ) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return { - "state": (52,), - "privileged_state": (114,), - } diff --git a/mujoco_playground/_src/locomotion/g1/joystick.py b/mujoco_playground/_src/locomotion/g1/joystick.py index b5fce64..561a189 100644 --- a/mujoco_playground/_src/locomotion/g1/joystick.py +++ b/mujoco_playground/_src/locomotion/g1/joystick.py @@ -790,10 +790,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: jp.zeros(3), jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]), ) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return { - "state": (103,), - "privileged_state": (216,), - } diff --git a/mujoco_playground/_src/locomotion/go1/getup.py b/mujoco_playground/_src/locomotion/go1/getup.py index a26fc9c..fe3246d 100644 --- a/mujoco_playground/_src/locomotion/go1/getup.py +++ b/mujoco_playground/_src/locomotion/go1/getup.py @@ -367,10 +367,3 @@ def _cost_dof_vel(self, qvel: jax.Array) -> jax.Array: def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array: return jp.sum(jp.square(qacc)) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return { - "state": (42,), - "privileged_state": (91,), - } diff --git a/mujoco_playground/_src/locomotion/go1/handstand.py b/mujoco_playground/_src/locomotion/go1/handstand.py index 65c72f0..8597256 100644 --- a/mujoco_playground/_src/locomotion/go1/handstand.py +++ b/mujoco_playground/_src/locomotion/go1/handstand.py @@ -374,13 +374,6 @@ def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array: def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array: return jp.sum(jp.square(qacc)) - @property - def observation_size(self) -> mjx_env.ObservationSize: - return { - "state": (45,), - "privileged_state": (94,), - } - class Footstand(Handstand): """Footstand task for Go1.""" diff --git a/mujoco_playground/_src/locomotion/go1/joystick.py b/mujoco_playground/_src/locomotion/go1/joystick.py index 36f9b0e..aec444a 100644 --- a/mujoco_playground/_src/locomotion/go1/joystick.py +++ b/mujoco_playground/_src/locomotion/go1/joystick.py @@ -599,10 +599,3 @@ def sample_command(self, rng: jax.Array, x_k: jax.Array) -> jax.Array: w_k = jax.random.bernoulli(w_rng, 0.5, shape=(3,)) x_kp1 = x_k - w_k * (x_k - y_k * z_k) return x_kp1 - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return { - "state": (48,), - "privileged_state": (123,), - } diff --git a/mujoco_playground/_src/locomotion/h1/inplace_gait_tracking.py b/mujoco_playground/_src/locomotion/h1/inplace_gait_tracking.py index 6d8f9c5..946b0ae 100644 --- a/mujoco_playground/_src/locomotion/h1/inplace_gait_tracking.py +++ b/mujoco_playground/_src/locomotion/h1/inplace_gait_tracking.py @@ -375,7 +375,3 @@ def _cost_ang_vel(self, global_angvel: jax.Array) -> jax.Array: def _cost_lin_vel(self, global_linvel: jax.Array) -> jax.Array: # Penalize xy linear velocity. return jp.sum(jp.square(global_linvel[:2])) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 186 diff --git a/mujoco_playground/_src/locomotion/h1/joystick_gait_tracking.py b/mujoco_playground/_src/locomotion/h1/joystick_gait_tracking.py index 90e1ff1..d2b3c14 100644 --- a/mujoco_playground/_src/locomotion/h1/joystick_gait_tracking.py +++ b/mujoco_playground/_src/locomotion/h1/joystick_gait_tracking.py @@ -477,7 +477,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: ) cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]) return cmd - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 113 diff --git a/mujoco_playground/_src/locomotion/locomotion_test.py b/mujoco_playground/_src/locomotion/locomotion_test.py index 6afd8f8..971a849 100644 --- a/mujoco_playground/_src/locomotion/locomotion_test.py +++ b/mujoco_playground/_src/locomotion/locomotion_test.py @@ -36,7 +36,6 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) - obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/locomotion/op3/joystick.py b/mujoco_playground/_src/locomotion/op3/joystick.py index 9c333e3..43942f3 100644 --- a/mujoco_playground/_src/locomotion/op3/joystick.py +++ b/mujoco_playground/_src/locomotion/op3/joystick.py @@ -466,7 +466,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]) return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 147 diff --git a/mujoco_playground/_src/locomotion/spot/getup.py b/mujoco_playground/_src/locomotion/spot/getup.py index 3eaa9c9..e0453ae 100644 --- a/mujoco_playground/_src/locomotion/spot/getup.py +++ b/mujoco_playground/_src/locomotion/spot/getup.py @@ -277,7 +277,3 @@ def _cost_action_rate( c1 = jp.sum(jp.square(act - info["last_act"])) c2 = jp.sum(jp.square(act - 2 * info["last_act"] + info["last_last_act"])) return c1 + c2 - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 30 diff --git a/mujoco_playground/_src/locomotion/spot/joystick.py b/mujoco_playground/_src/locomotion/spot/joystick.py index f0f8857..ecedb90 100644 --- a/mujoco_playground/_src/locomotion/spot/joystick.py +++ b/mujoco_playground/_src/locomotion/spot/joystick.py @@ -612,7 +612,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: ) cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]) return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return {"privileged_state": (167,), "state": (81,)} diff --git a/mujoco_playground/_src/locomotion/spot/joystick_gait_tracking.py b/mujoco_playground/_src/locomotion/spot/joystick_gait_tracking.py index 1faf5cf..9fb9e58 100644 --- a/mujoco_playground/_src/locomotion/spot/joystick_gait_tracking.py +++ b/mujoco_playground/_src/locomotion/spot/joystick_gait_tracking.py @@ -386,7 +386,3 @@ def sample_command(self, rng: jax.Array) -> jax.Array: ) cmd = jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw]) return jp.where(jax.random.bernoulli(rng4, 0.1), jp.zeros(3), cmd) - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 69 diff --git a/mujoco_playground/_src/manipulation/aloha/handover.py b/mujoco_playground/_src/manipulation/aloha/handover.py index 537b04c..c54d658 100644 --- a/mujoco_playground/_src/manipulation/aloha/handover.py +++ b/mujoco_playground/_src/manipulation/aloha/handover.py @@ -283,7 +283,3 @@ def _get_obs(self, data: mjx.Data, info: Dict[str, Any]) -> jax.Array: ]) return obs - - @property - def observation_size(self) -> int: - return 83 diff --git a/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py b/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py index e8fcf59..5230d7c 100644 --- a/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py +++ b/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py @@ -252,7 +252,3 @@ def _get_reward( "peg_z_up": peg_orientation * peg_lift, "peg_insertion_reward": peg_insertion_reward, } - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 82 diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py b/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py index cf23d7e..6a98955 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py @@ -213,7 +213,3 @@ def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: ]) return obs - - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 55 diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/pick.py b/mujoco_playground/_src/manipulation/franka_emika_panda/pick.py index 2a2d73f..bd3f903 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/pick.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/pick.py @@ -230,10 +230,6 @@ def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: return obs - @property - def observation_size(self) -> mjx_env.ObservationSize: - return 66 - class PandaPickCubeOrientation(PandaPickCube): """Bring a box to a target and orientation.""" diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py b/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py index 4415ebb..dcbd277 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py @@ -433,11 +433,6 @@ def _move_tip( return new_ctrl, new_tip_pos, no_soln - @property - def observation_size(self) -> mjx_env.ObservationSize: - ret = {'pixels/view_0': (64, 64, 3)} if self._vision else 70 - return ret - @property def action_size(self) -> int: return 3 diff --git a/mujoco_playground/_src/manipulation/franka_emika_panda_robotiq/push_cube.py b/mujoco_playground/_src/manipulation/franka_emika_panda_robotiq/push_cube.py index e578d5f..649bad2 100644 --- a/mujoco_playground/_src/manipulation/franka_emika_panda_robotiq/push_cube.py +++ b/mujoco_playground/_src/manipulation/franka_emika_panda_robotiq/push_cube.py @@ -567,7 +567,3 @@ def _get_single_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: @property def action_size(self): return 7 - - @property - def observation_size(self): - return 48 diff --git a/mujoco_playground/_src/manipulation/leap_hand/reorient.py b/mujoco_playground/_src/manipulation/leap_hand/reorient.py index 3ae7f03..838aca9 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/reorient.py +++ b/mujoco_playground/_src/manipulation/leap_hand/reorient.py @@ -474,10 +474,6 @@ def get_xfrc( data = state.data.replace(xfrc_applied=xfrc) return state.replace(data=data) - @property - def observation_size(self) -> mjx_env.ObservationSize: - return {"privileged_state": (128,), "state": (57,)} - def domain_randomize(model: mjx.Model, rng: jax.Array): mj_model = CubeReorient().mj_model diff --git a/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py b/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py index ac883e7..186ffb4 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py +++ b/mujoco_playground/_src/manipulation/leap_hand/rotate_z.py @@ -260,10 +260,6 @@ def _cost_action_rate( def _cost_pose(self, joint_angles: jax.Array) -> jax.Array: return jp.sum(jp.square(joint_angles - self._default_pose)) - @property - def observation_size(self) -> mjx_env.ObservationSize: - return {"privileged_state": (105,), "state": (32,)} - def domain_randomize(model: mjx.Model, rng: jax.Array): mj_model = CubeRotateZAxis().mj_model diff --git a/mujoco_playground/_src/manipulation/manipulation_test.py b/mujoco_playground/_src/manipulation/manipulation_test.py index 0b52d2e..971aa19 100644 --- a/mujoco_playground/_src/manipulation/manipulation_test.py +++ b/mujoco_playground/_src/manipulation/manipulation_test.py @@ -36,7 +36,6 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) - obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index 1a4de43..4871554 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -270,12 +270,8 @@ def n_substeps(self) -> int: @property def observation_size(self) -> ObservationSize: - rng = jax.random.PRNGKey(0) - reset_state = self.unwrapped.reset(rng) - obs = reset_state.obs - if isinstance(obs, jax.Array): - return obs.shape[-1] - return jax.tree_util.tree_map(lambda x: x.shape, obs) + out = jax.eval_shape(self.reset, jax.random.PRNGKey(0)) + return jax.tree_util.tree_map(lambda x: x.shape, out.obs) def render( self,