Skip to content

Commit

Permalink
Merge branch 'improved_linting' into minor_updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Oct 12, 2023
2 parents 5f078bf + c5f27c8 commit 560b1a3
Show file tree
Hide file tree
Showing 30 changed files with 80 additions and 118 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-ast
- id: check-yaml
Expand All @@ -30,6 +30,7 @@ repos:
hooks:
- id: isort
name: isort
args: ['--line-length=110']

- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v2.0.4
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ license = "MIT"
[tool.poetry.dependencies]
python = "^3.10"
matplotlib = "^3.5.1"
Pillow = "^9.0.0"
munch = "^2.5.0"
PyYAML = "^6.0"
imageio = "^2.14.1"
dict-deep = "^4.1.2"
scikit-optimize = "^0.9.0"
scikit-learn = "^1.3.0"
gymnasium = "^0.28"
torch = "^1.10.2"
gpytorch = "^1.6.0"
Expand All @@ -31,5 +31,5 @@ pre-commit = "^3.3.2"
[tool.poetry.dev-dependencies]

[build-system]
requires = ["poetry-core @ git+https://github.com/python-poetry/poetry-core.git@master"]
requires = ["poetry-core @ git+https://github.com/python-poetry/poetry-core.git@main"]
build-backend = "poetry.core.masonry.api"
17 changes: 7 additions & 10 deletions safe_control_gym/controllers/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
import torch

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.ddpg.ddpg_utils import (
DDPGAgent, DDPGBuffer, make_action_noise_process)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.controllers.ddpg.ddpg_utils import DDPGAgent, DDPGBuffer, make_action_noise_process
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import (
_flatten_obs, _unflatten_obs)
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import _flatten_obs, _unflatten_obs
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class DDPG(BaseController):
Expand Down
4 changes: 2 additions & 2 deletions safe_control_gym/controllers/lqr/ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from termcolor import colored

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.lqr.lqr_utils import (
compute_lqr_gain, discretize_linear_system, get_cost_weight_matrix)
from safe_control_gym.controllers.lqr.lqr_utils import (compute_lqr_gain, discretize_linear_system,
get_cost_weight_matrix)
from safe_control_gym.envs.benchmark_env import Task


Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/controllers/lqr/lqr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
'''Linear Quadratic Regulator (LQR).'''

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.lqr.lqr_utils import (compute_lqr_gain,
get_cost_weight_matrix)
from safe_control_gym.controllers.lqr.lqr_utils import compute_lqr_gain, get_cost_weight_matrix
from safe_control_gym.envs.benchmark_env import Task


Expand Down
7 changes: 3 additions & 4 deletions safe_control_gym/controllers/mpc/gp_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from skopt.sampler import Lhs

from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.gp_utils import (
GaussianProcessCollection, ZeroMeanIndependentGPModel, covSEard,
kmeans_centriods)
from safe_control_gym.controllers.mpc.gp_utils import (GaussianProcessCollection, ZeroMeanIndependentGPModel,
covSEard, kmeans_centriods)
from safe_control_gym.controllers.mpc.linear_mpc import MPC, LinearMPC
from safe_control_gym.envs.benchmark_env import Task

Expand Down Expand Up @@ -714,7 +713,7 @@ def learn(self,
self.train_iterations + validation_iterations,
random_state=self.seed)
input_samples = np.array(input_samples) # not being used currently
seeds = self.env.np_random.randint(0, 99999, size=self.train_iterations + validation_iterations)
seeds = self.env.np_random.integers(0, 99999, size=self.train_iterations + validation_iterations)
for i in range(self.train_iterations + validation_iterations):
# For random initial state training.
# init_state = init_state_samples[i,:]
Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/controllers/mpc/linear_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.mpc import MPC
from safe_control_gym.controllers.mpc.mpc_utils import \
compute_discrete_lqr_gain_from_cont_linear_system
from safe_control_gym.controllers.mpc.mpc_utils import compute_discrete_lqr_gain_from_cont_linear_system
from safe_control_gym.envs.benchmark_env import Task


