-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
78 lines (58 loc) · 2.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import time
import argparse
import datetime
from .config import ConfigLoader
from .gamelauncher import GameLauncher
from .trainers import FullTrainer, AttackTrainer
from .mainbot import MainBot
from .loggers import logger
import numpy as np
from sc2 import Result
import tensorflow as tf
import keras.backend.tensorflow_backend as backend
def get_session(gpu_fraction=0.85):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
backend.set_session(get_session())
# load commondline args
parser = argparse.ArgumentParser(
prog='main.py',
description='Yuri, the StarCraft II bot'
)
parser.add_argument('--type')
cmd_args = parser.parse_args()
game_type = cmd_args.type
# load configuration file
cl = ConfigLoader('yuri.json')
configs = cl.get_json()
config_json = configs.get(game_type)
model_type = config_json.get('model_type')
model = config_json.get('model', None)
if str(game_type) == 'game':
realtime = config_json.get('realtime', False)
difficulty = config_json.get('difficulty', 'easy')
race = config_json.get('race', 'zerg')
map_name = config_json.get('map')
logfile_name = config_json.get('logfn')
model_path = os.path.join(os.path.dirname(__file__), model)
game_launcher = GameLauncher(MainBot, model is not None, model_path, map_name=map_name, realtime=realtime)
train_data_tensor = list()
game_result = game_launcher.start_game(difficulty, race, train_data_tensor)
logger.debug(f'Game Result: {game_result}')
if game_result == Result.Victory:
np.save(
os.path.join(os.path.dirname(__file__), f'{model_type}_local_train/{str(int(time.time()))}.npy'),
np.array(train_data_tensor)
)
log_path = os.path.join(os.path.dirname(__file__), logfile_name)
with open(log_path, 'a') as f:
prefix = 'Model: ' if game_launcher.use_model else 'Random: '
f.write(f'{datetime.datetime.now()}:{prefix} Against {difficulty} {race} {game_result}\n')
elif str(game_type) == 'train':
if model_type == 'attack':
trainer = AttackTrainer(config_json)
elif model_type == 'full':
trainer = FullTrainer(config_json)
trainer.prepare_model(model)
trainer.train()