From ae9460e6e3194f5854e0ef77f0f36d3e6f97f3fc Mon Sep 17 00:00:00 2001 From: Federico-PizarroBejarano Date: Fri, 10 Jan 2025 15:45:00 -0500 Subject: [PATCH] Simple code for cartpole swingup with pure RL (PPO) --- .gitignore | 2 + .../cartpole/cartpole_stab.yaml | 46 ++++++----------- .../cartpole/ppo_cartpole.yaml | 2 +- examples/rl/rl_experiment.py | 6 +-- examples/rl/rl_experiment.sh | 20 ++------ examples/rl/train_rl_model.sh | 50 ++----------------- safe_control_gym/controllers/ppo/ppo.py | 2 + safe_control_gym/envs/gym_control/cartpole.py | 2 +- 8 files changed, 30 insertions(+), 100 deletions(-) diff --git a/.gitignore b/.gitignore index a6f4b5c8d..c8043a63f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ examples/mpsc/unsafe_rl_temp_data/ # examples/pid/*data/ # +examples/rl/test_model/ +# results/ z_docstring.py TODOs.md diff --git a/examples/rl/config_overrides/cartpole/cartpole_stab.yaml b/examples/rl/config_overrides/cartpole/cartpole_stab.yaml index 1f3a0e430..ec5a6eb2c 100644 --- a/examples/rl/config_overrides/cartpole/cartpole_stab.yaml +++ b/examples/rl/config_overrides/cartpole/cartpole_stab.yaml @@ -8,34 +8,34 @@ task_config: # state initialization init_state: - init_x: 0.1 - init_x_dot: -1.5 - init_theta: -0.155 - init_theta_dot: 0.75 + init_x: 0.0 + init_x_dot: 0.0 + init_theta: 3.14 + init_theta_dot: 0 randomized_init: True randomized_inertial_prop: False init_state_randomization_info: init_x: distrib: 'uniform' - low: -2 - high: 2 + low: -0.25 + high: 0.25 init_x_dot: distrib: 'uniform' - low: -2 - high: 2 + low: -0.25 + high: 0.25 init_theta: distrib: 'uniform' - low: -0.16 - high: 0.16 + low: 3.0 + high: 3.3 init_theta_dot: distrib: 'uniform' - low: -1 - high: 1 + low: 0 + high: 0 task: stabilization task_info: - stabilization_goal: [0.7, 0] + stabilization_goal: [0.0, 0] stabilization_goal_tolerance: 0.0 inertial_prop: @@ -48,25 +48,9 @@ task_config: obs_goal_horizon: 0 # RL Reward - rew_state_weight: [1, 1, 1, 1] - rew_act_weight: 0.1 + rew_state_weight: [0.1, 0.1, 1, 0.1] + rew_act_weight: 0.001 rew_exponential: True - # constraints - constraints: - - constraint_form: default_constraint - constrained_variable: state - upper_bounds: - - 2 - - 2 - - 0.16 - - 1 - lower_bounds: - - -2 - - -2 - - -0.16 - - -1 - - constraint_form: default_constraint - constrained_variable: input done_on_out_of_bound: True done_on_violation: False diff --git a/examples/rl/config_overrides/cartpole/ppo_cartpole.yaml b/examples/rl/config_overrides/cartpole/ppo_cartpole.yaml index 0dc45648b..af2dbb326 100644 --- a/examples/rl/config_overrides/cartpole/ppo_cartpole.yaml +++ b/examples/rl/config_overrides/cartpole/ppo_cartpole.yaml @@ -25,7 +25,7 @@ algo_config: max_grad_norm: 0.5 # runner args - max_env_steps: 300000 + max_env_steps: 1000000 num_workers: 1 rollout_batch_size: 4 rollout_steps: 150 diff --git a/examples/rl/rl_experiment.py b/examples/rl/rl_experiment.py index d4427adf7..b9f776da8 100644 --- a/examples/rl/rl_experiment.py +++ b/examples/rl/rl_experiment.py @@ -12,7 +12,7 @@ from safe_control_gym.utils.registration import make -def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'): +def run(gui=True, plot=True, n_episodes=1, n_steps=None, curr_path='.'): '''Main function to run RL experiments. Args: @@ -32,7 +32,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'): fac = ConfigFactory() config = fac.merge() - task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track' + # task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track' if config.task == Environment.QUADROTOR: system = f'quadrotor_{str(config.task_config.quad_type)}D' else: @@ -50,7 +50,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'): output_dir=curr_path + '/temp') # Load state_dict from trained. - ctrl.load(f'{curr_path}/models/{config.algo}/{config.algo}_model_{system}_{task}.pt') + ctrl.load(f'{curr_path}/test_model/model_best.pt') # Remove temporary files and directories shutil.rmtree(f'{curr_path}/temp', ignore_errors=True) diff --git a/examples/rl/rl_experiment.sh b/examples/rl/rl_experiment.sh index 203bc14e4..cc6a08c05 100755 --- a/examples/rl/rl_experiment.sh +++ b/examples/rl/rl_experiment.sh @@ -1,25 +1,11 @@ #!/bin/bash -# SYS='cartpole' -# SYS='quadrotor_2D' -SYS='quadrotor_3D' - -# TASK='stab' -TASK='track' - +SYS='cartpole' +TASK='stab' ALGO='ppo' -# ALGO='sac' -# ALGO='safe_explorer_ppo' - -if [ "$SYS" == 'cartpole' ]; then - SYS_NAME=$SYS -else - SYS_NAME='quadrotor' -fi -# RL Experiment python3 ./rl_experiment.py \ - --task ${SYS_NAME} \ + --task ${SYS} \ --algo ${ALGO} \ --overrides \ ./config_overrides/${SYS}/${SYS}_${TASK}.yaml \ diff --git a/examples/rl/train_rl_model.sh b/examples/rl/train_rl_model.sh index d9647def7..70c8322fb 100755 --- a/examples/rl/train_rl_model.sh +++ b/examples/rl/train_rl_model.sh @@ -1,61 +1,17 @@ #!/bin/bash SYS='cartpole' -# SYS='quadrotor_2D' -# SYS='quadrotor_3D' - TASK='stab' -# TASK='track' - ALGO='ppo' -# ALGO='sac' -# ALGO='safe_explorer_ppo' - -if [ "$SYS" == 'cartpole' ]; then - SYS_NAME=$SYS -else - SYS_NAME='quadrotor' -fi - -# Removed the temporary data used to train the new unsafe model. -rm -r -f ./unsafe_rl_temp_data/ - -if [ "$ALGO" == 'safe_explorer_ppo' ]; then - # Pretrain the unsafe controller/agent. - python3 ../../safe_control_gym/experiments/train_rl_controller.py \ - --algo ${ALGO} \ - --task ${SYS_NAME} \ - --overrides \ - ./config_overrides/${SYS}/${ALGO}_${SYS}_pretrain.yaml \ - ./config_overrides/${SYS}/${SYS}_${TASK}.yaml \ - --output_dir ./unsafe_rl_temp_data/ \ - --seed 2 \ - --kv_overrides \ - task_config.init_state=None - # Move the newly trained unsafe model. - mv ./unsafe_rl_temp_data/model_latest.pt ./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt - - # Removed the temporary data used to train the new unsafe model. - rm -r -f ./unsafe_rl_temp_data/ -fi - -# Train the unsafe controller/agent. python3 ../../safe_control_gym/experiments/train_rl_controller.py \ --algo ${ALGO} \ - --task ${SYS_NAME} \ + --task ${SYS} \ --overrides \ ./config_overrides/${SYS}/${ALGO}_${SYS}.yaml \ ./config_overrides/${SYS}/${SYS}_${TASK}.yaml \ - --output_dir ./unsafe_rl_temp_data/ \ + --output_dir ./test_model/ \ --seed 2 \ --kv_overrides \ task_config.init_state=None \ - task_config.randomized_init=True \ - algo_config.pretrained=./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt - -# Move the newly trained unsafe model. -mv ./unsafe_rl_temp_data/model_best.pt ./models/${ALGO}/${ALGO}_model_${SYS}_${TASK}.pt - -# Removed the temporary data used to train the new unsafe model. -rm -r -f ./unsafe_rl_temp_data/ + task_config.randomized_init=True diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index b444b5e6d..b46d46e27 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -205,6 +205,8 @@ def select_action(self, obs, info=None): obs = torch.FloatTensor(obs).to(self.device) action = self.agent.ac.act(obs) + # Turn this on when you want to visualize it nicely. There is probably a better way tho + # time.sleep(1/20.0) return action def run(self, diff --git a/safe_control_gym/envs/gym_control/cartpole.py b/safe_control_gym/envs/gym_control/cartpole.py index 7ed6ea442..eb785e144 100644 --- a/safe_control_gym/envs/gym_control/cartpole.py +++ b/safe_control_gym/envs/gym_control/cartpole.py @@ -443,7 +443,7 @@ def _set_observation_space(self): # NOTE: different value in PyBullet gym (0.4) and OpenAI gym (2.4). self.x_threshold = 2.4 self.x_dot_threshold = 20 - self.theta_threshold_radians = 90 * math.pi / 180 # Angle at which to fail the episode. + self.theta_threshold_radians = 2 * math.pi # Angle at which to fail the episode. self.theta_dot_threshold = 20 # Limit set to 2x: i.e. a failing observation is still within bounds. obs_bound = np.array([self.x_threshold * 2,