Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Apr 6, 2024
1 parent 2524966 commit 79e18a3
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 263 deletions.
186 changes: 33 additions & 153 deletions gym_xarm/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,42 @@
from collections import OrderedDict, deque

import gymnasium as gym
import numpy as np
from gymnasium.wrappers import TimeLimit
from collections import OrderedDict

from gym_xarm.tasks.base import Base as Base
from gym_xarm.tasks.lift import Lift
from gym_xarm.tasks.peg_in_box import PegInBox
from gym_xarm.tasks.push import Push
from gym_xarm.tasks.reach import Reach

# from gym_xarm.tasks.peg_in_box import PegInBox
# from gym_xarm.tasks.push import Push
# from gym_xarm.tasks.reach import Reach


TASKS = OrderedDict(
(
(
"reach",
{
"env": Reach,
"action_space": "xyz",
"episode_length": 50,
"description": "Reach a target location with the end effector",
},
),
(
"push",
{
"env": Push,
"action_space": "xyz",
"episode_length": 50,
"description": "Push a cube to a target location",
},
),
(
"peg_in_box",
{
"env": PegInBox,
"action_space": "xyz",
"episode_length": 50,
"description": "Insert a peg into a box",
},
),
# (
# "reach",
# {
# "env": Reach,
# "action_space": "xyz",
# "episode_length": 50,
# "description": "Reach a target location with the end effector",
# },
# ),
# (
# "push",
# {
# "env": Push,
# "action_space": "xyz",
# "episode_length": 50,
# "description": "Push a cube to a target location",
# },
# ),
# (
# "peg_in_box",
# {
# "env": PegInBox,
# "action_space": "xyz",
# "episode_length": 50,
# "description": "Insert a peg into a box",
# },
# ),
(
"lift",
{
Expand All @@ -50,121 +48,3 @@
),
)
)


class SimXarmWrapper(gym.Wrapper):
"""
DEPRECATED: Use gym.make()
A wrapper for the SimXarm environments. This wrapper is used to
convert the action and observation spaces to the correct format.
"""

def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
super().__init__(env)
self._env = env
self.obs_mode = obs_mode
self.image_size = image_size
self.action_repeat = action_repeat
self.frame_stack = frame_stack
self._frames = deque([], maxlen=frame_stack)
self.channel_last = channel_last
self._max_episode_steps = task["episode_length"] // action_repeat

image_shape = (
(image_size, image_size, 3 * frame_stack)
if channel_last
else (3 * frame_stack, image_size, image_size)
)
if obs_mode == "state":
self.observation_space = env.observation_space["observation"]
elif obs_mode == "rgb":
self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
elif obs_mode == "all":
self.observation_space = gym.spaces.Dict(
state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
)
else:
raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
if "w" not in task["action_space"]:
self.action_padding[-1] = 1.0

def _render_obs(self):
obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
if not self.channel_last:
obs = obs.transpose(2, 0, 1)
return obs.copy()

def _update_frames(self, reset=False):
pixels = self._render_obs()
self._frames.append(pixels)
if reset:
for _ in range(1, self.frame_stack):
self._frames.append(pixels)
assert len(self._frames) == self.frame_stack

def transform_obs(self, obs, reset=False):
if self.obs_mode == "state":
return obs["observation"]
elif self.obs_mode == "rgb":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return rgb_obs
elif self.obs_mode == "all":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
else:
raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")

def reset(self):
return self.transform_obs(self._env.reset(), reset=True)

def step(self, action):
action = np.concatenate([action, self.action_padding])
reward = 0.0
for _ in range(self.action_repeat):
obs, r, done, info = self._env.step(action)
reward += r
return self.transform_obs(obs), reward, done, info

def render(self, mode="rgb_array", width=384, height=384, **kwargs):
return self._env.render(mode, width=width, height=height)

@property
def state(self):
return self._env.robot_state


def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
"""
DEPRECATED: Use gym.make()
Create a new environment.
Args:
task (str): The task to create an environment for. Must be one of:
- 'reach'
- 'push'
- 'peg-in-box'
- 'lift'
obs_mode (str): The observation mode to use. Must be one of:
- 'state': Only state observations
- 'rgb': RGB images
- 'all': RGB images and state observations
image_size (int): The size of the image observations
action_repeat (int): The number of times to repeat the action
seed (int): The random seed to use
Returns:
gym.Env: The environment
"""
if task not in TASKS:
raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
env = TASKS[task]["env"]()
env = TimeLimit(env, TASKS[task]["episode_length"])
env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
env.seed(seed)

return env
91 changes: 30 additions & 61 deletions gym_xarm/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections import deque

import gymnasium as gym
import mujoco
Expand Down Expand Up @@ -36,10 +35,9 @@ def __init__(
observation_height=84,
visualization_width=680,
visualization_height=680,
frame_stack=1,
channel_last=True,
):
# Env setup
# Coordinates
if gripper_rotation is None:
gripper_rotation = [0, 1, 0, 0]
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
Expand All @@ -56,8 +54,6 @@ def __init__(
self.observation_height = observation_height
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.frame_stack = frame_stack
self._frames = deque([], maxlen=frame_stack)

# Assets
self.xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml")
Expand All @@ -79,15 +75,6 @@ def __init__(
if "w" not in self.metadata["action_space"]:
self.action_padding[-1] = 1.0

# super().__init__(
# xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml"),
# n_substeps=20,
# n_actions=4,
# initial_qpos={},
# width=image_size,
# height=image_size,
# )

def _initialize_simulation(self):
"""Initialize MuJoCo simulation data structures mjModel and mjData."""
self.model = self._mujoco.MjModel.from_xml_path(self.xml_path)
Expand Down Expand Up @@ -116,9 +103,9 @@ def _env_setup(self, initial_qpos):

def _initialize_observation_space(self):
image_shape = (
(self.observation_height, self.observation_width, 3 * self.frame_stack)
(self.observation_height, self.observation_width, 3)
if self.channel_last
else (3 * self.frame_stack, self.observation_height, self.observation_width)
else (3, self.observation_height, self.observation_width)
)
obs = self.get_obs()
if self.obs_type == "state":
Expand All @@ -145,7 +132,7 @@ def _initialize_renderer(self, type: str):
if type == "observation":
model = self.model
elif type == "visualization":
# HACK: MujoCo doesn't allow for custom size rendering on-the-fly, so we
# HACK: gymnasium doesn't allow for custom size rendering on-the-fly, so we
# initialize another renderer with appropriate size for visualization purposes
# see https://gymnasium.farama.org/content/migration-guide/#environment-render
from copy import deepcopy
Expand All @@ -165,16 +152,35 @@ def dt(self):

@property
def eef(self):
return self._utils.get_site_xpos(self.model, self.data, "grasp")
return self._utils.get_site_xpos(self.model, self.data, "grasp") - self.center_of_table

@property
def obj(self):
return self._utils.get_site_xpos(self.model, self.data, "object_site")
def eef_velp(self):
return self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt

@property
def gripper_angle(self):
return self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")

@property
def robot_state(self):
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
return np.concatenate([self.eef, gripper_angle])
return np.concatenate([self.eef - self.center_of_table, self.gripper_angle])

@property
def obj(self):
return self._utils.get_site_xpos(self.model, self.data, "object_site") - self.center_of_table

@property
def obj_rot(self):
return self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:]

@property
def obj_velp(self):
return self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt

@property
def obj_velr(self):
return self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt

def is_success(self):
"""Indicates whether or not the achieved goal successfully achieved the desired goal."""
Expand Down Expand Up @@ -237,27 +243,18 @@ def _reset_sim(self):
mujoco.mj_forward(self.model, self.data)
return True

# def reset(self, seed=None, options=None):
# super().reset(seed=seed, options=options)
# self._reset_sim()
# observation = self._get_obs()
# observation = self._transform_obs(observation)
# info = {}
# return observation, info

def step(self, action):
assert action.shape == (4,)
assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
self._apply_action(action)
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback()
observation = self.get_obs()
# observation = self._transform_obs(observation)
reward = self.get_reward()
terminated = is_success = self.is_success()
truncated = False
info = {"is_success": is_success}

truncated = False
return observation, reward, terminated, truncated, info

def _step_callback(self):
Expand Down Expand Up @@ -307,40 +304,12 @@ def _set_gripper(self, gripper_pos, gripper_rotation):
self.data.qpos[10] = 0.0
self.data.qpos[12] = 0.0

# def _transform_obs(self, obs, reset=False):
# if self.obs_type == "state":
# return obs["observation"]
# elif self.obs_type == "rgb":
# self._update_frames(reset=reset)
# rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
# return rgb_obs
# elif self.obs_type == "all":
# self._update_frames(reset=reset)
# rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
# return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
# else:
# raise ValueError(f"Unknown obs_type {self.obs_type}. Must be one of [rgb, all, state]")

# def _update_frames(self, reset=False):
# pixels = self._render_obs()
# self._frames.append(pixels)
# if reset:
# for _ in range(1, self.frame_stack):
# self._frames.append(pixels)
# assert len(self._frames) == self.frame_stack

# def _render_obs(self):
# obs = self.render(mode="rgb_array")
# if not self.channel_last:
# obs = obs.transpose(2, 0, 1)
# return obs.copy()

def render(self, mode="rgb_array"):
self._render_callback()
if mode == "visualize":
return self.visualization_renderer.render("rgb_array", camera_name="camera0")

render = self.observation_renderer.render("rgb_array", camera_name="camera0")
render = self.observation_renderer.render(mode, camera_name="camera0")
if self.channel_last:
return render.copy()
else:
Expand Down
Loading

0 comments on commit 79e18a3

Please sign in to comment.