Skip to content

Commit

Permalink
Simple code for cartpole swingup with pure RL (PPO)
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Jan 10, 2025
1 parent 5f20bac commit ae9460e
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 100 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ examples/mpsc/unsafe_rl_temp_data/
#
examples/pid/*data/
#
examples/rl/test_model/
#
results/
z_docstring.py
TODOs.md
Expand Down
46 changes: 15 additions & 31 deletions examples/rl/config_overrides/cartpole/cartpole_stab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,34 @@ task_config:

# state initialization
init_state:
init_x: 0.1
init_x_dot: -1.5
init_theta: -0.155
init_theta_dot: 0.75
init_x: 0.0
init_x_dot: 0.0
init_theta: 3.14
init_theta_dot: 0
randomized_init: True
randomized_inertial_prop: False

init_state_randomization_info:
init_x:
distrib: 'uniform'
low: -2
high: 2
low: -0.25
high: 0.25
init_x_dot:
distrib: 'uniform'
low: -2
high: 2
low: -0.25
high: 0.25
init_theta:
distrib: 'uniform'
low: -0.16
high: 0.16
low: 3.0
high: 3.3
init_theta_dot:
distrib: 'uniform'
low: -1
high: 1
low: 0
high: 0

task: stabilization
task_info:
stabilization_goal: [0.7, 0]
stabilization_goal: [0.0, 0]
stabilization_goal_tolerance: 0.0

inertial_prop:
Expand All @@ -48,25 +48,9 @@ task_config:
obs_goal_horizon: 0

# RL Reward
rew_state_weight: [1, 1, 1, 1]
rew_act_weight: 0.1
rew_state_weight: [0.1, 0.1, 1, 0.1]
rew_act_weight: 0.001
rew_exponential: True

# constraints
constraints:
- constraint_form: default_constraint
constrained_variable: state
upper_bounds:
- 2
- 2
- 0.16
- 1
lower_bounds:
- -2
- -2
- -0.16
- -1
- constraint_form: default_constraint
constrained_variable: input
done_on_out_of_bound: True
done_on_violation: False
2 changes: 1 addition & 1 deletion examples/rl/config_overrides/cartpole/ppo_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ algo_config:
max_grad_norm: 0.5

# runner args
max_env_steps: 300000
max_env_steps: 1000000
num_workers: 1
rollout_batch_size: 4
rollout_steps: 150
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/rl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from safe_control_gym.utils.registration import make


def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
def run(gui=True, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
'''Main function to run RL experiments.
Args:
Expand All @@ -32,7 +32,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
fac = ConfigFactory()
config = fac.merge()

task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
# task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
if config.task == Environment.QUADROTOR:
system = f'quadrotor_{str(config.task_config.quad_type)}D'
else:
Expand All @@ -50,7 +50,7 @@ def run(gui=False, plot=True, n_episodes=1, n_steps=None, curr_path='.'):
output_dir=curr_path + '/temp')

# Load state_dict from trained.
ctrl.load(f'{curr_path}/models/{config.algo}/{config.algo}_model_{system}_{task}.pt')
ctrl.load(f'{curr_path}/test_model/model_best.pt')

# Remove temporary files and directories
shutil.rmtree(f'{curr_path}/temp', ignore_errors=True)
Expand Down
20 changes: 3 additions & 17 deletions examples/rl/rl_experiment.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
#!/bin/bash

# SYS='cartpole'
# SYS='quadrotor_2D'
SYS='quadrotor_3D'

# TASK='stab'
TASK='track'

SYS='cartpole'
TASK='stab'
ALGO='ppo'
# ALGO='sac'
# ALGO='safe_explorer_ppo'

if [ "$SYS" == 'cartpole' ]; then
SYS_NAME=$SYS
else
SYS_NAME='quadrotor'
fi

# RL Experiment
python3 ./rl_experiment.py \
--task ${SYS_NAME} \
--task ${SYS} \
--algo ${ALGO} \
--overrides \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
Expand Down
50 changes: 3 additions & 47 deletions examples/rl/train_rl_model.sh
Original file line number Diff line number Diff line change
@@ -1,61 +1,17 @@
#!/bin/bash

SYS='cartpole'
# SYS='quadrotor_2D'
# SYS='quadrotor_3D'

TASK='stab'
# TASK='track'

ALGO='ppo'
# ALGO='sac'
# ALGO='safe_explorer_ppo'

if [ "$SYS" == 'cartpole' ]; then
SYS_NAME=$SYS
else
SYS_NAME='quadrotor'
fi

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/

if [ "$ALGO" == 'safe_explorer_ppo' ]; then
# Pretrain the unsafe controller/agent.
python3 ../../safe_control_gym/experiments/train_rl_controller.py \
--algo ${ALGO} \
--task ${SYS_NAME} \
--overrides \
./config_overrides/${SYS}/${ALGO}_${SYS}_pretrain.yaml \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
--output_dir ./unsafe_rl_temp_data/ \
--seed 2 \
--kv_overrides \
task_config.init_state=None

# Move the newly trained unsafe model.
mv ./unsafe_rl_temp_data/model_latest.pt ./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/
fi

# Train the unsafe controller/agent.
python3 ../../safe_control_gym/experiments/train_rl_controller.py \
--algo ${ALGO} \
--task ${SYS_NAME} \
--task ${SYS} \
--overrides \
./config_overrides/${SYS}/${ALGO}_${SYS}.yaml \
./config_overrides/${SYS}/${SYS}_${TASK}.yaml \
--output_dir ./unsafe_rl_temp_data/ \
--output_dir ./test_model/ \
--seed 2 \
--kv_overrides \
task_config.init_state=None \
task_config.randomized_init=True \
algo_config.pretrained=./models/${ALGO}/${ALGO}_pretrain_${SYS}_${TASK}.pt

# Move the newly trained unsafe model.
mv ./unsafe_rl_temp_data/model_best.pt ./models/${ALGO}/${ALGO}_model_${SYS}_${TASK}.pt

# Removed the temporary data used to train the new unsafe model.
rm -r -f ./unsafe_rl_temp_data/
task_config.randomized_init=True
2 changes: 2 additions & 0 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def select_action(self, obs, info=None):
obs = torch.FloatTensor(obs).to(self.device)
action = self.agent.ac.act(obs)

# Turn this on when you want to visualize it nicely. There is probably a better way tho
# time.sleep(1/20.0)
return action

def run(self,
Expand Down
2 changes: 1 addition & 1 deletion safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _set_observation_space(self):
# NOTE: different value in PyBullet gym (0.4) and OpenAI gym (2.4).
self.x_threshold = 2.4
self.x_dot_threshold = 20
self.theta_threshold_radians = 90 * math.pi / 180 # Angle at which to fail the episode.
self.theta_threshold_radians = 2 * math.pi # Angle at which to fail the episode.
self.theta_dot_threshold = 20
# Limit set to 2x: i.e. a failing observation is still within bounds.
obs_bound = np.array([self.x_threshold * 2,
Expand Down

0 comments on commit ae9460e

Please sign in to comment.