-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathprepare_data_ectbps_para_mask.py
82 lines (71 loc) · 2.93 KB
/
prepare_data_ectbps_para_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from utils import *
# Paraphrasing
# Doc - All lines that cover target summary sentences.
# Summ - Corresponding factually grounded sentences.
# Numbers in both documents and summaries are masked by placeholders.
def getMaskedLines(d_lines, s_lines):
dlines_masked, slines_masked = [], []
for dline, sline in zip(d_lines, s_lines):
dlines_num = {}
count = 1
dline = getPPText(dline)
for val in re.findall(pattern7, dline):
if val not in dlines_num:
dlines_num[val] = f'num-{num2words(count)}'
count += 1
vals = re.findall(pattern7, dline)
for val in vals:
if '.' in val and len([v for v in vals if val != v and val in v]) == 0:
dline = dline.replace(val, dlines_num[val])
for val in vals:
if '.' in val:
dline = dline.replace(val, dlines_num[val])
vals = re.findall(pattern7, dline)
for val in vals:
if len([v for v in vals if val != v and val in v]) == 0:
dline = dline.replace(val, dlines_num[val])
for val in vals:
dline = dline.replace(val, dlines_num[val])
dlines_masked.append(dline)
sline = getPPText(sline)
vals = re.findall(pattern7, sline)
for val in vals:
if '.' in val and len([v for v in vals if val != v and val in v]) == 0:
sline = sline.replace(val, dlines_num[val]) if val in dlines_num else sline
for val in vals:
if '.' in val:
sline = sline.replace(val, dlines_num[val]) if val in dlines_num else sline
vals = re.findall(pattern7, sline)
for val in vals:
if len([v for v in vals if val != v and val in v]) == 0:
sline = sline.replace(val, dlines_num[val]) if val in dlines_num else sline
for val in vals:
sline = sline.replace(val, dlines_num[val]) if val in dlines_num else sline
slines_masked.append(sline)
return dlines_masked, slines_masked
def prepare_data(dataPath, out_path):
source_path = f'{dataPath}/source/'
target_path = f'{dataPath}/target/'
if not os.path.isdir(f'{out_path}/source/'):
os.makedirs(f'{out_path}/source/')
if not os.path.isdir(f'{out_path}/target/'):
os.makedirs(f'{out_path}/target/')
for file in os.listdir(source_path):
if file.endswith('.txt'):
print(file)
f_ect_in = open(f'{source_path}{file}', 'r')
doc_lines = [line.strip() for line in f_ect_in.readlines()]
f_summ_in = open(f'{target_path}{file}', 'r')
summ_lines = [line.strip() for line in f_summ_in.readlines()]
assert len(doc_lines) == len(summ_lines)
d_lines, s_lines = getMaskedLines(doc_lines, summ_lines)
assert len(d_lines) == len(s_lines)
doc_out = open(f'{out_path}/source/{file}', 'w')
doc_out.write('\n'.join(d_lines))
summ_out = open(f'{out_path}/target/{file}', 'w')
summ_out.write('\n'.join(s_lines))
doc_out.close()
summ_out.close()
for split in ['train', 'val']:
print(f'\n\n Preparing {split} data..\n')
prepare_data(f'codes/ECT-BPS/ectbps_para/data/para/{split}', f'codes/ECT-BPS/ectbps_para/data/para_mask/{split}')