-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
38 lines (30 loc) · 1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import sys
from agent import Agent
from stable_baselines3 import DQN
from stable_baselines3.common.logger import configure
from snek.environment import Snek
if '--novid' in sys.argv:
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
def main():
env = Snek()
#agent = Agent(env)
new_logger = configure('./results', ["stdout", "csv", "json", "log"])
model = DQN("MlpPolicy", env, verbose=1)
model.set_logger(new_logger)
model.learn(total_timesteps=5000000, log_interval=100)
model.save("dqn_snek")
total_len = 0
num_episodes = 0
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=False)
obs, _, done, _ = env.step(action)
env.render() # Comment out this call to train faster
if done:
total_len += env.player.len
num_episodes += 1
obs = env.reset()
print('Average len: {:.1f}'.format(total_len / num_episodes))
if __name__ == '__main__':
main()