-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_tune_job.py
97 lines (84 loc) · 2.68 KB
/
run_tune_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from os import stat, walk
import numpy as np
import random
from itertools import permutations
from azureml.core import Run
import time
import logging
from pyrsistent import m
from typing import Dict, Tuple
import ray
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.dqn as dqn
import ray.tune as tune
from ray.rllib import train
from ray.tune.registry import register_env
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
from ray_on_aml.core import Ray_On_AML
from contosocabs_env import ContosoCabs_v0
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
class Callbacks(DefaultCallbacks):
def on_episode_end( self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs):
'''Callback on train result to record metrics returned by trainer.
'''
run = Run.get_context()
if episode is None:
pass
print("episode {} (env-idx={}) started.".format(episode.episode_id, env_index))
print(f'result: {episode}')
run.log(
name='episode_reward_mean',
value=sum(episode.agent_rewards.values())/len(episode.agent_rewards.keys()))
run.log(
name='episode_reward_max',
value=max(episode.agent_rewards.values()))
run.log(
name='episode_length',
value=episode.length)
def merge_dict(config, args):
for key, _ in config.items():
if key in args:
config[key] = args[key]
return config
def initiate_train():
args = train.create_parser().parse_args()
# Mapping configuration
config = {}
if args.run == "APEX":
config = APEX_DEFAULT_CONFIG
config["env"] = ContosoCabs_v0
config["log_level"] = "INFO"
config["callbacks"] = Callbacks
config = merge_dict(config, args.config)
print(f'config: {config}')
tune.run(
run_or_experiment=args.run,
config=config,
stop=args.stop,
local_dir='./logs')
def main():
ray_on_aml = Ray_On_AML()
# If running on a GPU cluster use ray_on_aml.getRay(gpu_support=True)
ray = ray_on_aml.getRay()
if ray: #in the headnode
logger.info("head node detected")
time.sleep(15)
print(ray.cluster_resources())
initiate_train()
else:
logger.info("in worker node")
if __name__ == "__main__":
main()