-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
69 lines (58 loc) · 2.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os, shutil
import config
try:
path = os.path.dirname(__file__)
config.load(os.path.join(path, "config.json"))
cfg = config.get_cfg()
except:
print("cannot load config.json")
exit()
import torch
import reward
import multiprocessing
from multiprocessing import Pool
import argparse
import json
from DQN.DQN_agent import DQN_Agent
from Actor_Critic.AC_agent import AC_Agent
def unpack(args):
run_training(**args)
def run_training(id, algo, episodes, reward_func, reward_settings, targets, reg_inits, root_dir):
log_dir = os.path.join(root_dir, str(id))
os.makedirs(log_dir)
preset = getattr(cfg.presets, algo)
agent = globals()[preset.agent](**preset.parameters.todict(), verbose=True, log_dir=log_dir)
Reward_func = getattr(reward, reward_func)
agent.train(Reward_func, reward_settings, episodes, targets, reg_inits)
agent.save("best", best=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run a multi-training series')
parser.add_argument('name', help="name of the experiment")
parser.add_argument('-f', '--force', help="force overwriting of the log folder", action='store_true')
parser.add_argument('-t', '--training', help="specify custom training file", default="training.json")
parser.add_argument('-j', '--jobs', type=int, help="number of concurrent jobs", default=multiprocessing.cpu_count())
args = parser.parse_args()
with open(args.training) as f:
trainings = json.load(f)
os.makedirs("Experiments", exist_ok=True)
root_dir = os.path.join("Experiments", args.name)
if args.force:
shutil.rmtree(root_dir, ignore_errors=True)
os.makedirs(root_dir)
with open(os.path.join(root_dir, "Training_descriptions.json"), 'w') as f:
json.dump(trainings, f, indent=4)
# json.dump(cfg.todict(), f, indent=4) doesn't work
for training in trainings:
training['root_dir'] = root_dir
if args.jobs > 1:
multiprocessing.set_start_method('spawn')
with Pool(processes=args.jobs) as pool:
try:
pool.map(unpack, trainings)
except KeyboardInterrupt:
pool.terminate()
pool.join()
# pool.starmap(run_multi_training, trainings)
else:
for training in trainings:
unpack(training)