Skip to content

Commit

Permalink
Fix, remove eef_velop, add copy(), same API as pusht
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Apr 5, 2024
1 parent fc0cc41 commit 2524966
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 323 deletions.
306 changes: 0 additions & 306 deletions gym_xarm/robot_env.py

This file was deleted.

26 changes: 10 additions & 16 deletions gym_xarm/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ class Base(gym.Env):
"""

metadata = {
"render_modes": [
"human",
"rgb_array",
],
"render_modes": [],
"render_fps": 25,
}
n_substeps = 20
Expand All @@ -39,7 +36,6 @@ def __init__(
observation_height=84,
visualization_width=680,
visualization_height=680,
render_mode=None,
frame_stack=1,
channel_last=True,
):
Expand All @@ -60,7 +56,6 @@ def __init__(
self.observation_height = observation_height
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.render_mode = render_mode
self.frame_stack = frame_stack
self._frames = deque([], maxlen=frame_stack)

Expand Down Expand Up @@ -121,9 +116,9 @@ def _env_setup(self, initial_qpos):

def _initialize_observation_space(self):
image_shape = (
(self.observation_width, self.observation_height, 3 * self.frame_stack)
(self.observation_height, self.observation_width, 3 * self.frame_stack)
if self.channel_last
else (3 * self.frame_stack, self.observation_width, self.observation_height)
else (3 * self.frame_stack, self.observation_height, self.observation_width)
)
obs = self.get_obs()
if self.obs_type == "state":
Expand Down Expand Up @@ -223,8 +218,6 @@ def reset(
while not did_reset_sim:
did_reset_sim = self._reset_sim()
observation = self.get_obs()
if self.render_mode == "human":
self.render()
info = {}
return observation, info

Expand Down Expand Up @@ -259,12 +252,13 @@ def step(self, action):
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback()
observation = self.get_obs()
# observation = self.get_obs()
# observation = self._transform_obs(observation)
reward = self.get_reward()
done = False
info = {"is_success": self.is_success(), "success": self.is_success()}
return observation, reward, done, info
terminated = is_success = self.is_success()
info = {"is_success": is_success}

truncated = False
return observation, reward, terminated, truncated, info

def _step_callback(self):
self._mujoco.mj_forward(self.model, self.data)
Expand Down Expand Up @@ -348,9 +342,9 @@ def render(self, mode="rgb_array"):

render = self.observation_renderer.render("rgb_array", camera_name="camera0")
if self.channel_last:
return render
return render.copy()
else:
return render.transpose(2, 0, 1)
return render.transpose(2, 0, 1).copy()

def _render_callback(self):
self._mujoco.mj_forward(self.model, self.data)
Expand Down
1 change: 0 additions & 1 deletion gym_xarm/tasks/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _get_obs(self, agent_only=False):
return np.concatenate(
[
eef,
self.eef_velp,
self.gripper_angle,
]
)
Expand Down

0 comments on commit 2524966

Please sign in to comment.