Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Apr 5, 2024
1 parent 0024fd7 commit 186095d
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 128 deletions.
217 changes: 111 additions & 106 deletions gym_xarm/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from collections import OrderedDict, deque
from collections import deque

import gymnasium as gym
import mujoco
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
visualization_height=680,
render_mode=None,
frame_stack=1,
channel_last=False,
channel_last=True,
):
# Env setup
if gripper_rotation is None:
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(
if not os.path.exists(self.xml_path):
raise OSError(f"File {self.xml_path} does not exist")

# Initialize sim, spaces & renderers
self._initialize_simulation()
self.observation_space = self._initialize_observation_space()
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(self.metadata["action_space"]),))
Expand All @@ -81,6 +82,9 @@ def __init__(
if "w" not in self.metadata["action_space"]:
self.action_padding[-1] = 1.0

self.observation_renderer = self._initialize_renderer(type="observation")
self.visualization_renderer = self._initialize_renderer(type="visualization")

# super().__init__(
# xml_path = os.path.join(os.path.dirname(__file__), "assets", f"{task}.xml"),
# n_substeps=20,
Expand All @@ -90,12 +94,9 @@ def __init__(
# height=image_size,
# )

self.observation_renderer = self._initialize_renderer(type="observation")
self.visualization_renderer = self._initialize_renderer(type="visualization")

def _initialize_simulation(self):
"""Initialize MuJoCo simulation data structures mjModel and mjData."""
self.model = self._mujoco.MjModel.from_xml_path(self.fullpath)
self.model = self._mujoco.MjModel.from_xml_path(self.xml_path)
self.data = self._mujoco.MjData(self.model)
self._model_names = self._utils.MujocoModelNames(self.model)

Expand All @@ -121,9 +122,9 @@ def _env_setup(self, initial_qpos):

def _initialize_observation_space(self):
image_shape = (
(self.image_size, self.image_size, 3 * self.frame_stack)
(self.observation_width, self.observation_height, 3 * self.frame_stack)
if self.channel_last
else (3 * self.frame_stack, self.image_size, self.image_size)
else (3 * self.frame_stack, self.observation_width, self.observation_height)
)
if self.obs_type == "state":
obs = self._get_obs()
Expand Down Expand Up @@ -159,22 +160,6 @@ def _initialize_renderer(self, type: str):

return MujocoRenderer(model, self.data)

def _reset_sim(self):
"""Resets a simulation and indicates whether or not it was successful.
If a reset was unsuccessful (e.g. if a randomized state caused an error in the
simulation), this method should indicate such a failure by returning False.
In such a case, this method will be called again to attempt a the reset again.
"""
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
if self.model.na != 0:
self.data.act[:] = None

mujoco.mj_forward(self.model, self.data)
return True

@property
def dt(self):
"""Return the timestep of each Gymanisum step."""
Expand Down Expand Up @@ -205,60 +190,12 @@ def _sample_goal(self):
raise NotImplementedError()

def get_obs(self):
return self._get_obs()

def _step_callback(self):
self._mujoco.mj_forward(self.model, self.data)

def _limit_gripper(self, gripper_pos, pos_ctrl):
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
pos_ctrl[0] = min(pos_ctrl[0], 0)
if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
pos_ctrl[0] = max(pos_ctrl[0], 0)
if gripper_pos[1] > self.center_of_table[1] + 0.3:
pos_ctrl[1] = min(pos_ctrl[1], 0)
if gripper_pos[1] < self.center_of_table[1] - 0.3:
pos_ctrl[1] = max(pos_ctrl[1], 0)
if gripper_pos[2] > self.max_z:
pos_ctrl[2] = min(pos_ctrl[2], 0)
if gripper_pos[2] < self.min_z:
pos_ctrl[2] = max(pos_ctrl[2], 0)
return pos_ctrl

def _apply_action(self, action):
assert action.shape == (4,)
action = action.copy()
pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl = self._limit_gripper(
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
) * (1 / self.n_substeps)
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
mocap.apply_action(
self.model,
self._model_names,
self.data,
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
)

self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
self._sample_goal()
self._mujoco.mj_step(self.model, self.data, nstep=10)
return True

def _set_gripper(self, gripper_pos, gripper_rotation):
self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
self.data.qpos[10] = 0.0
self.data.qpos[12] = 0.0
raise NotImplementedError()

def reset(
self,
*,
seed: int | None = None,
options: dict | None = None,
):
"""Reset MuJoCo simulation to initial state.
Expand Down Expand Up @@ -287,6 +224,22 @@ def reset(
info = {}
return observation, info

def _reset_sim(self):
"""Resets a simulation and indicates whether or not it was successful.
If a reset was unsuccessful (e.g. if a randomized state caused an error in the
simulation), this method should indicate such a failure by returning False.
In such a case, this method will be called again to attempt a the reset again.
"""
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
if self.model.na != 0:
self.data.act[:] = None

mujoco.mj_forward(self.model, self.data)
return True

# def reset(self, seed=None, options=None):
# super().reset(seed=seed, options=options)
# self._reset_sim()
Expand All @@ -297,56 +250,106 @@ def reset(

def step(self, action):
assert action.shape == (4,)
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
self._apply_action(action)
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback()
observation = self._get_obs()
observation = self._transform_obs(observation)
observation = self.get_obs()
# observation = self.get_obs()
# observation = self._transform_obs(observation)
reward = self.get_reward()
done = False
info = {"is_success": self.is_success(), "success": self.is_success()}
return observation, reward, done, info

def _transform_obs(self, obs, reset=False):
if self.obs_type == "state":
return obs["observation"]
elif self.obs_type == "rgb":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return rgb_obs
elif self.obs_type == "all":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
else:
raise ValueError(f"Unknown obs_type {self.obs_type}. Must be one of [rgb, all, state]")

def _render_callback(self):
def _step_callback(self):
self._mujoco.mj_forward(self.model, self.data)

def _update_frames(self, reset=False):
pixels = self._render_obs()
self._frames.append(pixels)
if reset:
for _ in range(1, self.frame_stack):
self._frames.append(pixels)
assert len(self._frames) == self.frame_stack
def _limit_gripper(self, gripper_pos, pos_ctrl):
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
pos_ctrl[0] = min(pos_ctrl[0], 0)
if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
pos_ctrl[0] = max(pos_ctrl[0], 0)
if gripper_pos[1] > self.center_of_table[1] + 0.3:
pos_ctrl[1] = min(pos_ctrl[1], 0)
if gripper_pos[1] < self.center_of_table[1] - 0.3:
pos_ctrl[1] = max(pos_ctrl[1], 0)
if gripper_pos[2] > self.max_z:
pos_ctrl[2] = min(pos_ctrl[2], 0)
if gripper_pos[2] < self.min_z:
pos_ctrl[2] = max(pos_ctrl[2], 0)
return pos_ctrl

def _apply_action(self, action):
assert action.shape == (4,)
action = action.copy()
pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl = self._limit_gripper(
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
) * (1 / self.n_substeps)
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
mocap.apply_action(
self.model,
self._model_names,
self.data,
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
)

def _render_obs(self):
obs = self.render(mode="rgb_array")
if not self.channel_last:
obs = obs.transpose(2, 0, 1)
return obs.copy()
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
self._sample_goal()
self._mujoco.mj_step(self.model, self.data, nstep=10)
return True

def _set_gripper(self, gripper_pos, gripper_rotation):
self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
self.data.qpos[10] = 0.0
self.data.qpos[12] = 0.0

# def _transform_obs(self, obs, reset=False):
# if self.obs_type == "state":
# return obs["observation"]
# elif self.obs_type == "rgb":
# self._update_frames(reset=reset)
# rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
# return rgb_obs
# elif self.obs_type == "all":
# self._update_frames(reset=reset)
# rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
# return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
# else:
# raise ValueError(f"Unknown obs_type {self.obs_type}. Must be one of [rgb, all, state]")

# def _update_frames(self, reset=False):
# pixels = self._render_obs()
# self._frames.append(pixels)
# if reset:
# for _ in range(1, self.frame_stack):
# self._frames.append(pixels)
# assert len(self._frames) == self.frame_stack

# def _render_obs(self):
# obs = self.render(mode="rgb_array")
# if not self.channel_last:
# obs = obs.transpose(2, 0, 1)
# return obs.copy()

def render(self, mode="rgb_array"):
self._render_callback()
# return self._mujoco.physics.render(height=84, width=84, camera_name="camera0")

if mode == "visualize":
return self.visualization_renderer.render("rgb_array", camera_name="camera0")

return self.observation_renderer.render(mode, camera_name="camera0")
render = self.observation_renderer.render("rgb_array", camera_name="camera0")
if self.channel_last:
return render
else:
return render.transpose(2, 0, 1)

def _render_callback(self):
self._mujoco.mj_forward(self.model, self.data)

def close(self):
"""Close contains the code necessary to "clean up" the environment.
Expand All @@ -355,3 +358,5 @@ def close(self):
"""
if self.observation_renderer is not None:
self.observation_renderer.close()
if self.visualization_renderer is not None:
self.visualization_renderer.close()
Loading

0 comments on commit 186095d

Please sign in to comment.