diff --git a/.gitignore b/.gitignore index 6fbdeca..adc6f37 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,7 @@ spring_ai/create_user.sql spring_ai/drop.sql start.sh spring_ai/env.sh +temp/rag_agent.ipynb +temp/tools.ipynb +temp/tools.py +temp/json-dual.sql diff --git a/README.md b/README.md index 05bd1b5..cadea55 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,12 @@ To run the application on bare-metal; download the [source](https://github.com/o pip3 install -r app/requirements.txt ``` +1. Exit from your shell and run again: + + ```bash + source .venv/bin/activate + ``` + 1. Start Streamlit: ```bash diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 885b49b..4a2cb5e 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -45,8 +45,11 @@ def display_logs(): st.chat_message("human").write(msg["message"]) else: if state.rag_params["enable"]: - st.chat_message("ai").write(msg["answer"]) - st_common.show_rag_refs(msg["context"]) + logger.info("msg[\"answer\"]") + logger.info(msg) + st.chat_message("ai").write(msg) + if "context" in msg and msg["context"]: + st_common.show_rag_refs(msg["context"]) else: st.chat_message("ai").write(msg.content) except api_server.queue.Empty: diff --git a/app/src/content/test_framework.py b/app/src/content/test_framework.py index c6550cb..a5f2504 100644 --- a/app/src/content/test_framework.py +++ b/app/src/content/test_framework.py @@ -188,10 +188,6 @@ def main(): st_common.set_default_state("test_set", None) st_common.set_default_state("temp_dir", None) - # TO CHECK - # file_path = "temp_file.csv" - # if os.path.exists(file_path): - # os.remove(file_path) if state.toggle_generate_test: st.header("Q&A Test Dataset Generation") @@ -256,6 +252,8 @@ def main(): # Generate Q&A qa_file = os.path.join(state["temp_dir"], f"{file_name}_{str(qa_count)}_test_set.json") state.qa_file = qa_file + logger.info("calling generate with client: ") + logger.info(llm_client) state.test_set = utilities.generate_qa(qa_file, kb, qa_count, client=llm_client) placeholder.empty() st.success("Q&A Generation Succeeded.", icon="✅") diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 3fdda66..4dae468 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -103,9 +103,14 @@ def do_POST(self): # pylint: disable=invalid-name try: # Parse the POST data as JSON post_json = json.loads(post_data) + messages = post_json.get("messages") + + # Extract the content of 'user' + message = next(item['content'] for item in messages if item['role'] == 'user') + + logger.info(messages) + logger.info(message) - # Extract the 'message' field from the JSON - message = post_json.get("message") response = None if message: # Log the incoming message @@ -123,6 +128,8 @@ def do_POST(self): # pylint: disable=invalid-name stream=False, ) self.send_response(200) + logger.info("RESPONSE:") + logger.info(response) # Process response to JSON else: json_response = {"error": "No 'message' field found in request."} @@ -163,9 +170,9 @@ def do_POST(self): # pylint: disable=invalid-name for i in range(max_items): chunk = full_context[i] sources.add(os.path.basename(chunk.metadata["source"])) - json_response = {"answer": response["answer"], "sources": list(sources)} + json_response = {"choices": [{"text":response["answer"],"index":0}], "sources": list(sources)} else: - json_response = {"answer": response.content} + json_response = {"choices": [{"text":response["answer"],"index":0}]} self.wfile.write(json.dumps(json_response).encode("utf-8")) diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py index c6a3d4f..c5ae724 100644 --- a/app/src/modules/report_utils.py +++ b/app/src/modules/report_utils.py @@ -120,7 +120,7 @@ def record_update(): "question": state.question_input, "reference_answer": state.reference_answer_input, "reference_context": state.reference_context_input, - "conversation_history": "", + "conversation_history": [], } for key, value in new_data.items(): state.df.at[index, key] = value @@ -129,12 +129,12 @@ def record_update(): # Text input for the question, storing the user's input in the session state state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0])) state.hide_input = st.checkbox("Hide", value=state.hide_input) - state.question_input = st.text_area("question", height=1, value=state.question_input) - state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input) + state.question_input = st.text_area("question", height=100, value=state.question_input) + state.reference_answer_input = st.text_area("Reference answer", height=100, value=state.reference_answer_input) state.reference_context_input = st.text_area( - "Reference context", height=10, value=state.reference_context_input, disabled=True + "Reference context", height=100, value=state.reference_context_input, disabled=True ) - state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True) + state.metadata_input = st.text_area("Metadata", height=100, value=state.metadata_input, disabled=True) if save_clicked: st.success("Q&A saved successfully!") diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index c3540bc..b668043 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -437,8 +437,13 @@ def create_zip(state_dict_filt, provider): def check_hybrid_conf(session_state_json): chatModel = state.ll_model_config.get(session_state_json["ll_model"]) - embModel = state.embed_model_config.get(state.rag_params["model"]) + if "rag_params" in state and "model" in state.rag_params: + embModel = state.embed_model_config.get(state.rag_params["model"]) + else: + # Handle the case where rag_params or "model" key does not exist + embModel = None # or some default value + logger.info("Model: %s",session_state_json["ll_model"]) logger.info("Embedding Model embModel: %s",embModel) logger.info("Chat Model: %s",chatModel) diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index e005e16..6c7ee95 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -82,7 +82,8 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): lm_params = ll_models_config[model] logger.info( - "Configuring LLM - URL: %s; Temp - %s; Max Tokens - %s", + "Configuring LLM - Model: %s; URL: %s; Temp - %s; Max Tokens - %s", + model, lm_params["url"], lm_params["temperature"][0], lm_params["max_tokens"][0], @@ -115,7 +116,7 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): elif llm_api == "ChatPerplexity": client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params) elif llm_api == "ChatOllama": - client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params) + client = ChatOllama(model=model,base_url=lm_params["url"], model_kwargs=common_params) ## End - Add Additional Model Authentication Here api_accessible, err_msg = is_url_accessible(llm_url) diff --git a/widget/app.js b/widget/app.js index 2ce6e8e..7cdbde1 100644 --- a/widget/app.js +++ b/widget/app.js @@ -48,10 +48,17 @@ document.addEventListener('DOMContentLoaded', () => { async function getBotResponse(userMessage) { const apiUrl = chatBotUrl; const messagePayload = { - message: userMessage + model: "", + messages: [ + { + role: "user", + content: userMessage + } + ] }; + try { const response = await fetch(apiUrl, { method: 'POST', @@ -63,7 +70,8 @@ document.addEventListener('DOMContentLoaded', () => { }); const data = await response.json(); - return data.choices[0].message.content; + //return data.answer; + return data.choices[0].text+"\n Source: "+data.sources[0]; } catch (error) { console.error('Error fetching API:', error); return "Sorry, I couldn't connect to the server.";