forked from google-research/bert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi-server.py
52 lines (43 loc) · 1.58 KB
/
api-server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from fastapi import FastAPI
from pydantic import BaseModel
import json
import os
import requests
import tokenization
app = FastAPI()
class TextContainer(BaseModel):
text: str
VOCAB_FILE_PATH = os.getenv('VOCAB_FILE') or "./weights_base/uncased_L-12_H-768_A-12/vocab.txt"
SERVE_API = os.getenv('SERVE_API_HOST') or "localhost"
SERVE_API_PORT = os.getenv('SERVE_API_PORT') or "8501"
@app.post('/predict')
def predict(body: TextContainer):
serve_endpoint = f"http://{SERVE_API}:{SERVE_API_PORT}/v1/models/bert:predict"
headers = {"content-type":"application-json"}
tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE_PATH, do_lower_case=True)
token_a = tokenizer.tokenize(body.text)
tokens = []
segments_ids = []
tokens.append("[CLS]")
segment_ids = []
segment_ids.append(0)
for token in token_a:
tokens.append(token)
segment_ids.append(0)
tokens.append('[SEP]')
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
max_seq_length = 128
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_id = 0
instances = [{"input_ids":input_ids, "input_mask":input_mask, "segment_ids":segment_ids, "label_ids":label_id}]
data = json.dumps({"signature_name":"serving_default", "instances":instances})
response = requests.post(serve_endpoint, data=data, headers=headers)
prediction = json.loads(response.text)['predictions']
return {
'predictions': prediction
}