diff --git a/README.md b/README.md index 04655fb..74eadd9 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,11 @@ If `obs_type` is set to `state`, the observation space is a 5-dimensional vector environment: [agent_x, agent_y, block_x, block_y, block_angle]. The values are in the range [0, 512] for the agent and block positions and [0, 2*pi] for the block angle. +If `obs_type` is set to `environment_state_agent_pos` the observation space is a dictionary with: + - `environment_state`: 16-dimensional vector representing the keypoint locations of the T (in [x0, y0, x1, y1, ...] + format). The values are in the range [0, 512]. + - `agent_pos`: A 2-dimensional vector representing the position of the robot end-effector. + If `obs_type` is set to `pixels`, the observation space is a 96x96 RGB image of the environment. ### Rewards @@ -84,7 +89,7 @@ The episode terminates when the block is at least 95% in the goal zone. >>>> ``` -* `obs_type`: (str) The observation type. Can be either `state`, `pixels` or `pixels_agent_pos`. Default is `state`. +* `obs_type`: (str) The observation type. Can be either `state`, `environment_state_agent_pos`, `pixels` or `pixels_agent_pos`. Default is `state`. * `block_cog`: (tuple) The center of gravity of the block if different from the center of mass. Default is `None`. @@ -105,7 +110,7 @@ The episode terminates when the block is at least 95% in the goal zone. Passing the option `options["reset_to_state"]` will reset the environment to a specific state. > [!WARNING] -> For legacy compatibility, the inner fonctionning has been preserved, and the state set is not the same as the +> For legacy compatibility, the inner functioning has been preserved, and the state set is not the same as the > the one passed in the argument. ```python diff --git a/gym_pusht/envs/pusht.py b/gym_pusht/envs/pusht.py index 6c3b820..73c9a35 100644 --- a/gym_pusht/envs/pusht.py +++ b/gym_pusht/envs/pusht.py @@ -56,6 +56,12 @@ class PushTEnv(gym.Env): environment: [agent_x, agent_y, block_x, block_y, block_angle]. The values are in the range [0, 512] for the agent and block positions and [0, 2*pi] for the block angle. + If `obs_type` is set to `environment_state_agent_pos` the observation space is a dictionary with: + - `environment_state`: 16-dimensional vector representing the keypoint locations of the T (in [x0, y0, x1, y1, ...] + format). The values are in the range [0, 512]. See `get_keypoints` for a diagram showing the location of the + keypoint indices. + - `agent_pos`: A 2-dimensional vector representing the position of the robot end-effector. + If `obs_type` is set to `pixels`, the observation space is a 96x96 RGB image of the environment. ## Rewards @@ -84,7 +90,8 @@ class PushTEnv(gym.Env): >>>> ``` - * `obs_type`: (str) The observation type. Can be either `state`, `pixels` or `pixels_agent_pos`. Default is `state`. + * `obs_type`: (str) The observation type. Can be either `state`, `keypoints`, `pixels` or `pixels_agent_pos`. + Default is `state`. * `block_cog`: (tuple) The center of gravity of the block if different from the center of mass. Default is `None`. @@ -181,6 +188,21 @@ def _initialize_observation_space(self): high=np.array([512, 512, 512, 512, 2 * np.pi]), dtype=np.float64, ) + elif self.obs_type == "environment_state_agent_pos": + self.observation_space = spaces.Dict( + { + "environment_state": spaces.Box( + low=np.zeros(16), + high=np.full((16,), 512), + dtype=np.float64, + ), + "agent_pos": spaces.Box( + low=np.array([0, 0]), + high=np.array([512, 512]), + dtype=np.float64, + ), + }, + ) elif self.obs_type == "pixels": self.observation_space = spaces.Box( low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8 @@ -203,7 +225,8 @@ def _initialize_observation_space(self): ) else: raise ValueError( - f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, pixels_agent_pos]" + f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, environment_state_agent_pos, " + "pixels_agent_pos]" ) def _get_coverage(self): @@ -364,6 +387,12 @@ def get_obs(self): block_angle = self.block.angle % (2 * np.pi) return np.concatenate([agent_position, block_position, [block_angle]], dtype=np.float64) + if self.obs_type == "environment_state_agent_pos": + return { + "environment_state": self.get_keypoints(self._block_shapes).flatten(), + "agent_pos": np.array(self.agent.position), + } + pixels = self._render() if self.obs_type == "pixels": return pixels @@ -416,7 +445,7 @@ def _setup(self): # Add agent, block, and goal zone self.agent = self.add_circle(self.space, (256, 400), 15) - self.block = self.add_tee(self.space, (256, 300), 0) + self.block, self._block_shapes = self.add_tee(self.space, (256, 300), 0) self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) if self.block_cog is not None: self.block.center_of_gravity = self.block_cog @@ -429,7 +458,7 @@ def _setup(self): def _set_state(self, state): self.agent.position = list(state[:2]) # Setting angle rotates with respect to center of mass, therefore will modify the geometric position if not - # the same as CoM. Therefore should theoritically set the angle first. But for compatibility with legacy data, + # the same as CoM. Therefore should theoretically set the angle first. But for compatibility with legacy data, # we do the opposite. self.block.position = list(state[2:4]) self.block.angle = state[4] @@ -482,8 +511,31 @@ def add_tee(space, position, angle, scale=30, color="LightSlateGray", mask=None) shape1.filter = pymunk.ShapeFilter(mask=mask) shape2.filter = pymunk.ShapeFilter(mask=mask) body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position body.angle = angle + body.position = position body.friction = 1 space.add(body, shape1, shape2) - return body + return body, [shape1, shape2] + + @staticmethod + def get_keypoints(block_shapes): + """Get a (8, 2) numpy array with the T keypoints. + + The T is composed of two rectangles each with 4 keypoints. + + 0───────────1 + │ │ + 3───4───5───2 + │ │ + │ │ + │ │ + │ │ + 7───6 + """ + keypoints = [] + for shape in block_shapes: + for v in shape.get_vertices(): + v = v.rotated(shape.body.angle) + v = v + shape.body.position + keypoints.append(np.array(v)) + return np.row_stack(keypoints) diff --git a/tests/test_env.py b/tests/test_env.py index f747e2a..91562ba 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -11,6 +11,7 @@ ("PushT-v0", "state"), ("PushT-v0", "pixels"), ("PushT-v0", "pixels_agent_pos"), + ("PushT-v0", "environment_state_agent_pos"), ], ) def test_env(env_task, obs_type):