-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_text.py
27 lines (24 loc) · 982 Bytes
/
generate_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import re
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('trained_model',
return_dict=True)
# ------------------------------------------------------------------ #
# generate text after correct preprocessing data #
# ------------------------------------------------------------------ #
def generate(text):
"""
generate new text by using text
input: text - keywords
output: result - text from keywords
"""
texts = text.split(".")
result = ""
for txt in texts:
model.eval()
input_ids = tokenizer.encode("WebNLG:{} </s>".format(txt),
return_tensors="pt")
outputs = model.generate(input_ids)
result += tokenizer.decode(outputs[0])
result = re.sub('<pad>|</s>',"",result)
return result