From 7c199bf05fb7a6c0e3ab66988f66c48fb2d06b0c Mon Sep 17 00:00:00 2001 From: corradodebari Date: Fri, 15 Nov 2024 12:50:55 +0100 Subject: [PATCH 1/8] top_p/check_hybrid --- app/src/modules/metadata.py | 14 +++++++------- app/src/modules/st_common.py | 8 +++----- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/app/src/modules/metadata.py b/app/src/modules/metadata.py index 4abfd12..1132121 100644 --- a/app/src/modules/metadata.py +++ b/app/src/modules/metadata.py @@ -95,7 +95,7 @@ def ll_models(): "openai_compat": False, "context_length": 127072, "temperature": [0.3, 0.3, 0.0, 2.0], - "top_p": [0.75, 0.75, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [100, 100, 1, 4096], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -108,7 +108,7 @@ def ll_models(): "openai_compat": True, "context_length": 4191, "temperature": [1.0, 1.0, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 4096], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -121,7 +121,7 @@ def ll_models(): "openai_compat": True, "context_length": 127072, "temperature": [1.0, 1.0, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 4096], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -134,7 +134,7 @@ def ll_models(): "openai_compat": True, "context_length": 127072, "temperature": [1.0, 1.0, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 8191], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -147,7 +147,7 @@ def ll_models(): "openai_compat": True, "context_length": 127072, "temperature": [1.0, 1.0, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 4095], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -160,7 +160,7 @@ def ll_models(): "openai_compat": False, "context_length": 127072, "temperature": [0.2, 0.2, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 28000], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], @@ -173,7 +173,7 @@ def ll_models(): "openai_compat": False, "context_length": 127072, "temperature": [0.2, 0.2, 0.0, 2.0], - "top_p": [0.9, 0.9, 0.0, 1.0], + "top_p": [1.0, 1.0, 0.0, 1.0], "max_tokens": [256, 256, 1, 28000], "frequency_penalty": [0.0, 0.0, -1.0, 1.0], "presence_penalty": [0.0, 0.0, -2.0, 2.0], diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index fbc3fc8..c3540bc 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -436,11 +436,9 @@ def create_zip(state_dict_filt, provider): # Check if the conf is full ollama or openai, currently supported for springai export def check_hybrid_conf(session_state_json): - embedding_models = meta.embedding_models() - chat_models = meta.ll_models() - - embModel = embedding_models.get(session_state_json["rag_params"].get("model")) - chatModel = chat_models.get(session_state_json["ll_model"]) + chatModel = state.ll_model_config.get(session_state_json["ll_model"]) + embModel = state.embed_model_config.get(state.rag_params["model"]) + logger.info("Model: %s",session_state_json["ll_model"]) logger.info("Embedding Model embModel: %s",embModel) logger.info("Chat Model: %s",chatModel) From 00d6bdb8b768b3c330586c60991a3ddc6508fc68 Mon Sep 17 00:00:00 2001 From: corradodebari Date: Mon, 18 Nov 2024 14:56:33 +0100 Subject: [PATCH 2/8] check embed model fails at startup --- .gitignore | 4 ++++ app/src/modules/st_common.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6fbdeca..adc6f37 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,7 @@ spring_ai/create_user.sql spring_ai/drop.sql start.sh spring_ai/env.sh +temp/rag_agent.ipynb +temp/tools.ipynb +temp/tools.py +temp/json-dual.sql diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index c3540bc..b668043 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -437,8 +437,13 @@ def create_zip(state_dict_filt, provider): def check_hybrid_conf(session_state_json): chatModel = state.ll_model_config.get(session_state_json["ll_model"]) - embModel = state.embed_model_config.get(state.rag_params["model"]) + if "rag_params" in state and "model" in state.rag_params: + embModel = state.embed_model_config.get(state.rag_params["model"]) + else: + # Handle the case where rag_params or "model" key does not exist + embModel = None # or some default value + logger.info("Model: %s",session_state_json["ll_model"]) logger.info("Embedding Model embModel: %s",embModel) logger.info("Chat Model: %s",chatModel) From abab3c1eef122ffae5d9d4085f5fd0deb135981d Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 19 Nov 2024 11:40:20 +0100 Subject: [PATCH 3/8] ollama fix --- README.md | 6 ++++++ app/src/content/test_framework.py | 6 ++---- app/src/modules/utilities.py | 5 +++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 05bd1b5..cadea55 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,12 @@ To run the application on bare-metal; download the [source](https://github.com/o pip3 install -r app/requirements.txt ``` +1. Exit from your shell and run again: + + ```bash + source .venv/bin/activate + ``` + 1. Start Streamlit: ```bash diff --git a/app/src/content/test_framework.py b/app/src/content/test_framework.py index c6550cb..a5f2504 100644 --- a/app/src/content/test_framework.py +++ b/app/src/content/test_framework.py @@ -188,10 +188,6 @@ def main(): st_common.set_default_state("test_set", None) st_common.set_default_state("temp_dir", None) - # TO CHECK - # file_path = "temp_file.csv" - # if os.path.exists(file_path): - # os.remove(file_path) if state.toggle_generate_test: st.header("Q&A Test Dataset Generation") @@ -256,6 +252,8 @@ def main(): # Generate Q&A qa_file = os.path.join(state["temp_dir"], f"{file_name}_{str(qa_count)}_test_set.json") state.qa_file = qa_file + logger.info("calling generate with client: ") + logger.info(llm_client) state.test_set = utilities.generate_qa(qa_file, kb, qa_count, client=llm_client) placeholder.empty() st.success("Q&A Generation Succeeded.", icon="✅") diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index e005e16..6c7ee95 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -82,7 +82,8 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): lm_params = ll_models_config[model] logger.info( - "Configuring LLM - URL: %s; Temp - %s; Max Tokens - %s", + "Configuring LLM - Model: %s; URL: %s; Temp - %s; Max Tokens - %s", + model, lm_params["url"], lm_params["temperature"][0], lm_params["max_tokens"][0], @@ -115,7 +116,7 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): elif llm_api == "ChatPerplexity": client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params) elif llm_api == "ChatOllama": - client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params) + client = ChatOllama(model=model,base_url=lm_params["url"], model_kwargs=common_params) ## End - Add Additional Model Authentication Here api_accessible, err_msg = is_url_accessible(llm_url) From f909d4f54f46e65306881a4e68163b5cc534ae43 Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 19 Nov 2024 15:03:38 +0100 Subject: [PATCH 4/8] report_utils streamlit 1.40.1 bug fix: min height in text_area is 68 now --- app/src/modules/report_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py index c6a3d4f..645c75d 100644 --- a/app/src/modules/report_utils.py +++ b/app/src/modules/report_utils.py @@ -129,12 +129,12 @@ def record_update(): # Text input for the question, storing the user's input in the session state state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0])) state.hide_input = st.checkbox("Hide", value=state.hide_input) - state.question_input = st.text_area("question", height=1, value=state.question_input) - state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input) + state.question_input = st.text_area("question", height=100, value=state.question_input) + state.reference_answer_input = st.text_area("Reference answer", height=100, value=state.reference_answer_input) state.reference_context_input = st.text_area( - "Reference context", height=10, value=state.reference_context_input, disabled=True + "Reference context", height=100, value=state.reference_context_input, disabled=True ) - state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True) + state.metadata_input = st.text_area("Metadata", height=100, value=state.metadata_input, disabled=True) if save_clicked: st.success("Q&A saved successfully!") From 7e796f7414cf78849fce3ebc9be3e4e3f497ec2c Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 19 Nov 2024 16:08:19 +0100 Subject: [PATCH 5/8] fix widget --- widget/app.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/widget/app.js b/widget/app.js index 2ce6e8e..f90ff13 100644 --- a/widget/app.js +++ b/widget/app.js @@ -63,7 +63,8 @@ document.addEventListener('DOMContentLoaded', () => { }); const data = await response.json(); - return data.choices[0].message.content; + return data.answer; + //return data.choices[0].message.content; } catch (error) { console.error('Error fetching API:', error); return "Sorry, I couldn't connect to the server."; From 9f0c3be85a135162354b39727cf6f30f5c809a26 Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 19 Nov 2024 18:14:09 +0100 Subject: [PATCH 6/8] backport giskard the new libraries 2.15.5 fails in Llama3.1 agent test. --- app/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/requirements.txt b/app/requirements.txt index 28ff963..c2e7bcb 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -14,7 +14,7 @@ bokeh==3.6.1 evaluate==0.4.3 faiss-cpu==1.9.0 -giskard==2.15.5 +giskard==2.15.2 IPython==8.29.0 langchain-cohere==0.3.1 langchain-community==0.3.7 From 745f494187fc43cb06618903b772e738477c7a5b Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 19 Nov 2024 22:14:30 +0100 Subject: [PATCH 7/8] fix report fix an issue in changing dataset generated --- app/src/modules/report_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py index 645c75d..c5ae724 100644 --- a/app/src/modules/report_utils.py +++ b/app/src/modules/report_utils.py @@ -120,7 +120,7 @@ def record_update(): "question": state.question_input, "reference_answer": state.reference_answer_input, "reference_context": state.reference_context_input, - "conversation_history": "", + "conversation_history": [], } for key, value in new_data.items(): state.df.at[index, key] = value From 74b47ddcaad2536dcc395e4fa07ca6ab27553f2e Mon Sep 17 00:00:00 2001 From: corradodebari Date: Wed, 20 Nov 2024 17:09:11 +0100 Subject: [PATCH 8/8] OpenAI API compatibility and Giskard upgrade Update API Server to accept/response in OpenAI API like signature. Giskard has been updated to the latest release too. --- app/requirements.txt | 2 +- app/src/content/api_server.py | 7 +++++-- app/src/modules/api_server.py | 15 +++++++++++---- widget/app.js | 13 ++++++++++--- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/app/requirements.txt b/app/requirements.txt index c2e7bcb..28ff963 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -14,7 +14,7 @@ bokeh==3.6.1 evaluate==0.4.3 faiss-cpu==1.9.0 -giskard==2.15.2 +giskard==2.15.5 IPython==8.29.0 langchain-cohere==0.3.1 langchain-community==0.3.7 diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 885b49b..4a2cb5e 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -45,8 +45,11 @@ def display_logs(): st.chat_message("human").write(msg["message"]) else: if state.rag_params["enable"]: - st.chat_message("ai").write(msg["answer"]) - st_common.show_rag_refs(msg["context"]) + logger.info("msg[\"answer\"]") + logger.info(msg) + st.chat_message("ai").write(msg) + if "context" in msg and msg["context"]: + st_common.show_rag_refs(msg["context"]) else: st.chat_message("ai").write(msg.content) except api_server.queue.Empty: diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 3fdda66..4dae468 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -103,9 +103,14 @@ def do_POST(self): # pylint: disable=invalid-name try: # Parse the POST data as JSON post_json = json.loads(post_data) + messages = post_json.get("messages") + + # Extract the content of 'user' + message = next(item['content'] for item in messages if item['role'] == 'user') + + logger.info(messages) + logger.info(message) - # Extract the 'message' field from the JSON - message = post_json.get("message") response = None if message: # Log the incoming message @@ -123,6 +128,8 @@ def do_POST(self): # pylint: disable=invalid-name stream=False, ) self.send_response(200) + logger.info("RESPONSE:") + logger.info(response) # Process response to JSON else: json_response = {"error": "No 'message' field found in request."} @@ -163,9 +170,9 @@ def do_POST(self): # pylint: disable=invalid-name for i in range(max_items): chunk = full_context[i] sources.add(os.path.basename(chunk.metadata["source"])) - json_response = {"answer": response["answer"], "sources": list(sources)} + json_response = {"choices": [{"text":response["answer"],"index":0}], "sources": list(sources)} else: - json_response = {"answer": response.content} + json_response = {"choices": [{"text":response["answer"],"index":0}]} self.wfile.write(json.dumps(json_response).encode("utf-8")) diff --git a/widget/app.js b/widget/app.js index f90ff13..7cdbde1 100644 --- a/widget/app.js +++ b/widget/app.js @@ -48,10 +48,17 @@ document.addEventListener('DOMContentLoaded', () => { async function getBotResponse(userMessage) { const apiUrl = chatBotUrl; const messagePayload = { - message: userMessage + model: "", + messages: [ + { + role: "user", + content: userMessage + } + ] }; + try { const response = await fetch(apiUrl, { method: 'POST', @@ -63,8 +70,8 @@ document.addEventListener('DOMContentLoaded', () => { }); const data = await response.json(); - return data.answer; - //return data.choices[0].message.content; + //return data.answer; + return data.choices[0].text+"\n Source: "+data.sources[0]; } catch (error) { console.error('Error fetching API:', error); return "Sorry, I couldn't connect to the server.";