-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmanager.py
87 lines (77 loc) · 3.56 KB
/
manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
from logging import getLogger
from src.logger import setup_logger
from src.config import Config, ControllerType
logger = getLogger(__name__)
CMD_LIST = ['train', 'evaluate']
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument("cmd", help="what to do", choices=CMD_LIST)
parser.add_argument("--controller", help="choose an algorithm (controller)",
choices=[name for name, m in ControllerType.__members__.items()])
parser.add_argument(
"--render", help="set to render the env when evaluate", action="store_true")
parser.add_argument(
"--save_replay", help="set to save replay", action="store_true")
parser.add_argument(
"--save_plot", help="set to save Q-value plot when evaluate", action="store_true")
parser.add_argument(
"--show_plot", help="set to show Q-value plot when evaluate", action="store_true")
parser.add_argument(
"--num_episodes", help="set to run how many episodes", default=10000, type=int)
parser.add_argument(
"--batch_size", help="set the batch size", default=50, type=int)
parser.add_argument(
"--eva_interval", help="set how many episodes evaluate once", default=500, type=int)
parser.add_argument("--evaluate_episodes",
help="set evaluate how many episodes", default=100, type=int)
parser.add_argument("--lr", help="set learning rate",
default=0.0001, type=float)
parser.add_argument(
"--epsilon", help="set epsilon when use epsilon-greedy", default=0.5, type=float)
parser.add_argument(
"--gamma", help="set reward decay rate", default=0.9, type=float)
parser.add_argument(
"--lam", help="set lambda if use sarsa(lambda) algorithm", default=0.5, type=float)
parser.add_argument(
"--forward", help="set to use forward-view sarsa(lambda)", action="store_true")
parser.add_argument(
"--rawpixels", help="set to use raw pixels as input (only valid to PPO)", action="store_true")
parser.add_argument(
"--max_workers", help="set max workers to train", default=8, type=int)
parser.add_argument(
"--t_max", help="set simulate how many timesteps until update param", default=5, type=int)
return parser
def start():
parser = create_parser()
args = parser.parse_args()
config = Config(ControllerType[args.controller])
if config.controller.controller_type == ControllerType.A3C:
from src.main_a3c import main
else:
from src.main import main
if args.cmd == 'train':
config.train = True
config.evaluate = False
else:
config.train = False
config.evaluate = True
print("\n===============================================================")
config.render = args.render
config.save_replay = args.save_replay
config.show_plot = args.show_plot
config.save_plot = args.save_plot
config.trainer.num_episodes = args.num_episodes
config.trainer.batch_size = args.batch_size
config.trainer.evaluate_interval = args.eva_interval
config.trainer.lr = args.lr
config.trainer.evaluate_episodes = args.evaluate_episodes
config.controller.epsilon = args.epsilon
config.controller.gamma = args.gamma
config.controller.lambda_ = args.lam
config.controller.forward = args.forward
config.controller.raw_pixels = args.rawpixels
config.controller.max_workers = args.max_workers
config.trainer.t_max = args.t_max
print("===============================================================\n")
main(config)