Skip to content

Commit

Permalink
Merge pull request #13 from alexander-soare/add_keypoints_mode
Browse files Browse the repository at this point in the history
Add keypoints mode, and fix order of operations in `add_tee`
  • Loading branch information
alexander-soare authored Jul 4, 2024
2 parents 2061b8f + c6ee94c commit ff76b74
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ If `obs_type` is set to `state`, the observation space is a 5-dimensional vector
environment: [agent_x, agent_y, block_x, block_y, block_angle]. The values are in the range [0, 512] for the agent
and block positions and [0, 2*pi] for the block angle.

If `obs_type` is set to `environment_state_agent_pos` the observation space is a dictionary with:
- `environment_state`: 16-dimensional vector representing the keypoint locations of the T (in [x0, y0, x1, y1, ...]
format). The values are in the range [0, 512].
- `agent_pos`: A 2-dimensional vector representing the position of the robot end-effector.

If `obs_type` is set to `pixels`, the observation space is a 96x96 RGB image of the environment.

### Rewards
Expand Down Expand Up @@ -84,7 +89,7 @@ The episode terminates when the block is at least 95% in the goal zone.
<TimeLimit<OrderEnforcing<PassiveEnvChecker<PushTEnv<gym_pusht/PushT-v0>>>>>
```

* `obs_type`: (str) The observation type. Can be either `state`, `pixels` or `pixels_agent_pos`. Default is `state`.
* `obs_type`: (str) The observation type. Can be either `state`, `environment_state_agent_pos`, `pixels` or `pixels_agent_pos`. Default is `state`.

* `block_cog`: (tuple) The center of gravity of the block if different from the center of mass. Default is `None`.

Expand All @@ -105,7 +110,7 @@ The episode terminates when the block is at least 95% in the goal zone.
Passing the option `options["reset_to_state"]` will reset the environment to a specific state.

> [!WARNING]
> For legacy compatibility, the inner fonctionning has been preserved, and the state set is not the same as the
> For legacy compatibility, the inner functioning has been preserved, and the state set is not the same as the
> the one passed in the argument.
```python
Expand Down
64 changes: 58 additions & 6 deletions gym_pusht/envs/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class PushTEnv(gym.Env):
environment: [agent_x, agent_y, block_x, block_y, block_angle]. The values are in the range [0, 512] for the agent
and block positions and [0, 2*pi] for the block angle.
If `obs_type` is set to `environment_state_agent_pos` the observation space is a dictionary with:
- `environment_state`: 16-dimensional vector representing the keypoint locations of the T (in [x0, y0, x1, y1, ...]
format). The values are in the range [0, 512]. See `get_keypoints` for a diagram showing the location of the
keypoint indices.
- `agent_pos`: A 2-dimensional vector representing the position of the robot end-effector.
If `obs_type` is set to `pixels`, the observation space is a 96x96 RGB image of the environment.
## Rewards
Expand Down Expand Up @@ -84,7 +90,8 @@ class PushTEnv(gym.Env):
<TimeLimit<OrderEnforcing<PassiveEnvChecker<PushTEnv<gym_pusht/PushT-v0>>>>>
```
* `obs_type`: (str) The observation type. Can be either `state`, `pixels` or `pixels_agent_pos`. Default is `state`.
* `obs_type`: (str) The observation type. Can be either `state`, `keypoints`, `pixels` or `pixels_agent_pos`.
Default is `state`.
* `block_cog`: (tuple) The center of gravity of the block if different from the center of mass. Default is `None`.
Expand Down Expand Up @@ -181,6 +188,21 @@ def _initialize_observation_space(self):
high=np.array([512, 512, 512, 512, 2 * np.pi]),
dtype=np.float64,
)
elif self.obs_type == "environment_state_agent_pos":
self.observation_space = spaces.Dict(
{
"environment_state": spaces.Box(
low=np.zeros(16),
high=np.full((16,), 512),
dtype=np.float64,
),
"agent_pos": spaces.Box(
low=np.array([0, 0]),
high=np.array([512, 512]),
dtype=np.float64,
),
},
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8
Expand All @@ -203,7 +225,8 @@ def _initialize_observation_space(self):
)
else:
raise ValueError(
f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, pixels_agent_pos]"
f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, environment_state_agent_pos, "
"pixels_agent_pos]"
)

def _get_coverage(self):
Expand Down Expand Up @@ -364,6 +387,12 @@ def get_obs(self):
block_angle = self.block.angle % (2 * np.pi)
return np.concatenate([agent_position, block_position, [block_angle]], dtype=np.float64)

if self.obs_type == "environment_state_agent_pos":
return {
"environment_state": self.get_keypoints(self._block_shapes).flatten(),
"agent_pos": np.array(self.agent.position),
}

pixels = self._render()
if self.obs_type == "pixels":
return pixels
Expand Down Expand Up @@ -416,7 +445,7 @@ def _setup(self):

# Add agent, block, and goal zone
self.agent = self.add_circle(self.space, (256, 400), 15)
self.block = self.add_tee(self.space, (256, 300), 0)
self.block, self._block_shapes = self.add_tee(self.space, (256, 300), 0)
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
if self.block_cog is not None:
self.block.center_of_gravity = self.block_cog
Expand All @@ -429,7 +458,7 @@ def _setup(self):
def _set_state(self, state):
self.agent.position = list(state[:2])
# Setting angle rotates with respect to center of mass, therefore will modify the geometric position if not
# the same as CoM. Therefore should theoritically set the angle first. But for compatibility with legacy data,
# the same as CoM. Therefore should theoretically set the angle first. But for compatibility with legacy data,
# we do the opposite.
self.block.position = list(state[2:4])
self.block.angle = state[4]
Expand Down Expand Up @@ -482,8 +511,31 @@ def add_tee(space, position, angle, scale=30, color="LightSlateGray", mask=None)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.position = position
body.friction = 1
space.add(body, shape1, shape2)
return body
return body, [shape1, shape2]

@staticmethod
def get_keypoints(block_shapes):
"""Get a (8, 2) numpy array with the T keypoints.
The T is composed of two rectangles each with 4 keypoints.
0───────────1
│ │
3───4───5───2
│ │
│ │
│ │
│ │
7───6
"""
keypoints = []
for shape in block_shapes:
for v in shape.get_vertices():
v = v.rotated(shape.body.angle)
v = v + shape.body.position
keypoints.append(np.array(v))
return np.row_stack(keypoints)
1 change: 1 addition & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
("PushT-v0", "state"),
("PushT-v0", "pixels"),
("PushT-v0", "pixels_agent_pos"),
("PushT-v0", "environment_state_agent_pos"),
],
)
def test_env(env_task, obs_type):
Expand Down

0 comments on commit ff76b74

Please sign in to comment.