-
Notifications
You must be signed in to change notification settings - Fork 4
/
app.py
94 lines (72 loc) · 2.44 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Needs code structuring
Date - 08/14/2020
"""
import torch
import logging
import sys
from flask import Flask, render_template, request
from utils.dataloader2 import Dataloader
from models.models import LSTMTagger
from config.config import Configuration
app = Flask(__name__)
def get_logger():
logger = logging.getLogger("logger")
logger.setLevel(logging.DEBUG)
logging.basicConfig(format="%(message)s", level=logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
"%(levelname)s:%(message)s"
))
logging.getLogger().addHandler(handler)
return logger
def get_config():
config_file = "./config/config.ini"
logger = get_logger()
config = Configuration(config_file=config_file, logger=logger)
config.model_file = './saved_models/lstm_1.pth'
config.vocab_file = './vocab/vocab.pkl'
config.label_file = './vocab/labels.pkl'
config.device = 'cpu'
config.verbose = False
config.eval = False
config.use_pos = False
config.infer = True
return config
def pred_to_tag(dataloader, predictions):
return ' '.join([dataloader.label_field.vocab.itos[i] for i in predictions]).split()
def infer(config, dataloader, model):
sent_tok = config.txt
X = [dataloader.txt_field.vocab.stoi[t] for t in sent_tok]
X = torch.LongTensor(X).to(config.device)
X = X.unsqueeze(0)
pred = model(X, None)
pred_idx = torch.max(pred, 1)[1]
y_pred_val = pred_idx.cpu().data.numpy().tolist()
pred_tag = pred_to_tag(dataloader, y_pred_val)
return pred_tag
# Inference section
def inference(config):
dataloader = Dataloader(config, '1')
model = LSTMTagger(config, dataloader).to(config.device)
# Load the model trained on gpu, to currently specified device
model.load_state_dict(torch.load(config.model_file, map_location=config.device)['state_dict'])
pred_tag = infer(config, dataloader, model)
return pred_tag
@app.route('/')
def hello():
return render_template('index.html')
@app.route('/post', methods=['GET', 'POST'])
def post():
config = get_config()
errors = []
text = request.form['input']
config.txt = text.split()
res = inference(config)
results = zip(config.txt, res)
if request.method == "GET":
return render_template('index.html')
else:
return render_template('index.html', errors=errors, results=results)
if __name__ == "__main__":
app.run()