-
Notifications
You must be signed in to change notification settings - Fork 421
/
Copy pathweb_demo.py
129 lines (103 loc) · 5.67 KB
/
web_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
import gradio as gr
from PIL import Image
import os
import json
from model import is_chinese, get_infer_setting, generate_input, chat
import torch
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
input_para.update(request_data)
input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
return answer
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
result_text = [(ele[0], ele[1]) for ele in result_previous]
for i in range(len(result_text)-1, -1, -1):
if result_text[i][0] == "" or result_text[i][1] == "":
del result_text[i]
print(f"history {result_text}")
is_zh = is_chinese(input_text)
if image_prompt is None:
if is_zh:
result_text.append((input_text, '图片为空!请上传图片并重试。'))
else:
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
return input_text, result_text
elif input_text == "":
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
return "", result_text
request_para = {"temperature": temperature, "top_p": top_p}
image = Image.open(image_prompt)
try:
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
except Exception as e:
print(f"error: {e}")
if is_zh:
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
else:
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
return "", result_text
result_text.append((input_text, answer))
print(result_text)
return "", result_text
DESCRIPTION = '''# <a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a>'''
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
def clear_fn(value):
return "", [("", "Hi, What do you want to know about this image?")], None
def clear_fn2(value):
return [("", "Hi, What do you want to know about this image?")]
def main(args):
gr.close_all()
global model, tokenizer
model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=4.5):
with gr.Group():
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
with gr.Row():
run_button = gr.Button('Generate')
clear_button = gr.Button('Clear')
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
with gr.Row():
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
with gr.Group():
with gr.Row():
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
with gr.Column(scale=5.5):
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
gr.Markdown(NOTES)
print(gr.__version__)
run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
print(gr.__version__)
demo.queue(concurrency_count=10)
demo.launch(share=args.share)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
main(args)