forked from rail-berkeley/hil-serl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwrapper.py
128 lines (110 loc) · 4.54 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import OrderedDict
from franka_env.camera.rs_capture import RSCapture
from franka_env.camera.video_capture import VideoCapture
from franka_env.utils.rotations import euler_2_quat
import numpy as np
import requests
import copy
import gymnasium as gym
import time
from franka_env.envs.franka_env import FrankaEnv
class USBEnv(FrankaEnv):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def init_cameras(self, name_serial_dict=None):
"""Init both wrist cameras."""
if self.cap is not None: # close cameras if they are already open
self.close_cameras()
self.cap = OrderedDict()
for cam_name, kwargs in name_serial_dict.items():
if cam_name == "side_classifier":
self.cap["side_classifier"] = self.cap["side_policy"]
else:
cap = VideoCapture(
RSCapture(name=cam_name, **kwargs)
)
self.cap[cam_name] = cap
def reset(self, **kwargs):
self._recover()
self._update_currpos()
self._send_pos_command(self.currpos)
time.sleep(0.1)
requests.post(self.url + "update_param", json=self.config.PRECISION_PARAM)
self._send_gripper_command(1.0)
# Move above the target pose
target = copy.deepcopy(self.currpos)
target[2] = self.config.TARGET_POSE[2] + 0.05
self.interpolate_move(target, timeout=0.5)
time.sleep(0.5)
self.interpolate_move(self.config.TARGET_POSE, timeout=0.5)
time.sleep(0.5)
self._send_gripper_command(-1.0)
self._update_currpos()
reset_pose = copy.deepcopy(self.config.TARGET_POSE)
reset_pose[1] += 0.04
self.interpolate_move(reset_pose, timeout=0.5)
obs, info = super().reset(**kwargs)
self._send_gripper_command(1.0)
time.sleep(1)
self.success = False
self._update_currpos()
obs = self._get_obs()
return obs, info
def interpolate_move(self, goal: np.ndarray, timeout: float):
"""Move the robot to the goal position with linear interpolation."""
if goal.shape == (6,):
goal = np.concatenate([goal[:3], euler_2_quat(goal[3:])])
self._send_pos_command(goal)
time.sleep(timeout)
self._update_currpos()
def go_to_reset(self, joint_reset=False):
"""
The concrete steps to perform reset should be
implemented each subclass for the specific task.
Should override this method if custom reset procedure is needed.
"""
# Perform joint reset if needed
if joint_reset:
print("JOINT RESET")
requests.post(self.url + "jointreset")
time.sleep(0.5)
# Perform Carteasian reset
if self.randomreset: # randomize reset position in xy plane
reset_pose = self.resetpos.copy()
reset_pose[:2] += np.random.uniform(
-self.random_xy_range, self.random_xy_range, (2,)
)
euler_random = self._RESET_POSE[3:].copy()
euler_random[-1] += np.random.uniform(
-self.random_rz_range, self.random_rz_range
)
reset_pose[3:] = euler_2_quat(euler_random)
self.interpolate_move(reset_pose, timeout=1)
else:
reset_pose = self.resetpos.copy()
self.interpolate_move(reset_pose, timeout=1)
# Change to compliance mode
requests.post(self.url + "update_param", json=self.config.COMPLIANCE_PARAM)
class GripperPenaltyWrapper(gym.Wrapper):
def __init__(self, env, penalty=-0.05):
super().__init__(env)
assert env.action_space.shape == (7,)
self.penalty = penalty
self.last_gripper_pos = None
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.last_gripper_pos = obs["state"][0, 0]
return obs, info
def step(self, action):
"""Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
observation, reward, terminated, truncated, info = self.env.step(action)
if "intervene_action" in info:
action = info["intervene_action"]
if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
action[-1] > 0.5 and self.last_gripper_pos < 0.9
):
info["grasp_penalty"] = self.penalty
else:
info["grasp_penalty"] = 0.0
self.last_gripper_pos = observation["state"][0, 0]
return observation, reward, terminated, truncated, info