-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_train.py
109 lines (93 loc) · 4.14 KB
/
dqn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import sys
import json
import random
import warnings
import logging
import argparse
import torch
import numpy as np
from datetime import datetime
from dqn import DQN, DQNAgent, DQNTrainer
from environment.graph import KnowledgeGraph
from environment.chatenv import StoryBotRetellEnv
# get the current time string
NOW_STR = datetime.now().strftime('%Y%m%d_%H%M%S')
# set logging format
logging.basicConfig(filename=f'log/dqn_{NOW_STR}.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
# set the hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=10)
parser.add_argument("-b", "--batch_size", help="batch size", type=int, default=1)
parser.add_argument("-l", "--learning_rate", help="learning rate", type=float, default=0.001)
parser.add_argument("-s", "--seed", help="random seed", type=int, default=42)
parser.add_argument("-c", "--cuda", help="cuda device", type=int, default=0, choices=[0, 1, 2])
parser.add_argument("-u", "--summary", help="story summary file", type=str, default="data/summary/summary_train.json")
parser.add_argument("-k", "--kg", help="story knowledge graph folder", type=str, default="data/kg/train_coref")
parser.add_argument("-m", "--name1", help="save_model_1 name", type=str, default=f'model/dqn1_{NOW_STR}.pth')
parser.add_argument("-n", "--name2", help="save_model_2 name", type=str, default=f'model/dqn2_{NOW_STR}.pth')
args = parser.parse_args()
# set the hyperparameters
EPOCHS = int(args.epochs)
BATCH_SIZE = int(args.batch_size)
LEARNING_RATE = float(args.learning_rate)
SEED = int(args.seed)
CUDA = args.cuda
SUMMARY = args.summary
KG = args.kg
SAVE_MODEL_NAME_1 = args.name1
SAVE_MODEL_NAME_2 = args.name2
# show the hyperparameters
logging.info(f'epochs: {EPOCHS}')
logging.info(f'batch_size: {BATCH_SIZE}')
logging.info(f'learning_rate: {LEARNING_RATE}')
logging.info(f'seed: {SEED}')
logging.info(f'cuda: {CUDA}')
logging.info(f'summary: {SUMMARY}')
logging.info(f'kg: {KG}')
logging.info(f'save model_1 name: {SAVE_MODEL_NAME_1}')
logging.info(f'save model_2 name: {SAVE_MODEL_NAME_2}')
# setting the seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# set the cuda device
device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
# ignore the warning
warnings.filterwarnings('ignore')
story_summary_dataset = {}
with open(SUMMARY, 'r', encoding='utf8') as f:
story_summary_dataset = {**story_summary_dataset, **json.load(f)}
storybot_env_1 = StoryBotRetellEnv(story_summary_dataset,
dialogue_evalution_model_ckpt='environment/dialogue_evalution/model/dialogue_evalution_model_best.pt',
kg2text_model_ckpt='environment/kg2text/model/kg2text_model.pt',
embedding_model_name='sentence-transformers/all-MiniLM-L6-v2',
device=device)
storybot_env_2 = StoryBotRetellEnv(story_summary_dataset,
dialogue_evalution_model_ckpt='environment/dialogue_evalution/model/dialogue_evalution_model_best.pt',
kg2text_model_ckpt='environment/kg2text/model/kg2text_model.pt',
embedding_model_name='sentence-transformers/all-MiniLM-L6-v2',
device=device)
# set the agent
agent_1 = DQNAgent(env=storybot_env_1, lr=LEARNING_RATE)
agent_2 = DQNAgent(env=storybot_env_2, lr=LEARNING_RATE)
# set the trainer
trainer = DQNTrainer(env1=storybot_env_1,
env2=storybot_env_2,
agent1=agent_1,
agent2=agent_2,
story_summary_path=SUMMARY,
story_kg_path=KG,
epoch=EPOCHS,
batch_size=BATCH_SIZE)
trainer.agent1_model_name = SAVE_MODEL_NAME_1
trainer.agent2_model_name = SAVE_MODEL_NAME_2
logging.info('Start training')
trainer.train()
logging.info('Finish training')