Expand Down
9 changes: 4 additions & 5 deletions safe_control_gym/controllers/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import numpy as np

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.mpc.mpc_utils import (
compute_discrete_lqr_gain_from_cont_linear_system, compute_state_rmse,
get_cost_weight_matrix, reset_constraints, rk_discrete)
from safe_control_gym.controllers.mpc.mpc_utils import (compute_discrete_lqr_gain_from_cont_linear_system,
compute_state_rmse, get_cost_weight_matrix,
reset_constraints, rk_discrete)
from safe_control_gym.envs.benchmark_env import Task
from safe_control_gym.envs.constraints import (GENERAL_CONSTRAINTS,
create_constraint_list)
from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS, create_constraint_list


class MPC(BaseController):
Expand Down
14 changes: 6 additions & 8 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
import torch

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.ppo.ppo_utils import (
PPOAgent, PPOBuffer, compute_returns_and_advantages)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.controllers.ppo.ppo_utils import PPOAgent, PPOBuffer, compute_returns_and_advantages
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class PPO(BaseController):
Expand Down
14 changes: 6 additions & 8 deletions safe_control_gym/controllers/rarl/rap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
import torch

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.ppo.ppo_utils import (
PPOAgent, PPOBuffer, compute_returns_and_advantages)
from safe_control_gym.controllers.ppo.ppo_utils import PPOAgent, PPOBuffer, compute_returns_and_advantages
from safe_control_gym.controllers.rarl.rarl_utils import split_obs_by_adversary
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class RAP(BaseController):
Expand Down
14 changes: 6 additions & 8 deletions safe_control_gym/controllers/rarl/rarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
import torch

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.ppo.ppo_utils import (
PPOAgent, PPOBuffer, compute_returns_and_advantages)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.controllers.ppo.ppo_utils import PPOAgent, PPOBuffer, compute_returns_and_advantages
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class RARL(BaseController):
Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/controllers/rarl/rarl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import numpy as np

from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import (
_flatten_obs, _unflatten_obs)
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import _flatten_obs, _unflatten_obs


def split_obs_by_adversary(obs, indices_splits):
Expand Down
14 changes: 6 additions & 8 deletions safe_control_gym/controllers/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.sac.sac_utils import SACAgent, SACBuffer
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import (
_flatten_obs, _unflatten_obs)
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import _flatten_obs, _unflatten_obs
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class SAC(BaseController):
Expand Down
20 changes: 8 additions & 12 deletions safe_control_gym/controllers/safe_explorer/safe_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
import torch

from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.ppo.ppo_utils import \
compute_returns_and_advantages
from safe_control_gym.controllers.safe_explorer.safe_explorer_utils import (
ConstraintBuffer, SafetyLayer)
from safe_control_gym.controllers.safe_explorer.safe_ppo_utils import (
SafePPOAgent, SafePPOBuffer)
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (
RecordEpisodeStatistics, VecRecordEpisodeStatistics)
from safe_control_gym.controllers.ppo.ppo_utils import compute_returns_and_advantages
from safe_control_gym.controllers.safe_explorer.safe_explorer_utils import ConstraintBuffer, SafetyLayer
from safe_control_gym.controllers.safe_explorer.safe_ppo_utils import SafePPOAgent, SafePPOBuffer
from safe_control_gym.envs.env_wrappers.record_episode_statistics import (RecordEpisodeStatistics,
VecRecordEpisodeStatistics)
from safe_control_gym.envs.env_wrappers.vectorized_env import make_vec_envs
from safe_control_gym.math_and_models.normalization import (
BaseNormalizer, MeanStdNormalizer, RewardStdNormalizer)
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import (get_random_state, is_wrapped,
set_random_state)
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state


class SafeExplorerPPO(BaseController):
Expand Down
4 changes: 2 additions & 2 deletions safe_control_gym/envs/disturbances.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reset(self,
env
):
if self.step_offset is None:
self.current_step_offset = self.np_random.randint(self.max_step)
self.current_step_offset = self.np_random.integers(self.max_step)
else:
self.current_step_offset = self.step_offset
self.current_peak_step = int(self.current_step_offset + self.duration / 2)
Expand Down Expand Up @@ -146,7 +146,7 @@ def reset(self,
env
):
if self.step_offset is None:
self.current_step_offset = self.np_random.randint(self.max_step)
self.current_step_offset = self.np_random.integers(self.max_step)
else:
self.current_step_offset = self.step_offset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import gymnasium as gym
import numpy as np

from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env import \
VecEnvWrapper
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env import VecEnvWrapper


class RecordEpisodeStatistics(gym.Wrapper):
Expand Down
6 changes: 2 additions & 4 deletions safe_control_gym/envs/env_wrappers/vectorized_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import numpy as np
import torch

