From e9c5712692040d476d25e31339bf05c4e7bde069 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 3 Apr 2024 18:10:11 +0000 Subject: [PATCH] WIP --- README.md | 3 ++- gym_pusht/envs/pusht.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8a02240..4e497db 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,13 @@ pip install gym-pusht import gymnasium as gym import gym_pusht -env = gym.make("gym_pusht/PushT-v0", render_mode="human") +env = gym.make("gym_pusht/PushT-v0", render_mode="human", render_size=(84,84), visualization_size=(680,680)) observation, info = env.reset() for _ in range(1000): action = env.action_space.sample() observation, reward, terminated, truncated, info = env.step(action) + image = env.render() if terminated or truncated: observation, info = env.reset() diff --git a/gym_pusht/envs/pusht.py b/gym_pusht/envs/pusht.py index 3a960bb..8241e7c 100644 --- a/gym_pusht/envs/pusht.py +++ b/gym_pusht/envs/pusht.py @@ -133,6 +133,14 @@ def __init__( ) elif self.obs_type == "pixels": self.observation_space = spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict({ + "pixels": spaces.Box(low=0, high=255, shape=(render_size, render_size, 3), dtype=np.uint8), + "agent_pos": spaces.Box( + low=np.array([0, 0], dtype=np.float32), + high=np.array([512, 512], dtype=np.float32), + ), + }) self.action_space = spaces.Box(low=0, high=512, shape=(2,), dtype=np.float32) @@ -186,14 +194,14 @@ def step(self, action): coverage = intersection_area / goal_area reward = np.clip(coverage / self.success_threshold, 0.0, 1.0) is_success = coverage > self.success_threshold + terminated = is_success observation = self._get_obs() info = self._get_info() - info["is_success"] = is_success - terminated = is_success - return observation, reward, terminated, False, info + truncated = False + return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): super().reset(seed=seed) @@ -202,11 +210,22 @@ def reset(self, seed=None, options=None): if options is not None and options.get("reset_to_state") is not None: state = np.array(options.get("reset_to_state")) else: - state = self.np_random.uniform(low=[50, 50, 100, 100, -np.pi], high=[450, 450, 400, 400, np.pi]) + #state = self.np_random.uniform(low=[50, 50, 100, 100, -np.pi], high=[450, 450, 400, 400, np.pi]) + rs = np.random.RandomState(seed=seed) + state = np.array( + [ + rs.randint(50, 450), + rs.randint(50, 450), + rs.randint(100, 400), + rs.randint(100, 400), + rs.randn() * 2 * np.pi - np.pi, + ] + ) self._set_state(state) observation = self._get_obs() info = self._get_info() + info["is_success"] = False if self.render_mode == "human": self.render() @@ -296,6 +315,12 @@ def _get_obs(self): elif self.obs_type == "pixels": screen = self._draw() obs = self._get_img(screen) + elif self.obs_type == "pixels_agent_pos": + screen = self._draw() + obs = { + "pixels": self._get_img(screen), + "agent_pos": np.array(self.agent.position), + } return obs def _get_goal_pose_body(self, pose):