forked from hlin01/mini_behavior
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_APT.py
67 lines (55 loc) · 2.03 KB
/
train_APT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gym
import os
from mini_behavior.register import register
from algorithms.APT_PPO import APT_PPO
from env_wrapper import CustomObservationWrapper
TASK = 'MultiToy'
ROOM_SIZE = 8
MAX_STEPS = 1000
TOTAL_TIMESTEPS = 1e6
DENSE_REWARD = False
POLICY_TYPE = 'CnnPolicy'
NUM_ENVS = 8
NUM_STEPS = 125
SAVE_FREQUENCY = 100
TEST_STEPS = 500
env_name = f"MiniGrid-{TASK}-{ROOM_SIZE}x{ROOM_SIZE}-N2-v0"
env_kwargs = {"room_size": ROOM_SIZE, "max_steps": MAX_STEPS}
test_env_kwargs = {"room_size": ROOM_SIZE, "max_steps": MAX_STEPS, "test_env": True}
test_env_name = f"MiniGrid-{TASK}-{ROOM_SIZE}x{ROOM_SIZE}-N2-v1"
def make_env(env_id, seed, idx, env_kwargs):
def thunk():
env = gym.make(env_id, **env_kwargs)
env = CustomObservationWrapper(env)
env.seed(seed + idx)
return env
return thunk
def init_env(num_envs: int, seed):
return gym.vector.SyncVectorEnv(
[make_env(env_name, seed, i, env_kwargs) for i in range(num_envs)]
)
if __name__ == "__main__":
register(
id=env_name,
entry_point=f'mini_behavior.envs:{TASK}Env',
kwargs=env_kwargs
)
register(
id = test_env_name,
entry_point=f'mini_behavior.envs:{TASK}Env',
kwargs = test_env_kwargs
)
env = init_env(NUM_ENVS, seed = 1)
save_dir = f"models/APT_PPO_{TASK}_Run5"
print('begin training')
# Policy training
model = APT_PPO(env, env_id = env_name, save_dir = save_dir, test_env_id=test_env_name, test_env_kwargs= test_env_kwargs, num_envs=NUM_ENVS, total_timesteps = TOTAL_TIMESTEPS, num_steps=NUM_STEPS, save_freq = SAVE_FREQUENCY, test_steps = TEST_STEPS)
print(f"\n=== Observation Space ===")
print(f"Shape: {env.observation_space.shape}")
print(f"Type: {env.observation_space.dtype}")
model.train()
# Check if the directory exists, and if not, create it
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model.evaluate_training()
model.save(f"{save_dir}/{env_name}", env_kwargs = env_kwargs)