From 79e18a320076e003e59c811267e5c0994097c5d9 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 6 Apr 2024 11:07:51 +0200 Subject: [PATCH] Cleanup --- gym_xarm/tasks/__init__.py | 186 +++++++---------------------------- gym_xarm/tasks/base.py | 91 ++++++----------- gym_xarm/tasks/lift.py | 44 ++------- gym_xarm/tasks/peg_in_box.py | 2 + gym_xarm/tasks/push.py | 2 + gym_xarm/tasks/reach.py | 2 + tests/test_env.py | 15 +-- 7 files changed, 79 insertions(+), 263 deletions(-) diff --git a/gym_xarm/tasks/__init__.py b/gym_xarm/tasks/__init__.py index a20e29d..3fabb3e 100644 --- a/gym_xarm/tasks/__init__.py +++ b/gym_xarm/tasks/__init__.py @@ -1,44 +1,42 @@ -from collections import OrderedDict, deque - -import gymnasium as gym -import numpy as np -from gymnasium.wrappers import TimeLimit +from collections import OrderedDict from gym_xarm.tasks.base import Base as Base from gym_xarm.tasks.lift import Lift -from gym_xarm.tasks.peg_in_box import PegInBox -from gym_xarm.tasks.push import Push -from gym_xarm.tasks.reach import Reach + +# from gym_xarm.tasks.peg_in_box import PegInBox +# from gym_xarm.tasks.push import Push +# from gym_xarm.tasks.reach import Reach + TASKS = OrderedDict( ( - ( - "reach", - { - "env": Reach, - "action_space": "xyz", - "episode_length": 50, - "description": "Reach a target location with the end effector", - }, - ), - ( - "push", - { - "env": Push, - "action_space": "xyz", - "episode_length": 50, - "description": "Push a cube to a target location", - }, - ), - ( - "peg_in_box", - { - "env": PegInBox, - "action_space": "xyz", - "episode_length": 50, - "description": "Insert a peg into a box", - }, - ), + # ( + # "reach", + # { + # "env": Reach, + # "action_space": "xyz", + # "episode_length": 50, + # "description": "Reach a target location with the end effector", + # }, + # ), + # ( + # "push", + # { + # "env": Push, + # "action_space": "xyz", + # "episode_length": 50, + # "description": "Push a cube to a target location", + # }, + # ), + # ( + # "peg_in_box", + # { + # "env": PegInBox, + # "action_space": "xyz", + # "episode_length": 50, + # "description": "Insert a peg into a box", + # }, + # ), ( "lift", { @@ -50,121 +48,3 @@ ), ) ) - - -class SimXarmWrapper(gym.Wrapper): - """ - DEPRECATED: Use gym.make() - - A wrapper for the SimXarm environments. This wrapper is used to - convert the action and observation spaces to the correct format. - """ - - def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False): - super().__init__(env) - self._env = env - self.obs_mode = obs_mode - self.image_size = image_size - self.action_repeat = action_repeat - self.frame_stack = frame_stack - self._frames = deque([], maxlen=frame_stack) - self.channel_last = channel_last - self._max_episode_steps = task["episode_length"] // action_repeat - - image_shape = ( - (image_size, image_size, 3 * frame_stack) - if channel_last - else (3 * frame_stack, image_size, image_size) - ) - if obs_mode == "state": - self.observation_space = env.observation_space["observation"] - elif obs_mode == "rgb": - self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8) - elif obs_mode == "all": - self.observation_space = gym.spaces.Dict( - state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32), - rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8), - ) - else: - raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]") - self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),)) - self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32) - if "w" not in task["action_space"]: - self.action_padding[-1] = 1.0 - - def _render_obs(self): - obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) - if not self.channel_last: - obs = obs.transpose(2, 0, 1) - return obs.copy() - - def _update_frames(self, reset=False): - pixels = self._render_obs() - self._frames.append(pixels) - if reset: - for _ in range(1, self.frame_stack): - self._frames.append(pixels) - assert len(self._frames) == self.frame_stack - - def transform_obs(self, obs, reset=False): - if self.obs_mode == "state": - return obs["observation"] - elif self.obs_mode == "rgb": - self._update_frames(reset=reset) - rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - return rgb_obs - elif self.obs_mode == "all": - self._update_frames(reset=reset) - rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state))) - else: - raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]") - - def reset(self): - return self.transform_obs(self._env.reset(), reset=True) - - def step(self, action): - action = np.concatenate([action, self.action_padding]) - reward = 0.0 - for _ in range(self.action_repeat): - obs, r, done, info = self._env.step(action) - reward += r - return self.transform_obs(obs), reward, done, info - - def render(self, mode="rgb_array", width=384, height=384, **kwargs): - return self._env.render(mode, width=width, height=height) - - @property - def state(self): - return self._env.robot_state - - -def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0): - """ - DEPRECATED: Use gym.make() - - Create a new environment. - Args: - task (str): The task to create an environment for. Must be one of: - - 'reach' - - 'push' - - 'peg-in-box' - - 'lift' - obs_mode (str): The observation mode to use. Must be one of: - - 'state': Only state observations - - 'rgb': RGB images - - 'all': RGB images and state observations - image_size (int): The size of the image observations - action_repeat (int): The number of times to repeat the action - seed (int): The random seed to use - Returns: - gym.Env: The environment - """ - if task not in TASKS: - raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}") - env = TASKS[task]["env"]() - env = TimeLimit(env, TASKS[task]["episode_length"]) - env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last) - env.seed(seed) - - return env diff --git a/gym_xarm/tasks/base.py b/gym_xarm/tasks/base.py index 394e63d..5ae90ae 100644 --- a/gym_xarm/tasks/base.py +++ b/gym_xarm/tasks/base.py @@ -1,5 +1,4 @@ import os -from collections import deque import gymnasium as gym import mujoco @@ -36,10 +35,9 @@ def __init__( observation_height=84, visualization_width=680, visualization_height=680, - frame_stack=1, channel_last=True, ): - # Env setup + # Coordinates if gripper_rotation is None: gripper_rotation = [0, 1, 0, 0] self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32) @@ -56,8 +54,6 @@ def __init__( self.observation_height = observation_height self.visualization_width = visualization_width self.visualization_height = visualization_height - self.frame_stack = frame_stack - self._frames = deque([], maxlen=frame_stack) # Assets self.xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml") @@ -79,15 +75,6 @@ def __init__( if "w" not in self.metadata["action_space"]: self.action_padding[-1] = 1.0 - # super().__init__( - # xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml"), - # n_substeps=20, - # n_actions=4, - # initial_qpos={}, - # width=image_size, - # height=image_size, - # ) - def _initialize_simulation(self): """Initialize MuJoCo simulation data structures mjModel and mjData.""" self.model = self._mujoco.MjModel.from_xml_path(self.xml_path) @@ -116,9 +103,9 @@ def _env_setup(self, initial_qpos): def _initialize_observation_space(self): image_shape = ( - (self.observation_height, self.observation_width, 3 * self.frame_stack) + (self.observation_height, self.observation_width, 3) if self.channel_last - else (3 * self.frame_stack, self.observation_height, self.observation_width) + else (3, self.observation_height, self.observation_width) ) obs = self.get_obs() if self.obs_type == "state": @@ -145,7 +132,7 @@ def _initialize_renderer(self, type: str): if type == "observation": model = self.model elif type == "visualization": - # HACK: MujoCo doesn't allow for custom size rendering on-the-fly, so we + # HACK: gymnasium doesn't allow for custom size rendering on-the-fly, so we # initialize another renderer with appropriate size for visualization purposes # see https://gymnasium.farama.org/content/migration-guide/#environment-render from copy import deepcopy @@ -165,16 +152,35 @@ def dt(self): @property def eef(self): - return self._utils.get_site_xpos(self.model, self.data, "grasp") + return self._utils.get_site_xpos(self.model, self.data, "grasp") - self.center_of_table @property - def obj(self): - return self._utils.get_site_xpos(self.model, self.data, "object_site") + def eef_velp(self): + return self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt + + @property + def gripper_angle(self): + return self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") @property def robot_state(self): - gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") - return np.concatenate([self.eef, gripper_angle]) + return np.concatenate([self.eef - self.center_of_table, self.gripper_angle]) + + @property + def obj(self): + return self._utils.get_site_xpos(self.model, self.data, "object_site") - self.center_of_table + + @property + def obj_rot(self): + return self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:] + + @property + def obj_velp(self): + return self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt + + @property + def obj_velr(self): + return self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt def is_success(self): """Indicates whether or not the achieved goal successfully achieved the desired goal.""" @@ -237,14 +243,6 @@ def _reset_sim(self): mujoco.mj_forward(self.model, self.data) return True - # def reset(self, seed=None, options=None): - # super().reset(seed=seed, options=options) - # self._reset_sim() - # observation = self._get_obs() - # observation = self._transform_obs(observation) - # info = {} - # return observation, info - def step(self, action): assert action.shape == (4,) assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid" @@ -252,12 +250,11 @@ def step(self, action): self._mujoco.mj_step(self.model, self.data, nstep=2) self._step_callback() observation = self.get_obs() - # observation = self._transform_obs(observation) reward = self.get_reward() terminated = is_success = self.is_success() + truncated = False info = {"is_success": is_success} - truncated = False return observation, reward, terminated, truncated, info def _step_callback(self): @@ -307,40 +304,12 @@ def _set_gripper(self, gripper_pos, gripper_rotation): self.data.qpos[10] = 0.0 self.data.qpos[12] = 0.0 - # def _transform_obs(self, obs, reset=False): - # if self.obs_type == "state": - # return obs["observation"] - # elif self.obs_type == "rgb": - # self._update_frames(reset=reset) - # rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - # return rgb_obs - # elif self.obs_type == "all": - # self._update_frames(reset=reset) - # rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - # return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state))) - # else: - # raise ValueError(f"Unknown obs_type {self.obs_type}. Must be one of [rgb, all, state]") - - # def _update_frames(self, reset=False): - # pixels = self._render_obs() - # self._frames.append(pixels) - # if reset: - # for _ in range(1, self.frame_stack): - # self._frames.append(pixels) - # assert len(self._frames) == self.frame_stack - - # def _render_obs(self): - # obs = self.render(mode="rgb_array") - # if not self.channel_last: - # obs = obs.transpose(2, 0, 1) - # return obs.copy() - def render(self, mode="rgb_array"): self._render_callback() if mode == "visualize": return self.visualization_renderer.render("rgb_array", camera_name="camera0") - render = self.observation_renderer.render("rgb_array", camera_name="camera0") + render = self.observation_renderer.render(mode, camera_name="camera0") if self.channel_last: return render.copy() else: diff --git a/gym_xarm/tasks/lift.py b/gym_xarm/tasks/lift.py index 06897cb..12cb564 100644 --- a/gym_xarm/tasks/lift.py +++ b/gym_xarm/tasks/lift.py @@ -19,26 +19,6 @@ def __init__(self, **kwargs): def z_target(self): return self._init_z + self._z_threshold - @property - def eef_velp(self): - return self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt - - @property - def obj_rot(self): - return self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:] - - @property - def obj_velp(self): - return self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt - - @property - def obj_velr(self): - return self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt - - @property - def gripper_angle(self): - return self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") - def is_success(self): return self.obj[2] >= self.z_target @@ -76,29 +56,19 @@ def get_obs(self): elif self.obs_type == "pixels_agent_pos": return { "pixels": pixels, - "agent_pos": self._get_obs(agent_only=True), + "agent_pos": self.robot_state, } else: raise ValueError( f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, pixels_agent_pos]" ) - def _get_obs(self, agent_only=False): - eef = self.eef - self.center_of_table - if agent_only: - return np.concatenate( - [ - eef, - self.gripper_angle, - ] - ) - - obj = self.obj - self.center_of_table + def _get_obs(self): return np.concatenate( [ - eef, + self.eef, self.eef_velp, - obj, + self.obj, self.obj_rot, self.obj_velp, self.obj_velr, @@ -106,10 +76,10 @@ def _get_obs(self, agent_only=False): np.array( [ np.linalg.norm(self.eef - self.obj), - np.linalg.norm(eef[:-1] - obj[:-1]), + np.linalg.norm(self.eef[:-1] - self.obj[:-1]), self.z_target, - self.z_target - obj[-1], - self.z_target - eef[-1], + self.z_target - self.obj[-1], + self.z_target - self.eef[-1], ] ), self.gripper_angle, diff --git a/gym_xarm/tasks/peg_in_box.py b/gym_xarm/tasks/peg_in_box.py index 4f21c48..1db5e42 100644 --- a/gym_xarm/tasks/peg_in_box.py +++ b/gym_xarm/tasks/peg_in_box.py @@ -4,6 +4,8 @@ class PegInBox(Base): + """DEPRECATED: use only Lift for now""" + def __init__(self): super().__init__("peg_in_box") diff --git a/gym_xarm/tasks/push.py b/gym_xarm/tasks/push.py index 8b145a6..daa6b60 100644 --- a/gym_xarm/tasks/push.py +++ b/gym_xarm/tasks/push.py @@ -4,6 +4,8 @@ class Push(Base): + """DEPRECATED: use only Lift for now""" + def __init__(self): super().__init__("push") diff --git a/gym_xarm/tasks/reach.py b/gym_xarm/tasks/reach.py index 94688e7..ee1a533 100644 --- a/gym_xarm/tasks/reach.py +++ b/gym_xarm/tasks/reach.py @@ -4,6 +4,8 @@ class Reach(Base): + """DEPRECATED: use only Lift for now""" + def __init__(self): super().__init__("reach") diff --git a/tests/test_env.py b/tests/test_env.py index 9d2e8ba..f08f951 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -2,29 +2,20 @@ import gymnasium as gym from gymnasium.utils.env_checker import check_env +import gym_xarm # noqa: F401 + @pytest.mark.parametrize( "env_task, obs_type", [ ("XarmLift-v0", "state"), ("XarmLift-v0", "pixels"), ("XarmLift-v0", "pixels_agent_pos"), - # TODO(aliberts): Add other tasks - # ("reach", False, False), - # ("reach", True, False), - # ("push", False, False), - # ("push", True, False), - # ("peg_in_box", False, False), - # ("peg_in_box", True, False), ], ) def test_env(env_task, obs_type): - import gym_xarm # noqa: F401 env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type) - check_env(env.unwrapped, skip_render_check=True) - # env.reset() - # env.render() + check_env(env.unwrapped) if __name__ == "__main__": test_env("XarmLift-v0", "pixels_agent_pos") - # test_env("XarmLift-v0", "state") \ No newline at end of file