from safe_control_gym.envs.env_wrappers.vectorized_env.dummy_vec_env import \
DummyVecEnv
from safe_control_gym.envs.env_wrappers.vectorized_env.subproc_vec_env import \
SubprocVecEnv
from safe_control_gym.envs.env_wrappers.vectorized_env.dummy_vec_env import DummyVecEnv
from safe_control_gym.envs.env_wrappers.vectorized_env.subproc_vec_env import SubprocVecEnv


def make_env_fn(env_func,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import numpy as np

from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env import VecEnv
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import \
_flatten_obs
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import _flatten_obs
from safe_control_gym.utils.utils import get_random_state, set_random_state


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import numpy as np

from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env import VecEnv
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import (
CloudpickleWrapper, _flatten_list, _flatten_obs, clear_mpi_env_vars)
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import (CloudpickleWrapper,
_flatten_list, _flatten_obs,
clear_mpi_env_vars)
from safe_control_gym.utils.utils import get_random_state, set_random_state


Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/envs/env_wrappers/vectorized_env/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from abc import ABC, abstractmethod

from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import \
tile_images
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env_utils import tile_images


class VecEnv(ABC):
Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from gymnasium import spaces

from safe_control_gym.envs.benchmark_env import BenchmarkEnv, Cost, Task
from safe_control_gym.envs.constraints import (GENERAL_CONSTRAINTS,
SymmetricStateConstraint)
from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS, SymmetricStateConstraint
from safe_control_gym.math_and_models.normalization import normalize_angle
from safe_control_gym.math_and_models.symbolic_systems import SymbolicModel

Expand Down
6 changes: 2 additions & 4 deletions safe_control_gym/envs/gym_pybullet_drones/quadrotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from safe_control_gym.envs.benchmark_env import Cost, Task
from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS
from safe_control_gym.envs.gym_pybullet_drones.base_aviary import BaseAviary
from safe_control_gym.envs.gym_pybullet_drones.quadrotor_utils import (
QuadType, cmd2pwm, pwm2rpm)
from safe_control_gym.envs.gym_pybullet_drones.quadrotor_utils import QuadType, cmd2pwm, pwm2rpm
from safe_control_gym.math_and_models.symbolic_systems import SymbolicModel
from safe_control_gym.math_and_models.transformations import (
csRotXYZ, transform_trajectory)
from safe_control_gym.math_and_models.transformations import csRotXYZ, transform_trajectory


class Quadrotor(BaseAviary):
Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/experiments/base_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from munch import munchify
from termcolor import colored

from safe_control_gym.math_and_models.metrics.performance_metrics import \
compute_cvar
from safe_control_gym.math_and_models.metrics.performance_metrics import compute_cvar
from safe_control_gym.utils.utils import is_wrapped


Expand Down
4 changes: 1 addition & 3 deletions safe_control_gym/safety_filters/cbf/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import numpy as np

from safe_control_gym.safety_filters.base_safety_filter import BaseSafetyFilter
from safe_control_gym.safety_filters.cbf.cbf_utils import (cartesian_product,
cbf_cartpole,
linear_function)
from safe_control_gym.safety_filters.cbf.cbf_utils import cartesian_product, cbf_cartpole, linear_function


class CBF(BaseSafetyFilter):
Expand Down
9 changes: 3 additions & 6 deletions safe_control_gym/safety_filters/mpsc/linear_mpsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.mpc_utils import rk_discrete
from safe_control_gym.envs.benchmark_env import Environment, Task
from safe_control_gym.envs.constraints import (ConstrainedVariableType,
LinearConstraint,
QuadraticContstraint)
from safe_control_gym.envs.constraints import ConstrainedVariableType, LinearConstraint, QuadraticContstraint
from safe_control_gym.safety_filters.mpsc.mpsc import MPSC
from safe_control_gym.safety_filters.mpsc.mpsc_utils import (
Cost_Function, compute_RPI_set, ellipse_bounding_box,
pontryagin_difference_AABB)
from safe_control_gym.safety_filters.mpsc.mpsc_utils import (Cost_Function, compute_RPI_set,
ellipse_bounding_box, pontryagin_difference_AABB)


class LINEAR_MPSC(MPSC):
Expand Down
Loading

0 comments on commit 560b1a3

Please sign in to comment.