Skip to content

Commit

Permalink
Revised reward, terminated, truncated in RL aviaries
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Nov 19, 2023
1 parent 85a726d commit 1b14a1a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions gym_pybullet_drones/envs/HoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self,
The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control)
"""
self.target_pos = np.array([0,0,1])
self.TARGET_POS = np.array([0,0,1])
super().__init__(drone_model=drone_model,
num_drones=1,
initial_xyzs=initial_xyzs,
Expand All @@ -74,7 +74,7 @@ def _computeReward(self):
"""
state = self._getDroneStateVector(0)
ret = max(0, 500 - np.linalg.norm(self.target_pos-state[0:3])**2)
ret = max(0, 500 - np.linalg.norm(self.TARGET_POS-state[0:3])**2)
return ret

################################################################################
Expand All @@ -89,7 +89,7 @@ def _computeTerminated(self):
"""
state = self._getDroneStateVector(0)
if np.linalg.norm(self.target_pos-state[0:3]) < .001:
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .001:
return True
else:
return False
Expand Down
6 changes: 3 additions & 3 deletions gym_pybullet_drones/envs/LeaderFollowerAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control)
"""
self.target_pos = np.array([0,0,1])
self.TARGET_POS = np.array([0,0,1])
super().__init__(drone_model=drone_model,
num_drones=num_drones,
neighbourhood_radius=neighbourhood_radius,
Expand Down Expand Up @@ -82,7 +82,7 @@ def _computeReward(self):
"""
rewards = np.zeros(self.NUM_DRONES)
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
ret = max(0, 500 - np.linalg.norm(self.target_pos-states[0, 0:3])**2)
ret = max(0, 500 - np.linalg.norm(self.TARGET_POS-states[0, 0:3])**2)
for i in range(1, self.NUM_DRONES):
ret += max(0, 100 - np.linalg.norm(states[i-1, 3]-states[i, 3])**2)
return ret
Expand All @@ -99,7 +99,7 @@ def _computeTerminated(self):
"""
state = self._getDroneStateVector(0)
if np.linalg.norm(self.target_pos-state[0:3]) < .001:
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .001:
return True
else:
return False
Expand Down

0 comments on commit 1b14a1a

Please sign in to comment.