Skip to content

Commit

Permalink
control frequency based action buffer in rl state
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Dec 11, 2023
1 parent 1a82134 commit 2c2470c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions gym_pybullet_drones/envs/BaseRLAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(self,
The type of action space (1 or 3D; RPMS, thurst and torques, waypoint or velocity with PID control; etc.)
"""
#### Create a buffer for the last 10 actions ###############
self.ACTION_BUFFER_SIZE = 10
#### Create a buffer for the last .5 sec of actions ########
self.ACTION_BUFFER_SIZE = int(ctrl_freq//2)
self.action_buffer = deque(maxlen=self.ACTION_BUFFER_SIZE)
####
vision_attributes = True if obs == ObservationType.RGB else False
Expand Down
4 changes: 2 additions & 2 deletions gym_pybullet_drones/envs/HoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _computeTruncated(self):
"""
state = self._getDroneStateVector(0)
if (abs(state[0]) > 2.0 or abs(state[1]) > 2.0 or state[2] > 2.0 # Truncate when the drone is too far away
or abs(state[7]) > .5 or abs(state[8]) > .5 # Truncate when the drone is too tilted
if (abs(state[0]) > 1.5 or abs(state[1]) > 1.5 or state[2] > 2.0 # Truncate when the drone is too far away
or abs(state[7]) > .4 or abs(state[8]) > .4 # Truncate when the drone is too tilted
):
return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
Expand Down
2 changes: 1 addition & 1 deletion gym_pybullet_drones/envs/MultiHoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _computeTruncated(self):
states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)])
for i in range(self.NUM_DRONES):
if (abs(states[i][0]) > 2.0 or abs(states[i][1]) > 2.0 or states[i][2] > 2.0 # Truncate when a drones is too far away
or abs(states[i][7]) > .5 or abs(states[i][8]) > .5 # Truncate when a drone is too tilted
or abs(states[i][7]) > .4 or abs(states[i][8]) > .4 # Truncate when a drone is too tilted
):
return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
Expand Down
10 changes: 5 additions & 5 deletions gym_pybullet_drones/examples/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_

#### Target cumulative rewards (problem-dependent) ##########
if DEFAULT_ACT == ActionType.ONE_D_RPM:
target_reward = 474.1 if not multiagent else 950.
target_reward = 474.15 if not multiagent else 949.5
else:
target_reward = 465. if not multiagent else 920.
target_reward = 467. if not multiagent else 920.
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=target_reward,
verbose=1)
eval_callback = EvalCallback(eval_env,
Expand Down Expand Up @@ -111,9 +111,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_
if local:
input("Press Enter to continue...")

if os.path.isfile(filename+'/final_model.zip'):
path = filename+'/final_model.zip'
elif os.path.isfile(filename+'/best_model.zip'):
# if os.path.isfile(filename+'/final_model.zip'):
# path = filename+'/final_model.zip'
if os.path.isfile(filename+'/best_model.zip'):
path = filename+'/best_model.zip'
else:
print("[ERROR]: no model under the specified path", filename)
Expand Down

0 comments on commit 2c2470c

Please sign in to comment.