Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 31, 2018
1 parent 7fbe91a commit 339a02b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
5 changes: 2 additions & 3 deletions rl_baselines/rl_algorithm/sac.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pickle
import os
import pickle

import numpy as np
from stable_baselines import SAC
from stable_baselines.sac.policies import MlpPolicy, CnnPolicy
from stable_baselines.common.vec_env import VecNormalize, DummyVecEnv
from stable_baselines.sac.policies import MlpPolicy, CnnPolicy

from environments.utils import makeEnv
from rl_baselines.base_classes import StableBaselinesRLObject
Expand Down
9 changes: 6 additions & 3 deletions rl_baselines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def latestPath(path):
:param path: path to the log folder (defined in srl_model.yaml) (str)
:return: path to latest learned model in the same dataset folder (str)
"""
return max([path + "/" + d for d in os.listdir(path) if not d.startswith('baselines') and os.path.isdir(path + "/" + d)],
key=os.path.getmtime) + '/srl_model.pth'
return max(
[path + "/" + d for d in os.listdir(path) if not d.startswith('baselines') and os.path.isdir(path + "/" + d)],
key=os.path.getmtime) + '/srl_model.pth'


def configureEnvAndLogFolder(args, env_kwargs, all_models):
"""
Expand Down Expand Up @@ -199,7 +201,8 @@ def main():
parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml",
help='Set the location of the SRL model path configuration.')
parser.add_argument('--hyperparam', type=str, nargs='+', default=[])
parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model")
parser.add_argument('--min-episodes-save', type=int, default=100,
help="Min number of episodes before saving best model")
parser.add_argument('--latest', action='store_true', default=False,
help='load the latest learned model (location:srl_zoo/logs/DatasetName/)')

Expand Down
1 change: 0 additions & 1 deletion rl_baselines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def get_original_obs(self):
"""
return self.venv.get_original_obs()


def saveRunningAverage(self, path):
"""
Hack to use VecNormalize
Expand Down

0 comments on commit 339a02b

Please sign in to comment.