-
Notifications
You must be signed in to change notification settings - Fork 5
/
summarizer.py
184 lines (145 loc) · 5.77 KB
/
summarizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from sentence_graph import SentenceGraph
from utils import *
from sklearn.cluster import SpectralClustering
import torch
import nltk.data
class SummPip():
def __init__(
self,
nb_clusters: int = 14,
nb_words: int = 5,
ita: float = 0.98,
seed: int = 88,
w2v_file: str = "word_vec/multi_news/news_w2v.txt",
lm_path: str = "gpt2/mutli_news",
use_lm: bool = False
):
"""
This is the SummPip class
:param nb_clusters: this determines the number of sentences in the output summary
:param nb_words: this controls the length of each sentence in the output summary
:param ita: threshold for determining whether two sentences are similar by vector similarity
:param seed: the random state to reproduce summarization
:param w2v_file: file for storing w2v matrix
:param lm_path: path for langauge model
:param use_lm: use language model or not
"""
self.nb_clusters = nb_clusters
self.nb_words = nb_words
self.ita = ita
self.seed = seed
self.use_lm = use_lm
if not self.use_lm:
self.w2v = self._get_w2v_embeddings(w2v_file)
self.lm_tokenizer = ""
self.lm_model = ""
else:
from transformers import GPT2Tokenizer, GPT2Model
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_path)
self.lm_model = GPT2Model.from_pretrained(lm_path,
output_hidden_states=True,
output_attentions=False)
self.w2v = ""
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
# set seed
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def _get_w2v_embeddings(self, w2v_file):
"""
Get w2v word embedding matrix
:return: w2v matrix
"""
word_embeddings = {}
f = open(w2v_file, encoding='utf-8')
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
word_embeddings[word] = coefs
f.close()
return word_embeddings
def construct_sentence_graph(self, sentences_list):
"""
Construct a sentence graph
:return: adjacency matrix
"""
graph = SentenceGraph(sentences_list, self.w2v, self.use_lm, self.lm_model, self.lm_tokenizer, self.ita)
X = graph.build_sentence_graph()
return X
def cluster_graph(self, X, sentences_list):
"""
Perform graph clustering
:return: a dictionary with key, value pairs of cluster Id and sentences
"""
# ???? n
clustering = SpectralClustering(n_clusters = self.nb_clusters, random_state = self.seed).fit(X)
clusterIDs = clustering.labels_
num_clusters = max(clusterIDs)+1
cluster_dict={new_list:[] for new_list in range(num_clusters)}
# group sentences by cluster ID
for i, clusterID in enumerate(clusterIDs):
cluster_dict[clusterID].append(sentences_list[i])
return cluster_dict
def convert_sents_to_tagged_sents(self, sent_list):
tagged_list = []
if(len(sent_list)>0):
for s in sent_list:
s = s.replace("/", "")
# print("original sent -------- \n",s)
temp_tagged = tag_pos(s)
tagged_list.append(temp_tagged)
else:
tagged_list.append(tag_pos("."))
return tagged_list
def get_compressed_sen(self, sentences):
compresser = takahe.word_graph(sentence_list = sentences, nb_words = self.nb_words, lang = 'en', punct_tag = "." )
candidates = compresser.get_compression(3)
reranker = takahe.keyphrase_reranker(sentences, candidates, lang = 'en')
reranked_candidates = reranker.rerank_nbest_compressions()
# print(reranked_candidates)
if len(reranked_candidates)>0:
score, path = reranked_candidates[0]
result = ' '.join([u[0] for u in path])
else:
result=' '
return result
def compress_cluster(self, cluster_dict):
"""
Perform cluster compression
:return: a string of concatenated sentences from all clusters
"""
summary = []
for k,v in cluster_dict.items():
tagged_sens = self.convert_sents_to_tagged_sents(v)
compressed_sent = self.get_compressed_sen(tagged_sens)
summary.append(compressed_sent)
return " ".join(summary)
def split_sentences(self, docs):
tag="story_separator_special_tag"
src_list = []
for doc in docs:
doc = doc.replace(tag,"")
sent_list = self.sent_detector.tokenize(doc.strip())
src_list.append(sent_list)
return src_list
def summarize(self, src_list):
"""
Construct a graph, run graph clustering, compress each cluster, then concatenate sentences
:param src_list: a list of input documents each of whose elements is a list of multiple documents
:return: a list of summaries
"""
#TODO: split sentences
summary_list = []
# iterate over all docs
for idx, sentences_list in enumerate(src_list):
num_sents = len(sentences_list)
# handle short doc
if num_sents <= self.nb_clusters:
summary_list.append(" ".join(sentences_list))
print("continue----")
continue
X = self.construct_sentence_graph(sentences_list)
cluster_dict = self.cluster_graph(X, sentences_list)
summary = self.compress_cluster(cluster_dict)
summary_list.append(summary)
return summary_list