-
Notifications
You must be signed in to change notification settings - Fork 8
/
dialog_gen_qk.py
77 lines (69 loc) · 2.58 KB
/
dialog_gen_qk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
sys.path.append("/home/wyshi/simulator")
from simulator.loose_user import LooseUser
# from simulator.user import Goal
from simulator.user import User
from simulator.system import System
from simulator.loose_system import LooseSystem
from sequicity_user.seq_user import Seq_User
from sequicity_user.seq_user_act import Seq_User_Act
from simulator.env_for_evaluation import Enviroment
import simulator.dialog_config as dialog_config
import numpy as np
from simulator.agent.core import SystemAct
from config import Config
from evaluation.config import Config as evaluation_config
from tqdm import tqdm
eval_config = evaluation_config()
config = Config()
if eval_config.rule_policy:
if eval_config.nlg_template:
user = LooseUser(nlg_sample=False)
elif eval_config.nlg_sample:
user = LooseUser(nlg_sample=True)
elif eval_config.nlg_generation:
pass
else:
if eval_config.nlg_template:
user = Seq_User_Act(nlg_sample=False)
elif eval_config.nlg_sample:
user = Seq_User_Act(nlg_sample=True)
elif eval_config.nlg_generation:
user = Seq_User()
system = System(config=config) # sequicity system
env = Enviroment(user=user, system=system, verbose=True, config=config)
sys_act = None
status = []
MODE = dialog_config.RL_WARM_START#RANDOM_ACT#RL_WARM_START#RANDOM_ACT#RL_WARM_START#INTERACTIVE#RL_TRAINING#RANDOM_ACT#RL_WARM_START
for _ in tqdm(range(100)):
print("-"*20)
usr_act_seq = []
next_state = env.reset(mode=MODE)
usr_act_seq.append(env.last_usr_act_true)
# print("*"*20)
# print(accum_slots(usr_act_seq))
# print("*"*20)
sys_act = None # initial sys act
total_rewards = 0
while True:
provided_sys_act = None
next_state, reward, done = env.step(provided_sys_act=provided_sys_act, mode=MODE)
print("env.last_usr_act_true", env.last_usr_act_true)
usr_act_seq.append(env.last_usr_act_true)
# print("*" * 20)
# print(accum_slots(usr_act_seq))
# print("per turn reward", reward)
# print("*" * 20)
total_rewards += reward
# usr_act, usr_sent = user.respond(sys_act=sys_act)
# sys_act, sys_sent = system.respond(usr_sent=usr_sent, warm_start=True, usr_act=usr_act)
# sys_act = next_sys_act
print("user turn status: ", env.user.dialog_status)
if done:
status.append(user.dialog_status)
# assert env.success
print('dialog_status: {}'.format(env.success))
print('reward: {}'.format(total_rewards))
print("-"*20)
print("\n\n\n")
break