-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmpe_lstm_main.py
65 lines (53 loc) · 2.62 KB
/
mpe_lstm_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import supersuit
import pettingzoo.mpe.simple_tag_v2 as simple_tag_v2
import gym
import random
import numpy as np
from time import time_ns
from tensorboardX import SummaryWriter
from agent.ppo_agent import PPOAgent
from agent.ppo_lstm_agent import PPOLSTMAgent
def get_action():
return np.random.randint(0,5)
env_config = {
"num_good" : 2,
"num_adversaries" : 3,
"num_obstacles" : 2,
"max_cycles" : 100,
"continuous_actions" : False
}
if __name__ == "__main__":
env = simple_tag_v2.env(num_good=env_config["num_good"], num_adversaries=env_config["num_adversaries"], \
num_obstacles=env_config["num_obstacles"], max_cycles=env_config["max_cycles"], \
continuous_actions=env_config["continuous_actions"])
sum_of_agents = env_config["num_good"] + env_config["num_adversaries"]
adversary_observation = 4 + (env_config["num_obstacles"] * 2) + (env_config["num_good"] + env_config["num_adversaries"]-1) * 2 + env_config["num_good"] * 2
good_observation = 4 + (env_config["num_obstacles"] * 2) + (env_config["num_good"] + env_config["num_adversaries"]-1) * 2 + (env_config["num_good"]-1) * 2
adversary_agent = PPOLSTMAgent(adversary_observation , 5)
good_agent = PPOAgent(good_observation, 5)
summary_writer = SummaryWriter('logs/mpe_lstm_main_' + str(time_ns()))
for i_eps in range(10000):
env.reset()
prev_state = [np.zeros(adversary_observation) for _ in range(env_config["num_adversaries"])]
step_cnt = 0
sum_reward = 0
while step_cnt < env_config["max_cycles"] * sum_of_agents:
agent_idx = step_cnt % sum_of_agents
next_state, reward, done, info = env.last()
if not done:
# print('step cnt : {}'.format(step_cnt))
# print('next_state : {}, reward : {}, done : {}, info : {}'.format(next_state, reward, done, info))
action = 0
if agent_idx < env_config["num_adversaries"]:
action, action_prob = adversary_agent.get_action(next_state)
adversary_agent.save_xp((prev_state[agent_idx], next_state, action, action_prob[action].item(), reward, done))
prev_state[agent_idx] = next_state
sum_reward += reward
elif agent_idx >= env_config["num_adversaries"]:
action = get_action()
env.step(action)
# env.render()
step_cnt += 1
adversary_agent.train()
# print('{} eps total reward : {}'.format(i_eps, sum_reward))
summary_writer.add_scalar('Episode reward', sum_reward, i_eps)