Skip to content

Commit

Permalink
Fixing HPO Linting Pipeline Failure (#152)
Browse files Browse the repository at this point in the history
* Fixing linting issues, tests remain to fix

* Test out running the HPO tests on GitHub

* Fixing MySQL issue

* Trying again

* Trying again

* Removing everything related to fixing tests

* Skipping tests for now
  • Loading branch information
Federico-PizarroBejarano authored May 14, 2024
1 parent f90ac22 commit f6b850e
Show file tree
Hide file tree
Showing 25 changed files with 87 additions and 113 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-ast
- id: check-yaml
Expand Down
2 changes: 1 addition & 1 deletion examples/cbf/config_overrides/ppo_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 64
activation: "relu"
activation: relu
norm_obs: False
norm_reward: False
clip_obs: 10.0
Expand Down
2 changes: 1 addition & 1 deletion examples/cbf/config_overrides/sac_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 256
activation: "relu"
activation: relu
use_entropy_tuning: False

# optim args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
hpo_config:

hpo: True # do hyperparameter optimization
load_if_exists: True # this should set to True if hpo is run in parallel
use_database: False # this is set to true if MySQL is used
Expand Down
26 changes: 8 additions & 18 deletions examples/hpo/hpo_experiment.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
"""Template hyperparameter optimization/hyperparameter evaluation script.
"""
'''Template hyperparameter optimization/hyperparameter evaluation script.'''
import os
from functools import partial

import yaml

import matplotlib.pyplot as plt
import numpy as np

from safe_control_gym.envs.benchmark_env import Environment, Task

from safe_control_gym.hyperparameters.hpo import HPO
from safe_control_gym.experiments.base_experiment import BaseExperiment
from safe_control_gym.hyperparameters.hpo import HPO
from safe_control_gym.utils.configuration import ConfigFactory
from safe_control_gym.utils.registration import make
from safe_control_gym.utils.utils import set_device_from_config, set_dir_from_config, set_seed_from_config


def hpo(config):
"""Hyperparameter optimization.
'''Hyperparameter optimization.
Usage:
* to start HPO, use with `--func hpo`.
"""
'''

# Experiment setup.
if config.hpo_config.hpo:
Expand All @@ -48,12 +40,11 @@ def hpo(config):


def train(config):
"""Training for a given set of hyperparameters.
'''Training for a given set of hyperparameters.
Usage:
* to start training, use with `--func train`.
"""
'''
# Override algo_config with given yaml file
if config.opt_hps == '':
# if no opt_hps file is given
Expand Down Expand Up @@ -94,15 +85,14 @@ def train(config):
experiment.launch_training()
results, metrics = experiment.run_evaluation(n_episodes=1, n_steps=None, done_on_max_steps=True)
control_agent.close()

return eval_env.X_GOAL, results, metrics


MAIN_FUNCS = {'hpo': hpo, 'train': train}


if __name__ == '__main__':

# Make config.
fac = ConfigFactory()
fac.add_argument('--func', type=str, default='train', help='main function to run.')
Expand All @@ -115,5 +105,5 @@ def train(config):
# Execute.
func = MAIN_FUNCS.get(config.func, None)
if func is None:
raise Exception('Main function {} not supported.'.format(config.func))
raise Exception(f'Main function {config.func} not supported.')
func(config)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 64
activation: "relu"
activation: relu
norm_obs: False
norm_reward: False
clip_obs: 10.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
hpo_config:

hpo: True # do hyperparameter optimization
load_if_exists: True # this should set to True if hpo is run in parallel
use_database: False # this is set to true if MySQL is used
Expand All @@ -21,7 +20,7 @@ hpo_config:
hps_config:
# model args
hidden_dim: 64
activation: "relu"
activation: relu

# loss args
gamma: 0.99
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 256
activation: "relu"
activation: relu
norm_obs: False
norm_reward: False
clip_obs: 10.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
hpo_config:

hpo: True # do hyperparameter optimization
load_if_exists: True # this should set to True if hpo is run in parallel
use_database: False # this is set to true if MySQL is used
Expand All @@ -21,7 +20,7 @@ hpo_config:
hps_config:
# model args
hidden_dim: 256
activation: "relu"
activation: relu

# loss args
gamma: 0.99
Expand Down
2 changes: 1 addition & 1 deletion examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 64
activation: "relu"
activation: relu
norm_obs: False
norm_reward: False
clip_obs: 10.0
Expand Down
2 changes: 1 addition & 1 deletion examples/mpsc/config_overrides/cartpole/sac_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 64
activation: "relu"
activation: relu
use_entropy_tuning: False

# optim args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 256
activation: "relu"
activation: relu

# loss args
use_gae: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 256
activation: "relu"
activation: relu
use_entropy_tuning: False

# optim args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 128
activation: "relu"
activation: relu

# loss args
use_gae: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 128
activation: "relu"
activation: relu
use_entropy_tuning: False

# optim args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: ppo
algo_config:
# model args
hidden_dim: 128
activation: "relu"
activation: relu

# loss args
use_gae: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ algo: sac
algo_config:
# model args
hidden_dim: 128
activation: "relu"
activation: relu
use_entropy_tuning: False

# optim args
Expand Down
2 changes: 1 addition & 1 deletion safe_control_gym/controllers/mpc/gp_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from sklearn.model_selection import train_test_split
from skopt.sampler import Lhs

from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.gp_utils import (GaussianProcessCollection, ZeroMeanIndependentGPModel,
covMatern52ard, covSEard, kmeans_centriods)
from safe_control_gym.controllers.mpc.linear_mpc import MPC, LinearMPC
from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.envs.benchmark_env import Task


Expand Down
10 changes: 5 additions & 5 deletions safe_control_gym/controllers/mpc/gp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def __init__(self, model_type,
self.parallel = parallel
if parallel:
self.gps = BatchGPModel(model_type,
likelihood,
input_mask=input_mask,
target_mask=target_mask,
normalize=normalize,
kernel=kernel)
likelihood,
input_mask=input_mask,
target_mask=target_mask,
normalize=normalize,
kernel=kernel)
else:
for _ in range(target_dim):
self.gp_list.append(GaussianProcess(model_type,
Expand Down
34 changes: 14 additions & 20 deletions safe_control_gym/hyperparameters/database.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
"""
This script already assumes that mysql server is up and hard coded user 'optuna' without password was added.
"""
'''This script already assumes that mysql server is up and hard
coded user 'optuna' without password was added.
'''

import mysql.connector

from safe_control_gym.utils.configuration import ConfigFactory


def create(config):
"""
This function is used to create database named after --Tag.
"""
'''This function is used to create database named after --Tag.'''

db = mysql.connector.connect(
host='localhost',
Expand All @@ -21,19 +17,17 @@ def create(config):

mycursor = db.cursor()

mycursor.execute('CREATE DATABASE IF NOT EXISTS {}'.format(config.tag))
mycursor.execute(f'CREATE DATABASE IF NOT EXISTS {config.tag}')


def drop(config):
"""
This function is used to drop database named after --Tag.
Be sure to backup before dropping.
* Backup: mysqldump --no-tablespaces -u optuna DATABASE_NAME > DATABASE_NAME.sql
* Restore:
1. mysql -u optuna -e "create database DATABASE_NAME".
'''This function is used to drop database named after --Tag.
Be sure to backup before dropping.
* Backup: mysqldump --no-tablespaces -u optuna DATABASE_NAME > DATABASE_NAME.sql
* Restore:
1. mysql -u optuna -e 'create database DATABASE_NAME'.
2. mysql -u optuna DATABASE_NAME < DATABASE_NAME.sql
"""
'''

db = mysql.connector.connect(
host='localhost',
Expand All @@ -42,18 +36,18 @@ def drop(config):

mycursor = db.cursor()

mycursor.execute('drop database if exists {}'.format(config.tag))
mycursor.execute(f'drop database if exists {config.tag}')


MAIN_FUNCS = {'drop': drop, 'create': create}

if __name__ == '__main__':

if __name__ == '__main__':
fac = ConfigFactory()
fac.add_argument('--func', type=str, default='create', help='main function to run.')
config = fac.merge()

func = MAIN_FUNCS.get(config.func, None)
if func is None:
raise Exception('Main function {} not supported.'.format(config.func))
raise Exception(f'Main function {config.func} not supported.')
func(config)
Loading

0 comments on commit f6b850e

Please sign in to comment.