Skip to content

Commit

Permalink
Merge pull request #2 from huggingface/user/rcadene/2024_04_03_compat…
Browse files Browse the repository at this point in the history
…iblity

Make it compatible with LeRobot
  • Loading branch information
aliberts authored Apr 4, 2024
2 parents ec00cd6 + e9c5712 commit 0fe4449
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ pip install gym-pusht
import gymnasium as gym
import gym_pusht

env = gym.make("gym_pusht/PushT-v0", render_mode="human")
env = gym.make("gym_pusht/PushT-v0", render_mode="human", render_size=(84,84), visualization_size=(680,680))
observation, info = env.reset()

for _ in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
image = env.render()

if terminated or truncated:
observation, info = env.reset()
Expand Down
33 changes: 29 additions & 4 deletions gym_pusht/envs/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def __init__(
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict({
"pixels": spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8),
"agent_pos": spaces.Box(
low=np.array([0, 0], dtype=np.float32),
high=np.array([512, 512], dtype=np.float32),
),
})

self.action_space = spaces.Box(low=0, high=512, shape=(2,), dtype=np.float32)

Expand Down Expand Up @@ -186,14 +194,14 @@ def step(self, action):
coverage = intersection_area / goal_area
reward = np.clip(coverage / self.success_threshold, 0.0, 1.0)
is_success = coverage > self.success_threshold
terminated = is_success

observation = self._get_obs()
info = self._get_info()

info["is_success"] = is_success
terminated = is_success

return observation, reward, terminated, False, info
truncated = False
return observation, reward, terminated, truncated, info

def reset(self, seed=None, options=None):
super().reset(seed=seed)
Expand All @@ -202,11 +210,22 @@ def reset(self, seed=None, options=None):
if options is not None and options.get("reset_to_state") is not None:
state = np.array(options.get("reset_to_state"))
else:
state = self.np_random.uniform(low=[50, 50, 100, 100, -np.pi], high=[450, 450, 400, 400, np.pi])
#state = self.np_random.uniform(low=[50, 50, 100, 100, -np.pi], high=[450, 450, 400, 400, np.pi])
rs = np.random.RandomState(seed=seed)
state = np.array(
[
rs.randint(50, 450),
rs.randint(50, 450),
rs.randint(100, 400),
rs.randint(100, 400),
rs.randn() * 2 * np.pi - np.pi,
]
)
self._set_state(state)

observation = self._get_obs()
info = self._get_info()
info["is_success"] = False

if self.render_mode == "human":
self.render()
Expand Down Expand Up @@ -296,6 +315,12 @@ def _get_obs(self):
elif self.obs_type == "pixels":
screen = self._draw()
obs = self._get_img(screen)
elif self.obs_type == "pixels_agent_pos":
screen = self._draw()
obs = {
"pixels": self._get_img(screen),
"agent_pos": np.array(self.agent.position),
}
return obs

def _get_goal_pose_body(self, pose):
Expand Down

0 comments on commit 0fe4449

Please sign in to comment.