diff --git a/gym_pybullet_drones/envs/BaseRLAviary.py b/gym_pybullet_drones/envs/BaseRLAviary.py index b154cf478..e6988e5df 100644 --- a/gym_pybullet_drones/envs/BaseRLAviary.py +++ b/gym_pybullet_drones/envs/BaseRLAviary.py @@ -69,7 +69,6 @@ def __init__(self, vision_attributes = True if obs == ObservationType.RGB else False self.OBS_TYPE = obs self.ACT_TYPE = act - self.EPISODE_LEN_SEC = 5 #### Create integrated controllers ######################### if act in [ActionType.PID, ActionType.VEL, ActionType.ONE_D_PID]: os.environ['KMP_DUPLICATE_LIB_OK']='True' diff --git a/gym_pybullet_drones/envs/HoverAviary.py b/gym_pybullet_drones/envs/HoverAviary.py index 7957fe74c..165b35c32 100644 --- a/gym_pybullet_drones/envs/HoverAviary.py +++ b/gym_pybullet_drones/envs/HoverAviary.py @@ -49,6 +49,7 @@ def __init__(self, """ self.TARGET_POS = np.array([0,0,1]) + self.EPISODE_LEN_SEC = 8 super().__init__(drone_model=drone_model, num_drones=1, initial_xyzs=initial_xyzs, @@ -74,7 +75,7 @@ def _computeReward(self): """ state = self._getDroneStateVector(0) - ret = max(0, 10 - np.linalg.norm(self.TARGET_POS-state[0:3])**4) + ret = max(0, 2 - np.linalg.norm(self.TARGET_POS-state[0:3])**4) return ret ################################################################################ diff --git a/gym_pybullet_drones/envs/LeaderFollowerAviary.py b/gym_pybullet_drones/envs/LeaderFollowerAviary.py index 6beca9fa7..171b641ba 100644 --- a/gym_pybullet_drones/envs/LeaderFollowerAviary.py +++ b/gym_pybullet_drones/envs/LeaderFollowerAviary.py @@ -55,6 +55,7 @@ def __init__(self, """ self.TARGET_POS = np.array([0,0,1]) + self.EPISODE_LEN_SEC = 5 super().__init__(drone_model=drone_model, num_drones=num_drones, neighbourhood_radius=neighbourhood_radius, diff --git a/gym_pybullet_drones/examples/learn.py b/gym_pybullet_drones/examples/learn.py index 4b229363c..397d5fbf6 100644 --- a/gym_pybullet_drones/examples/learn.py +++ b/gym_pybullet_drones/examples/learn.py @@ -85,7 +85,11 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ eval_freq=int(2000), deterministic=True, render=False) - model.learn(total_timesteps=int(1e6), + if not multiagent: + steps = 2 * int(1e5) + else: + steps = int(1e6) + model.learn(total_timesteps=steps, callback=eval_callback, log_interval=100) @@ -140,7 +144,7 @@ def run(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_ obs, info = test_env.reset(seed=42, options={}) start = time.time() - for i in range(3*test_env.CTRL_FREQ): + for i in range((test_env.EPISODE_LEN_SEC+2)*test_env.CTRL_FREQ): action, _states = model.predict(obs, deterministic=True )