diff --git a/web_demo_streamlit-2_5.py b/web_demo_streamlit-2_5.py index bf100ce..5182a39 100644 --- a/web_demo_streamlit-2_5.py +++ b/web_demo_streamlit-2_5.py @@ -4,11 +4,11 @@ from transformers import AutoModel, AutoTokenizer # Model path -model_path = "openbmb/MiniCPM-Llama3-V-2_5" +model_path = "./MiniCPM-Llama3-V-2_5-int4" # User and assistant names -U_NAME = "User" -A_NAME = "Assistant" +U_NAME = "YOU" +A_NAME = "AI" # Set page configuration st.set_page_config( @@ -22,7 +22,7 @@ @st.cache_resource def load_model_and_tokenizer(): print(f"load_model_and_tokenizer from {model_path}") - model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device="cuda") + model = AutoModel.from_pretrained(model_path, trust_remote_code=True,) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return model, tokenizer @@ -89,20 +89,34 @@ def load_model_and_tokenizer(): model = st.session_state.model tokenizer = st.session_state.tokenizer - with st.chat_message(A_NAME, avatar="assistant"): - # If the previous message contains an image, pass the image to the model - if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None: - uploaded_image = st.session_state.chat_history[-2]["image"] - imagefile = Image.open(uploaded_image).convert('RGB') + if selected_mode == "Image": + with st.chat_message(A_NAME, avatar="assistant"): + # If the previous message contains an image, pass the image to the model + if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None: + uploaded_image = st.session_state.chat_history[-2]["image"] + imagefile = Image.open(uploaded_image).convert('RGB') - msgs = [{"role": "user", "content": user_text}] - res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer, - sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, - temperature=temperature, stream=True) + msgs = [{"role": "user", "content": user_text}] + res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer, + sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, + temperature=temperature, stream=True) - # Collect the generated_text str - generated_text = st.write_stream(res) + # Collect the generated_text str + generated_text = st.write_stream(res) - st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None}) + st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None}) - st.divider() + st.divider() + else: + with st.chat_message(A_NAME, avatar="assistant"): + msgs = [{"role": "user", "content": user_text}] + res = model.chat(image=None,msgs=msgs, context=None, tokenizer=tokenizer, + sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, + temperature=temperature, stream=True) + + # Collect the generated_text str + generated_text = st.write_stream(res) + + st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None}) + + st.divider()