diff --git a/gym_pybullet_drones/envs/BaseAviary.py b/gym_pybullet_drones/envs/BaseAviary.py index 592cc5292..ce578ea1d 100755 --- a/gym_pybullet_drones/envs/BaseAviary.py +++ b/gym_pybullet_drones/envs/BaseAviary.py @@ -1092,3 +1092,49 @@ def _computeInfo(self): """ raise NotImplementedError + + ################################################################################ + + def _calculateNextStep(self, current_position, destination, step_size=1): + """ + Calculates intermediate waypoint + towards drone's destination + from drone's current position + + Enables drones to reach distant waypoints without + losing control/crashing, and hover on arrival at destintion + + Parameters + ---------- + current_position : ndarray + drone's current position from state vector + destination : ndarray + drone's target position + step_size: int + distance next waypoint is from current position, default 1 + + Returns + ---------- + next_pos: int + intermediate waypoint for drone + + """ + direction = ( + destination - current_position + ) # Calculate the direction vector + distance = np.linalg.norm( + direction + ) # Calculate the distance to the destination + + if distance <= step_size: + # If the remaining distance is less than or equal to the step size, + # return the destination + return destination + + normalized_direction = ( + direction / distance + ) # Normalize the direction vector + next_step = ( + current_position + normalized_direction * step_size + ) # Calculate the next step + return next_step \ No newline at end of file diff --git a/gym_pybullet_drones/envs/multi_agent_rl/BaseMultiagentAviary.py b/gym_pybullet_drones/envs/multi_agent_rl/BaseMultiagentAviary.py index 3910fe030..40415449a 100644 --- a/gym_pybullet_drones/envs/multi_agent_rl/BaseMultiagentAviary.py +++ b/gym_pybullet_drones/envs/multi_agent_rl/BaseMultiagentAviary.py @@ -22,7 +22,7 @@ def __init__(self, pyb_freq: int = 240, ctrl_freq: int = 240, gui=False, - record=False, + record=False, obs: ObservationType=ObservationType.KIN, act: ActionType=ActionType.RPM ): @@ -182,16 +182,21 @@ def _preprocessAction(self, rpm = np.zeros((self.NUM_DRONES,4)) for k in range(action.shape[0]): target = action[k, :] - if self.ACT_TYPE == ActionType.RPM: + if self.ACT_TYPE == ActionType.RPM: rpm[k,:] = np.array(self.HOVER_RPM * (1+0.05*target)) elif self.ACT_TYPE == ActionType.PID: state = self._getDroneStateVector(k) + next_pos = self._calculateNextStep( + current_position=state[0:3], + destination=target, + step_size=1, + ) rpm_k, _, _ = self.ctrl[k].computeControl(control_timestep=self.CTRL_TIMESTEP, cur_pos=state[0:3], cur_quat=state[3:7], cur_vel=state[10:13], cur_ang_vel=state[13:16], - target_pos=state[0:3]+0.1*target + target_pos=next_pos ) rpm[k,:] = rpm_k elif self.ACT_TYPE == ActionType.VEL: diff --git a/gym_pybullet_drones/envs/single_agent_rl/BaseSingleAgentAviary.py b/gym_pybullet_drones/envs/single_agent_rl/BaseSingleAgentAviary.py index fa54a5b7c..5b4745cb3 100644 --- a/gym_pybullet_drones/envs/single_agent_rl/BaseSingleAgentAviary.py +++ b/gym_pybullet_drones/envs/single_agent_rl/BaseSingleAgentAviary.py @@ -72,7 +72,7 @@ def __init__(self, num_drones=1, initial_xyzs=initial_xyzs, initial_rpys=initial_rpys, - physics=physics, + physics=physics, pyb_freq=pyb_freq, ctrl_freq=ctrl_freq, gui=gui, @@ -170,14 +170,19 @@ def _preprocessAction(self, """ if self.ACT_TYPE == ActionType.RPM: return np.array(self.HOVER_RPM * (1+0.05*action)) - elif self.ACT_TYPE == ActionType.PID: + elif self.ACT_TYPE == ActionType.PID: state = self._getDroneStateVector(0) + next_pos = self._calculateNextStep( + current_position=state[0:3], + destination=action, + step_size=1, + ) rpm, _, _ = self.ctrl.computeControl(control_timestep=self.CTRL_TIMESTEP, cur_pos=state[0:3], cur_quat=state[3:7], cur_vel=state[10:13], cur_ang_vel=state[13:16], - target_pos=state[0:3]+0.1*action + target_pos=next_pos ) return rpm elif self.ACT_TYPE == ActionType.VEL: diff --git a/pyproject.toml b/pyproject.toml index a0a391df2..b98aa08fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,5 +24,5 @@ stable-baselines3 = "^2.0.0" [tool.poetry.dev-dependencies] [build-system] -requires = ["poetry-core @ git+https://github.com/python-poetry/poetry-core.git@main"] +requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"