From 2c2470c2043eb4cf474d9af2681c9fa10f2d0e2f Mon Sep 17 00:00:00 2001 From: Jacopo Panerati Date: Mon, 11 Dec 2023 12:14:56 +0800 Subject: [PATCH] control frequency based action buffer in rl state --- gym_pybullet_drones/envs/BaseRLAviary.py | 4 ++-- gym_pybullet_drones/envs/HoverAviary.py | 4 ++-- gym_pybullet_drones/envs/MultiHoverAviary.py | 2 +- gym_pybullet_drones/examples/learn.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gym_pybullet_drones/envs/BaseRLAviary.py b/gym_pybullet_drones/envs/BaseRLAviary.py index e6988e5df..0c2549822 100644 --- a/gym_pybullet_drones/envs/BaseRLAviary.py +++ b/gym_pybullet_drones/envs/BaseRLAviary.py @@ -62,8 +62,8 @@ def __init__(self, The type of action space (1 or 3D; RPMS, thurst and torques, waypoint or velocity with PID control; etc.) """ - #### Create a buffer for the last 10 actions ############### - self.ACTION_BUFFER_SIZE = 10 + #### Create a buffer for the last .5 sec of actions ######## + self.ACTION_BUFFER_SIZE = int(ctrl_freq//2) self.action_buffer = deque(maxlen=self.ACTION_BUFFER_SIZE) #### vision_attributes = True if obs == ObservationType.RGB else False diff --git a/gym_pybullet_drones/envs/HoverAviary.py b/gym_pybullet_drones/envs/HoverAviary.py index 449a37a57..927504036 100644 --- a/gym_pybullet_drones/envs/HoverAviary.py +++ b/gym_pybullet_drones/envs/HoverAviary.py @@ -107,8 +107,8 @@ def _computeTruncated(self): """ state = self._getDroneStateVector(0) - if (abs(state[0]) > 2.0 or abs(state[1]) > 2.0 or state[2] > 2.0 # Truncate when the drone is too far away - or abs(state[7]) > .5 or abs(state[8]) > .5 # Truncate when the drone is too tilted + if (abs(state[0]) > 1.5 or abs(state[1]) > 1.5 or state[2] > 2.0 # Truncate when the drone is too far away + or abs(state[7]) > .4 or abs(state[8]) > .4 # Truncate when the drone is too tilted ): return True if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: diff --git a/gym_pybullet_drones/envs/MultiHoverAviary.py b/gym_pybullet_drones/envs/MultiHoverAviary.py index 6a8db25d5..a124668c3 100644 --- a/gym_pybullet_drones/envs/MultiHoverAviary.py +++ b/gym_pybullet_drones/envs/MultiHoverAviary.py @@ -121,7 +121,7 @@ def _computeTruncated(self): states = np.array([self._getDroneStateVector(i) for i in range(self.NUM_DRONES)]) for i in range(self.NUM_DRONES): if (abs(states[i][0]) > 2.0 or abs(states[i][1]) > 2.0 or states[i][2] > 2.0 # Truncate when a drones is too far away - or abs(states[i][7]) > .5 or abs(states[i][8]) > .5 # Truncate when a drone is too tilted + or abs(states[i][7]) > .4 or abs(states[i][8]) > .4 # Truncate when a drone is too tilted ): return True if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index 6010f9f6e..98f373753 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -76,9 +76,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ #### Target cumulative rewards (problem-dependent) ########## if DEFAULT_ACT == ActionType.ONE_D_RPM: - target_reward = 474.1 if not multiagent else 950. + target_reward = 474.15 if not multiagent else 949.5 else: - target_reward = 465. if not multiagent else 920. + target_reward = 467. if not multiagent else 920. callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=target_reward, verbose=1) eval_callback = EvalCallback(eval_env, @@ -111,9 +111,9 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ if local: input("Press Enter to continue...") - if os.path.isfile(filename+'/final_model.zip'): - path = filename+'/final_model.zip' - elif os.path.isfile(filename+'/best_model.zip'): + # if os.path.isfile(filename+'/final_model.zip'): + # path = filename+'/final_model.zip' + if os.path.isfile(filename+'/best_model.zip'): path = filename+'/best_model.zip' else: print("[ERROR]: no model under the specified path", filename)