From cbeb265abfb2aa3270540050b3a1af089e1eb190 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Mon, 21 Oct 2024 11:51:20 +0000 Subject: [PATCH 01/13] restruct --- app/requirements.txt | 1 + app/src/content/api_server.py | 146 ++++++++++++----------- app/src/content/api_server.py.bak | 190 ++++++++++++++++++++++++++++++ app/src/content/chatbot.py | 47 +------- app/src/modules/api_server.py | 23 ++-- app/src/modules/st_common.py | 39 ++++++ app/src/modules/utilities.py | 2 +- test.sh | 5 + 8 files changed, 322 insertions(+), 131 deletions(-) create mode 100644 app/src/content/api_server.py.bak create mode 100644 test.sh diff --git a/app/requirements.txt b/app/requirements.txt index 6397168..8244fbb 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -21,6 +21,7 @@ langchain-community==0.3.2 langchain-huggingface==0.1.0 langchain-ollama==0.2.0 langchain-openai==0.2.2 +langgraph==0.2.39 llama_index==0.11.18 lxml==5.3.0 matplotlib==3.9.2 diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 17f2d74..22c8c67 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -2,11 +2,12 @@ Copyright (c) 2023, 2024, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# spell-checker:ignore streamlit, langchain, llms -# spell-checker:ignore streamlit import inspect -import threading import time +import threading +import json # Streamlit import streamlit as st @@ -15,9 +16,11 @@ # Utilities import modules.st_common as st_common import modules.api_server as api_server - import modules.logging_config as logging_config +# History +from langchain_community.chat_message_histories import StreamlitChatMessageHistory + logger = logging_config.logging.getLogger("api_server") @@ -30,23 +33,26 @@ def initialize_streamlit(): state.api_server_config = api_server.config() logger.info("initialized API Server Config") - -def display_logs(): - log_placeholder = st.empty() # A placeholder to update logs - logs = [] # Store logs for display - - try: - while "server_thread" in st.session_state: - try: - # Retrieve log from queue (non-blocking) - log_item = api_server.log_queue.get_nowait() - logs.append(log_item) - # Update the placeholder with new logs - log_placeholder.text("\n".join(logs)) - except api_server.queue.Empty: - time.sleep(0.1) # Avoid busy-waiting - finally: - logger.info("API Server events display has stopped.") +# def display_logs(chat_history): +# log_placeholder = st.empty() # A placeholder to update logs +# logs = [] # Store logs for display + +# try: +# while "server_thread" in st.session_state: +# try: +# # Retrieve log from queue (non-blocking) +# log_item = api_server.log_queue.get_nowait() +# if log_item['type'] == "AIMessageChunk": +# st.chat_message("ai").write(log_item['message']) +# except (KeyError, TypeError): +# st.chat_message("human").write(log_item['message']) +# #logs.append(log_item) +# # Update the placeholder with new logs +# # log_placeholder.text("\n".join(logs)) +# except api_server.queue.Empty: +# time.sleep(0.1) # Avoid busy-waiting +# finally: +# logger.info("API Server events display has stopped.") def api_server_start(): @@ -106,14 +112,55 @@ def api_server_stop(): ############################################################################# def main(): """Streamlit GUI""" - initialize_streamlit() st.header("API Server") - - # LLM Params - ll_model = st_common.lm_sidebar() - # Initialize RAG st_common.initialize_rag() + # Setup History + chat_history = StreamlitChatMessageHistory(key="api_chat_history") + + ######################################################################### + # Sidebar Settings + ######################################################################### + enabled_llms = sum(model_info["enabled"] for model_info in state.ll_model_config.values()) + if enabled_llms > 0: + initialize_streamlit() + server_running = False + if "server_thread" in state: + server_running = True + st.success("API Server is Running") + + left, right = st.columns([0.2, 0.8]) + left.number_input( + "API Server Port:", + value=state.api_server_config["port"], + min_value=1, + max_value=65535, + key="user_api_server_port", + disabled=server_running, + ) + right.text_input( + "API Server Key:", + type="password", + value=state.api_server_config["key"], + key="user_api_server_key", + disabled=server_running, + ) + enable_history = st.sidebar.checkbox( + "Enable History and Context?", + value=True, + key="user_chat_history", + ) + if st.sidebar.button("Clear History", disabled=not enable_history): + chat_history.clear() + st.sidebar.divider() + ll_model = st_common.lm_sidebar() + if "server_thread" in state: + st.button("Stop Server", type="primary", on_click=api_server_stop) + elif "initialized" in state and state.initialized: + st.button("Start Server", type="primary", on_click=api_server_start) + else: + st.error("No chat models are configured and/or enabled.", icon="🚨") + st.stop() # RAG st_common.rag_sidebar() @@ -130,10 +177,6 @@ def main(): state.initialized = True st_common.update_rag() logger.debug("Force rerun to save state") - if "server_thread" in state: - logger.info("Restarting API Server") - api_server_stop() - api_server_start() st.rerun() except Exception as ex: logger.exception(ex, exc_info=False) @@ -143,47 +186,14 @@ def main(): st.rerun() st.stop() else: - # RAG Enabled but not configured - if "server_thread" in state: - logger.info("Stopping API Server") - api_server_stop() - + st.error("Not all required RAG options are set, please review or disable RAG.") ######################################################################### - # API Server + # API Server Centre ######################################################################### - server_running = False - if "server_thread" in state: - server_running = True - st.success("API Server is Running") - - left, right = st.columns([0.2, 0.8]) - left.number_input( - "API Server Port:", - value=state.api_server_config["port"], - min_value=1, - max_value=65535, - key="user_api_server_port", - disabled=server_running, - ) - right.text_input( - "API Server Key:", - type="password", - value=state.api_server_config["key"], - key="user_api_server_key", - disabled=server_running, - ) - - if "server_thread" in state: - st.button("Stop Server", type="primary", on_click=api_server_stop) - elif "initialized" in state and state.initialized: - st.button("Start Server", type="primary", on_click=api_server_start) - else: - st.error("Not all required RAG options are set, please review or disable RAG.") - - st.subheader("Activity") - if "server_thread" in state: - with st.container(border=True): - display_logs() + # if "server_thread" in state: + # st.subheader("Activity") + # with st.container(border=True): + # display_logs(chat_history) if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: diff --git a/app/src/content/api_server.py.bak b/app/src/content/api_server.py.bak new file mode 100644 index 0000000..7987193 --- /dev/null +++ b/app/src/content/api_server.py.bak @@ -0,0 +1,190 @@ +""" +Copyright (c) 2023, 2024, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" + +# spell-checker:ignore streamlit +import inspect +import threading +import time + +# Streamlit +import streamlit as st +from streamlit import session_state as state + +# Utilities +import modules.st_common as st_common +import modules.api_server as api_server + +import modules.logging_config as logging_config + +logger = logging_config.logging.getLogger("api_server") + + +############################################################################# +# Functions +############################################################################# +def initialize_streamlit(): + """initialize Streamlit Session State""" + if "api_server_config" not in state: + state.api_server_config = api_server.config() + logger.info("initialized API Server Config") + + +def display_logs(): + log_placeholder = st.empty() # A placeholder to update logs + logs = [] # Store logs for display + + try: + while "server_thread" in st.session_state: + try: + # Retrieve log from queue (non-blocking) + log_item = api_server.log_queue.get_nowait() + logs.append(log_item) + # Update the placeholder with new logs + log_placeholder.text("\n".join(logs)) + except api_server.queue.Empty: + time.sleep(0.1) # Avoid busy-waiting + finally: + logger.info("API Server events display has stopped.") + + +def api_server_start(): + state.api_server_config["port"] = state.user_api_server_port + state.api_server_config["key"] = state.user_api_server_key + if "initialized" in state and state.initialized: + if "server_thread" not in state: + state.httpd = api_server.run_server( + state.api_server_config["port"], + state.chat_manager, + state.rag_params, + state.lm_instr, + state.context_instr, + state.api_server_config["key"], + ) + + # Start the server in the thread + def api_server_process(httpd): + httpd.serve_forever() + + state.server_thread = threading.Thread( + target=api_server_process, + # Trailing , ensures tuple is passed + args=(state.httpd,), + daemon=True, + ) + state.server_thread.start() + logger.info("Started API Server on port: %i", state.api_server_config["port"]) + else: + st.warning("API Server is already running.") + else: + logger.warning("Unable to start API Server; ChatMgr not configured") + st.warning("Failed to start API Server: Chatbot not initialized.") + + +def api_server_stop(): + if "server_thread" in state: + if state.server_thread.is_alive(): + state.httpd.shutdown() # Shut down the server + state.server_thread.join() # Wait for the thread to exit + + del state.server_thread # Clean up the thread reference + del state.httpd # Clean up the server reference + + logger.info("API Server stopped successfully.") + st.success("API Server stopped successfully.") + else: + logger.warning("API Server thread is not running - cleaning up.") + st.warning("API Server thread is not running.") + del state.server_thread + else: + logger.info("Unable to stop API Server - not running.") + + +############################################################################# +# MAIN +############################################################################# +def main(): + """Streamlit GUI""" + initialize_streamlit() + + + # LLM Params + ll_model = st_common.lm_sidebar() + + # Initialize RAG + st_common.initialize_rag() + + # RAG + st_common.rag_sidebar() + + ######################################################################### + # Initialize the Client + ######################################################################### + if "initialized" not in state: + if not state.rag_params["enable"] or all( + state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"] + ): + try: + state.chat_manager = st_common.initialize_chatbot(ll_model) + state.initialized = True + st_common.update_rag() + logger.debug("Force rerun to save state") + if "server_thread" in state: + logger.info("Restarting API Server") + api_server_stop() + api_server_start() + st.rerun() + except Exception as ex: + logger.exception(ex, exc_info=False) + st.error(f"Failed to initialize the chat client: {ex}") + st_common.clear_initialized() + if st.button("Retry", key="retry_initialize"): + st.rerun() + st.stop() + else: + # RAG Enabled but not configured + if "server_thread" in state: + logger.info("Stopping API Server") + api_server_stop() + + ######################################################################### + # API Server + ######################################################################### + server_running = False + if "server_thread" in state: + server_running = True + st.success("API Server is Running") + + left, right = st.columns([0.2, 0.8]) + left.number_input( + "API Server Port:", + value=state.api_server_config["port"], + min_value=1, + max_value=65535, + key="user_api_server_port", + disabled=server_running, + ) + right.text_input( + "API Server Key:", + type="password", + value=state.api_server_config["key"], + key="user_api_server_key", + disabled=server_running, + ) + + if "server_thread" in state: + st.button("Stop Server", type="primary", on_click=api_server_stop) + elif "initialized" in state and state.initialized: + st.button("Start Server", type="primary", on_click=api_server_start) + else: + st.error("Not all required RAG options are set, please review or disable RAG.") + + st.subheader("Activity") + if "server_thread" in state: + with st.container(border=True): + display_logs() + + +if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: + main() diff --git a/app/src/content/chatbot.py b/app/src/content/chatbot.py index 12e0953..ee205db 100644 --- a/app/src/content/chatbot.py +++ b/app/src/content/chatbot.py @@ -20,51 +20,6 @@ logger = logging_config.logging.getLogger("chatbot") - -############################################################################# -# Functions -############################################################################# -def show_refs(context): - """When RAG Enabled, show the references""" - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) - - column_sizes = [10, 8, 8, 8, 2, 24] - cols = st.columns(column_sizes) - # Create a button in each column - links = set() - with cols[0]: - st.markdown("**References:**") - # Limit the maximum number of items to 3 (counting from 0) - max_items = min(len(context), 3) - - # Loop through the chunks and display them - for i in range(max_items): - with cols[i + 1]: - chunk = context[i] - links.add(chunk.metadata["source"]) - with st.popover(f"Ref: {i+1}"): - st.markdown(chunk.metadata["source"]) - st.markdown(chunk.page_content) - st.markdown(chunk.metadata["id"]) - - for link in links: - st.markdown("- " + link) - - ############################################################################# # MAIN ############################################################################# @@ -156,7 +111,7 @@ def main(): full_context = chunk["context"] message_placeholder.markdown(full_answer) if full_context: - show_refs(full_context) + st_common.show_rag_refs(full_context) else: st.chat_message("ai").write_stream(response) except Exception as ex: diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 179fb6c..57c93b9 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -12,15 +12,12 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import urlparse -from langchain_community.chat_message_histories import StreamlitChatMessageHistory - # Utilities import modules.chatbot as chatbot import modules.logging_config as logging_config logger = logging_config.logging.getLogger("modules.api_server") - # Create a queue to store the requests and responses log_queue = queue.Queue() @@ -49,6 +46,8 @@ def generate_api_key(length=32): def get_answer_fn( + + question: str, history=None, chat_manager=None, @@ -57,29 +56,20 @@ def get_answer_fn( context_instr=None, ) -> str: """Send for completion""" - # Format appropriately the history for your RAG agent - chat_history_api = StreamlitChatMessageHistory(key="empty") - chat_history_api.clear() - if history: - for h in history: - if h["role"] == "assistant": - chat_history_api.add_ai_message(h["content"]) - else: - chat_history_api.add_user_message(h["content"]) - try: response = chatbot.generate_response( chat_mgr=chat_manager, input=question, - chat_history=chat_history_api, - enable_history=True, + chat_history={}, + enable_history=False, rag_params=rag_params, chat_instr=lm_instr, context_instr=context_instr, stream=False, ) - logger.info("MSG from Chatbot API: %s", response) + logger.info("Raw MSG from Chatbot API: %s", response) if rag_params["enable"]: + return {"messages": response} return response["answer"] else: return response.content @@ -98,6 +88,7 @@ def __init__( self.lm_instr = lm_instr self.context_instr = context_instr self.api_key = api_key + self.workflow = StateGraph(state_schema=MessagesState) super().__init__(*args, **kwargs) def do_OPTIONS(self): # pylint: disable=invalid-name diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index 7160f92..c1bb735 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -140,6 +140,45 @@ def initialize_rag(): except AttributeError: st.error("Application has not been initialized, please restart.", icon="⛑️") +def show_rag_refs(context): + """When RAG Enabled, show the references""" + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + column_sizes = [10, 8, 8, 8, 2, 24] + cols = st.columns(column_sizes) + # Create a button in each column + links = set() + with cols[0]: + st.markdown("**References:**") + # Limit the maximum number of items to 3 (counting from 0) + max_items = min(len(context), 3) + + # Loop through the chunks and display them + for i in range(max_items): + with cols[i + 1]: + chunk = context[i] + links.add(chunk.metadata["source"]) + with st.popover(f"Ref: {i+1}"): + st.markdown(chunk.metadata["source"]) + st.markdown(chunk.page_content) + st.markdown(chunk.metadata["id"]) + + for link in links: + st.markdown("- " + link) def initialize_chatbot(ll_model): """Initialize the Chatbot""" diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index 2c46da0..61e37ad 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -264,7 +264,7 @@ def init_vs(db_conn, embedding_function, store_table, distance_metric): def get_vs_table(model, chunk_size, chunk_overlap, distance_metric, embed_alias=None): - """Get a list of Vector Store Tables""" + """Return the concatenated VS Table name and comment""" chunk_overlap_ceil = math.ceil(chunk_overlap) table_string = f"{model}_{chunk_size}_{chunk_overlap_ceil}_{distance_metric}" if embed_alias: diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..75c7593 --- /dev/null +++ b/test.sh @@ -0,0 +1,5 @@ +curl -X POST \ + -H 'Content-Type: application/json' \ + -H 'Authorization: Bearer abc' \ + -d '{"message":"How do I determine the accuracy of my vector indexes?"}' \ + http://127.0.0.1:8000/v1/chat/completions From 3032dc57675869c88a6a23cb67ef37cb9ad12b46 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Wed, 30 Oct 2024 09:50:23 +0000 Subject: [PATCH 02/13] api server --- app/src/content/api_server.py | 115 ++++++++++++++++++++------------- app/src/modules/api_server.py | 113 +++++++++++++------------------- app/src/modules/chatbot.py | 4 ++ app/src/modules/langgraph.py | 117 ++++++++++++++++++++++++++++++++++ app/src/modules/utilities.py | 2 + test.sh | 8 ++- 6 files changed, 243 insertions(+), 116 deletions(-) create mode 100644 app/src/modules/langgraph.py diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 22c8c67..e5f6393 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -2,12 +2,13 @@ Copyright (c) 2023, 2024, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ + # spell-checker:ignore streamlit, langchain, llms import inspect import time -import threading import json +import threading # Streamlit import streamlit as st @@ -33,54 +34,76 @@ def initialize_streamlit(): state.api_server_config = api_server.config() logger.info("initialized API Server Config") -# def display_logs(chat_history): -# log_placeholder = st.empty() # A placeholder to update logs -# logs = [] # Store logs for display - -# try: -# while "server_thread" in st.session_state: -# try: -# # Retrieve log from queue (non-blocking) -# log_item = api_server.log_queue.get_nowait() -# if log_item['type'] == "AIMessageChunk": -# st.chat_message("ai").write(log_item['message']) -# except (KeyError, TypeError): -# st.chat_message("human").write(log_item['message']) -# #logs.append(log_item) -# # Update the placeholder with new logs -# # log_placeholder.text("\n".join(logs)) -# except api_server.queue.Empty: -# time.sleep(0.1) # Avoid busy-waiting -# finally: -# logger.info("API Server events display has stopped.") + +def display_logs(): + # log_placeholder = st.empty() # A placeholder to update logs + # logs = [] # Store logs for display + + try: + while "server_thread" in st.session_state: + try: + # Retrieve log from queue (non-blocking) + msg = api_server.log_queue.get_nowait() + print(msg) + # if "content" in msg: + # st.chat_message("ai").write(dir(msg)) + # else: + # st.chat_message("human").write(msg['message']) + #print(msg) + # try: + # print(msg.message) + # except: + # print(msg) + + # if msg.type == "AIMessageChunk": + # st.chat_message("ai").write(msg.content) + # else: + # st.chat_message(msg.type).write(msg.content) + # try: + # st.chat_message("ai").write_stream(msg) + # except (KeyError, TypeError): + # st.chat_message("human").write(msg) + # logs.append(log_item) + # Update the placeholder with new logs + # log_placeholder.text("\n".join(logs)) + except api_server.queue.Empty: + time.sleep(0.1) # Avoid busy-waiting + finally: + logger.info("API Server events display has stopped.") def api_server_start(): + chat_history = StreamlitChatMessageHistory(key="api_chat_history") state.api_server_config["port"] = state.user_api_server_port state.api_server_config["key"] = state.user_api_server_key if "initialized" in state and state.initialized: if "server_thread" not in state: - state.httpd = api_server.run_server( - state.api_server_config["port"], - state.chat_manager, - state.rag_params, - state.lm_instr, - state.context_instr, - state.api_server_config["key"], - ) - - # Start the server in the thread - def api_server_process(httpd): - httpd.serve_forever() - - state.server_thread = threading.Thread( - target=api_server_process, - # Trailing , ensures tuple is passed - args=(state.httpd,), - daemon=True, - ) - state.server_thread.start() - logger.info("Started API Server on port: %i", state.api_server_config["port"]) + try: + state.httpd = api_server.run_server( + state.api_server_config["port"], + state.chat_manager, + state.rag_params, + state.lm_instr, + state.context_instr, + state.api_server_config["key"], + chat_history, + False, + ) + + # Start the server in the thread + def api_server_process(httpd): + httpd.serve_forever() + + state.server_thread = threading.Thread( + target=api_server_process, + # Trailing , ensures tuple is passed + args=(state.httpd,), + daemon=True, + ) + state.server_thread.start() + logger.info("Started API Server on port: %i", state.api_server_config["port"]) + except OSError: + st.error("Port is already in use.") else: st.warning("API Server is already running.") else: @@ -190,10 +213,10 @@ def main(): ######################################################################### # API Server Centre ######################################################################### - # if "server_thread" in state: - # st.subheader("Activity") - # with st.container(border=True): - # display_logs(chat_history) + if "server_thread" in state: + st.subheader("Activity") + with st.container(border=True): + display_logs() if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 57c93b9..96eface 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -2,6 +2,7 @@ Copyright (c) 2023, 2024, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ + # spell-checker:ignore streamlit, langchain import os @@ -12,6 +13,8 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import urlparse +from langchain_community.chat_message_histories import StreamlitChatMessageHistory + # Utilities import modules.chatbot as chatbot import modules.logging_config as logging_config @@ -44,51 +47,28 @@ def generate_api_key(length=32): "key": os.environ.get("API_SERVER_KEY", default=auto_api_key), } - -def get_answer_fn( - - - question: str, - history=None, - chat_manager=None, - rag_params=None, - lm_instr=None, - context_instr=None, -) -> str: - """Send for completion""" - try: - response = chatbot.generate_response( - chat_mgr=chat_manager, - input=question, - chat_history={}, - enable_history=False, - rag_params=rag_params, - chat_instr=lm_instr, - context_instr=context_instr, - stream=False, - ) - logger.info("Raw MSG from Chatbot API: %s", response) - if rag_params["enable"]: - return {"messages": response} - return response["answer"] - else: - return response.content - except Exception as ex: - return f"I'm sorry, something's gone wrong: {ex}" - - class ChatbotHTTPRequestHandler(BaseHTTPRequestHandler): """Handler for mini-chatbot""" def __init__( - self, *args, chat_manager=None, rag_params=None, lm_instr=None, context_instr=None, api_key=None, **kwargs + self, + *args, + chat_manager=None, + rag_params=None, + lm_instr=None, + context_instr=None, + api_key=None, + chat_history=None, + enable_history=False, + **kwargs ): self.chat_manager = chat_manager self.rag_params = rag_params self.lm_instr = lm_instr self.context_instr = context_instr self.api_key = api_key - self.workflow = StateGraph(state_schema=MessagesState) + self.chat_history = chat_history + self.enable_history = enable_history super().__init__(*args, **kwargs) def do_OPTIONS(self): # pylint: disable=invalid-name @@ -117,38 +97,37 @@ def do_POST(self): # pylint: disable=invalid-name content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length).decode("utf-8") - # Log the raw body - logger.info("Raw Body: %s", post_data) try: # Parse the POST data as JSON post_json = json.loads(post_data) # Extract the 'message' field from the JSON message = post_json.get("message") - + response = None if message: # Log the incoming message logger.info("MSG to Chatbot API: %s", message) # Call your function to get the chatbot response - answer = get_answer_fn( - message, None, self.chat_manager, self.rag_params, self.lm_instr, self.context_instr + response = chatbot.generate_response( + chat_mgr=self.chat_manager, + input=message, + chat_history=self.chat_history, + enable_history=self.enable_history, + rag_params=self.rag_params, + chat_instr=self.lm_instr, + context_instr=self.context_instr, + stream=False, ) - - # Prepare the response as JSON - response = {"choices": [{"message": {"content": answer}}]} self.send_response(200) + # Process response to JSON else: - # If no message is provided, return an error - response = {"error": "No 'message' field found in request."} self.send_response(400) # Bad request - # Add request/response to the queue - log_queue.put(f"Request: {post_data}") - log_queue.put(f"Response: {response}") + # Add request/response to the queue for output + log_queue.put(post_json) + log_queue.put(response) except json.JSONDecodeError: - # If JSON parsing fails, return an error - response = {"error": "Invalid JSON in request."} self.send_response(400) # Bad request else: # Invalid or missing API Key @@ -156,37 +135,33 @@ def do_POST(self): # pylint: disable=invalid-name self.send_response(401) self.send_header("Content-type", "application/json") self.end_headers() - self.wfile.write(b'{"error": "Unauthorized. Invalid API Key."}') return else: # Return a 404 response for unknown paths self.send_response(404) - response = {"error": "Path not found."} # Send the response self.send_header("Access-Control-Allow-Origin", "*") # Add CORS header self.send_header("Content-Type", "application/json") self.end_headers() - self.wfile.write(json.dumps(response).encode("utf-8")) -def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key): - def create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key): - def handler(*args, **kwargs): - ChatbotHTTPRequestHandler( - *args, - chat_manager=chat_manager, - rag_params=rag_params, - lm_instr=lm_instr, - context_instr=context_instr, - api_key=api_key, - **kwargs, - ) - - return handler +def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key, chat_history, enable_history): + # Define the request handler function + def handler(*args, **kwargs): + return ChatbotHTTPRequestHandler( + *args, + chat_manager=chat_manager, + rag_params=rag_params, + lm_instr=lm_instr, + context_instr=context_instr, + api_key=api_key, + chat_history=chat_history, + enable_history=enable_history, + **kwargs, + ) - handler_with_params = create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key) server_address = ("", port) - httpd = HTTPServer(server_address, handler_with_params) + httpd = HTTPServer(server_address, handler) return httpd diff --git a/app/src/modules/chatbot.py b/app/src/modules/chatbot.py index 0427dab..e390552 100644 --- a/app/src/modules/chatbot.py +++ b/app/src/modules/chatbot.py @@ -138,6 +138,10 @@ def langchain_rag(self, rag_params, chat_instr, context_instr, input, chat_histo # History Aware Chain rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + ### Statefully manage chat history ### + + + conversational_rag_chain = RunnableWithMessageHistory( rag_chain, lambda session_id: chat_history, diff --git a/app/src/modules/langgraph.py b/app/src/modules/langgraph.py new file mode 100644 index 0000000..2689bff --- /dev/null +++ b/app/src/modules/langgraph.py @@ -0,0 +1,117 @@ +""" +Copyright (c) 2023, 2024, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore langchain, langgraph, openai + +from typing import Sequence + +import bs4 +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.document_loaders import WebBaseLoader +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.vectorstores import InMemoryVectorStore +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, StateGraph +from langgraph.graph.message import add_messages +from typing_extensions import Annotated, TypedDict + +llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) + + +### Construct retriever ### +loader = WebBaseLoader( + web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), + bs_kwargs=dict(parse_only=bs4.SoupStrainer(class_=("post-content", "post-title", "post-header"))), +) +docs = loader.load() + +text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) +splits = text_splitter.split_documents(docs) + +vectorstore = InMemoryVectorStore(embedding=OpenAIEmbeddings()) +vectorstore.add_documents(documents=splits) +retriever = vectorstore.as_retriever() + + +### Contextualize question ### +contextualize_q_system_prompt = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." +) +contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) +history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) + + +### Answer question ### +system_prompt = ( + "You are an assistant for question-answering tasks. " + "Use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + "{context}" +) +qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) +question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + +rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + + +### Statefully manage chat history ### + + +# We define a dict representing the state of the application. +# This state has the same input and output keys as `rag_chain`. +class State(TypedDict): + input: str + chat_history: Annotated[Sequence[BaseMessage], add_messages] + context: str + answer: str + + +# We then define a simple node that runs the `rag_chain`. +# The `return` values of the node update the graph state, so here we just +# update the chat history with the input message and response. +def call_model(state: State): + response = rag_chain.invoke(state) + return { + "chat_history": [ + HumanMessage(state["input"]), + AIMessage(response["answer"]), + ], + "context": response["context"], + "answer": response["answer"], + } + + +# Our graph consists only of one node: +workflow = StateGraph(state_schema=State) +workflow.add_edge(START, "model") +workflow.add_node("model", call_model) + +# Finally, we compile the graph with a checkpointer object. +# This persists the state, in this case in memory. +memory = MemorySaver() +app = workflow.compile(checkpointer=memory) diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index 61e37ad..40c06ac 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -535,6 +535,8 @@ def oci_get_namespace(config, retries=True): raise OciException("AuthN Error - Disabling OCI") from ex except oci.exceptions.RequestException as ex: raise OciException("No Network Access - Disabling OCI") from ex + except oci.exceptions.ConnectTimeout as ex: + raise OciException("No Network Access - Disabling OCI") from ex except FileNotFoundError as ex: raise OciException("Invalid Key Path") from ex except UnboundLocalError as ex: diff --git a/test.sh b/test.sh index 75c7593..af05dc7 100644 --- a/test.sh +++ b/test.sh @@ -1,5 +1,11 @@ +#curl -X POST \ +# -H 'Content-Type: application/json' \ +# -H 'Authorization: Bearer abc' \ +# -d '{"message":"How do I determine the accuracy of my vector indexes?"}' \ +# http://127.0.0.1:8000/v1/chat/completions +# curl -X POST \ -H 'Content-Type: application/json' \ -H 'Authorization: Bearer abc' \ - -d '{"message":"How do I determine the accuracy of my vector indexes?"}' \ + -d '{"message":"Are you sure?"}' \ http://127.0.0.1:8000/v1/chat/completions From fee96e7993283936d907d419649db4bba206427f Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Wed, 30 Oct 2024 10:29:55 +0000 Subject: [PATCH 03/13] bump --- app/requirements.txt | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/app/requirements.txt b/app/requirements.txt index 8244fbb..5e017d0 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -7,32 +7,32 @@ ## Installation Example (from repo root): ### python3.11 -m venv .venv ### source .venv/bin/activate -### pip3 install --upgrade pip wheel +### pip3 install --upgrade pip wheel setuptools ### pip3 install -r app/requirements.txt ## Top-Level brings in the required dependencies, if adding modules, try to find the minimum required bokeh==3.6.0 evaluate==0.4.3 faiss-cpu==1.9.0 -giskard==2.15.2 -IPython==8.28.0 +giskard==2.15.3 +IPython==8.29.0 langchain-cohere==0.3.1 -langchain-community==0.3.2 -langchain-huggingface==0.1.0 +langchain-community==0.3.3 +langchain-huggingface==0.1.1 langchain-ollama==0.2.0 -langchain-openai==0.2.2 +langchain-openai==0.2.4 langgraph==0.2.39 -llama_index==0.11.18 +llama_index==0.11.20 lxml==5.3.0 matplotlib==3.9.2 oci>=2.0.0 oracledb>=2.0.0 plotly==5.24.1 streamlit==1.39.0 -umap-learn==0.5.6 +umap-learn==0.5.7 ## For Licensing Purposes; ensures no GPU modules are installed ## as part of langchain-huggingface -f https://download.pytorch.org/whl/cpu/torch -torch==2.4.1+cpu ; sys_platform == "linux" -torch==2.2.2 ; sys_platform == "darwin" \ No newline at end of file +torch==2.5.1+cpu ; sys_platform == "linux" +torch==2.5.1 ; sys_platform == "darwin" \ No newline at end of file From 9b297c9fb42638eae47b1b2aa1071e6cbcca7757 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Wed, 30 Oct 2024 10:58:21 +0000 Subject: [PATCH 04/13] msg printed --- app/src/content/api_server.py | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index e5f6393..8de0d1d 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -7,10 +7,10 @@ import inspect import time -import json import threading # Streamlit +import msgpack import streamlit as st from streamlit import session_state as state @@ -36,36 +36,15 @@ def initialize_streamlit(): def display_logs(): - # log_placeholder = st.empty() # A placeholder to update logs - # logs = [] # Store logs for display - try: while "server_thread" in st.session_state: try: # Retrieve log from queue (non-blocking) msg = api_server.log_queue.get_nowait() - print(msg) - # if "content" in msg: - # st.chat_message("ai").write(dir(msg)) - # else: - # st.chat_message("human").write(msg['message']) - #print(msg) - # try: - # print(msg.message) - # except: - # print(msg) - - # if msg.type == "AIMessageChunk": - # st.chat_message("ai").write(msg.content) - # else: - # st.chat_message(msg.type).write(msg.content) - # try: - # st.chat_message("ai").write_stream(msg) - # except (KeyError, TypeError): - # st.chat_message("human").write(msg) - # logs.append(log_item) - # Update the placeholder with new logs - # log_placeholder.text("\n".join(logs)) + if 'message' in msg: + st.chat_message("human").write(msg['message']) + else: + st.chat_message("ai").write(msg.content) except api_server.queue.Empty: time.sleep(0.1) # Avoid busy-waiting finally: From e5ffabf440115b75d887034b93cfef40697b69a2 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Wed, 30 Oct 2024 11:50:30 +0000 Subject: [PATCH 05/13] catch exceptions --- app/src/content/api_server.py | 14 ++-- app/src/modules/api_server.py | 2 - app/src/modules/st_common.py | 119 ++++++++++++++++++++-------------- 3 files changed, 81 insertions(+), 54 deletions(-) diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 8de0d1d..5be3b39 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -10,7 +10,6 @@ import threading # Streamlit -import msgpack import streamlit as st from streamlit import session_state as state @@ -41,10 +40,15 @@ def display_logs(): try: # Retrieve log from queue (non-blocking) msg = api_server.log_queue.get_nowait() - if 'message' in msg: - st.chat_message("human").write(msg['message']) + logger.info("API Msg: %s", msg) + if "message" in msg: + st.chat_message("human").write(msg["message"]) else: - st.chat_message("ai").write(msg.content) + if state.rag_params["enable"]: + st.chat_message("ai").write(msg["answer"]) + st_common.show_rag_refs(msg["context"]) + else: + st.chat_message("ai").write(msg.content) except api_server.queue.Empty: time.sleep(0.1) # Avoid busy-waiting finally: @@ -66,7 +70,7 @@ def api_server_start(): state.context_instr, state.api_server_config["key"], chat_history, - False, + state.user_chat_history, ) # Start the server in the thread diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 96eface..3457616 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -13,8 +13,6 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import urlparse -from langchain_community.chat_message_histories import StreamlitChatMessageHistory - # Utilities import modules.chatbot as chatbot import modules.logging_config as logging_config diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index c1bb735..b68580b 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -312,82 +312,107 @@ def refresh_rag_filtered(): if user_alias: rag_filt_alias = [user_alias] else: - rag_filt_alias = [ - v.get("alias", None) - for v in state.vs_tables.values() - if (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_alias = [ + v.get("alias", None) + for v in state.vs_tables.values() + if (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_model: rag_filt_model = [user_model] else: - rag_filt_model = [ - v["model"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_model = [ + v["model"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_chunk_size: rag_filt_chunk_size = [user_chunk_size] else: - rag_filt_chunk_size = [ - v["chunk_size"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_chunk_size = [ + v["chunk_size"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_chunk_overlap: rag_filt_chunk_overlap = [user_chunk_overlap] else: - rag_filt_chunk_overlap = [ - v["chunk_overlap"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_chunk_overlap = [ + v["chunk_overlap"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_distance_metric: rag_filt_distance_metric = [user_distance_metric] else: - rag_filt_distance_metric = [ - v["distance_metric"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - ] + try: + rag_filt_distance_metric = [ + v["distance_metric"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() # Remove duplicates and sort - state.rag_filter["alias"] = sorted(set(rag_filt_alias)) - state.rag_filter["model"] = sorted(set(rag_filt_model)) - state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size)) - state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap)) - state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric)) + try: + state.rag_filter["alias"] = sorted(set(rag_filt_alias)) + state.rag_filter["model"] = sorted(set(rag_filt_model)) + state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size)) + state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap)) + state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric)) + except UnboundLocalError: + clear_initialized() # (Re)set the index to previously selected option attributes = ["alias", "model", "chunk_size", "chunk_overlap", "distance_metric"] for attr in attributes: - filtered_list = state.rag_filter[attr] try: + filtered_list = state.rag_filter[attr] user_value = getattr(state, f"rag_user_{attr}") except AttributeError: setattr(state, f"rag_user_{attr}", []) + except KeyError: + clear_initialized() try: idx = 0 if len(filtered_list) == 1 else filtered_list.index(user_value) - except (ValueError, AttributeError): + except (ValueError, AttributeError, UnboundLocalError): idx = None state.rag_user_idx[attr] = idx From c8b474132abfd915d3f7b4e33ffbde340ff7e462 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Wed, 30 Oct 2024 11:52:33 +0000 Subject: [PATCH 06/13] Remove old --- app/src/content/api_server.py.bak | 190 ------------------------------ test.sh | 11 -- 2 files changed, 201 deletions(-) delete mode 100644 app/src/content/api_server.py.bak delete mode 100644 test.sh diff --git a/app/src/content/api_server.py.bak b/app/src/content/api_server.py.bak deleted file mode 100644 index 7987193..0000000 --- a/app/src/content/api_server.py.bak +++ /dev/null @@ -1,190 +0,0 @@ -""" -Copyright (c) 2023, 2024, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" - -# spell-checker:ignore streamlit -import inspect -import threading -import time - -# Streamlit -import streamlit as st -from streamlit import session_state as state - -# Utilities -import modules.st_common as st_common -import modules.api_server as api_server - -import modules.logging_config as logging_config - -logger = logging_config.logging.getLogger("api_server") - - -############################################################################# -# Functions -############################################################################# -def initialize_streamlit(): - """initialize Streamlit Session State""" - if "api_server_config" not in state: - state.api_server_config = api_server.config() - logger.info("initialized API Server Config") - - -def display_logs(): - log_placeholder = st.empty() # A placeholder to update logs - logs = [] # Store logs for display - - try: - while "server_thread" in st.session_state: - try: - # Retrieve log from queue (non-blocking) - log_item = api_server.log_queue.get_nowait() - logs.append(log_item) - # Update the placeholder with new logs - log_placeholder.text("\n".join(logs)) - except api_server.queue.Empty: - time.sleep(0.1) # Avoid busy-waiting - finally: - logger.info("API Server events display has stopped.") - - -def api_server_start(): - state.api_server_config["port"] = state.user_api_server_port - state.api_server_config["key"] = state.user_api_server_key - if "initialized" in state and state.initialized: - if "server_thread" not in state: - state.httpd = api_server.run_server( - state.api_server_config["port"], - state.chat_manager, - state.rag_params, - state.lm_instr, - state.context_instr, - state.api_server_config["key"], - ) - - # Start the server in the thread - def api_server_process(httpd): - httpd.serve_forever() - - state.server_thread = threading.Thread( - target=api_server_process, - # Trailing , ensures tuple is passed - args=(state.httpd,), - daemon=True, - ) - state.server_thread.start() - logger.info("Started API Server on port: %i", state.api_server_config["port"]) - else: - st.warning("API Server is already running.") - else: - logger.warning("Unable to start API Server; ChatMgr not configured") - st.warning("Failed to start API Server: Chatbot not initialized.") - - -def api_server_stop(): - if "server_thread" in state: - if state.server_thread.is_alive(): - state.httpd.shutdown() # Shut down the server - state.server_thread.join() # Wait for the thread to exit - - del state.server_thread # Clean up the thread reference - del state.httpd # Clean up the server reference - - logger.info("API Server stopped successfully.") - st.success("API Server stopped successfully.") - else: - logger.warning("API Server thread is not running - cleaning up.") - st.warning("API Server thread is not running.") - del state.server_thread - else: - logger.info("Unable to stop API Server - not running.") - - -############################################################################# -# MAIN -############################################################################# -def main(): - """Streamlit GUI""" - initialize_streamlit() - - - # LLM Params - ll_model = st_common.lm_sidebar() - - # Initialize RAG - st_common.initialize_rag() - - # RAG - st_common.rag_sidebar() - - ######################################################################### - # Initialize the Client - ######################################################################### - if "initialized" not in state: - if not state.rag_params["enable"] or all( - state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"] - ): - try: - state.chat_manager = st_common.initialize_chatbot(ll_model) - state.initialized = True - st_common.update_rag() - logger.debug("Force rerun to save state") - if "server_thread" in state: - logger.info("Restarting API Server") - api_server_stop() - api_server_start() - st.rerun() - except Exception as ex: - logger.exception(ex, exc_info=False) - st.error(f"Failed to initialize the chat client: {ex}") - st_common.clear_initialized() - if st.button("Retry", key="retry_initialize"): - st.rerun() - st.stop() - else: - # RAG Enabled but not configured - if "server_thread" in state: - logger.info("Stopping API Server") - api_server_stop() - - ######################################################################### - # API Server - ######################################################################### - server_running = False - if "server_thread" in state: - server_running = True - st.success("API Server is Running") - - left, right = st.columns([0.2, 0.8]) - left.number_input( - "API Server Port:", - value=state.api_server_config["port"], - min_value=1, - max_value=65535, - key="user_api_server_port", - disabled=server_running, - ) - right.text_input( - "API Server Key:", - type="password", - value=state.api_server_config["key"], - key="user_api_server_key", - disabled=server_running, - ) - - if "server_thread" in state: - st.button("Stop Server", type="primary", on_click=api_server_stop) - elif "initialized" in state and state.initialized: - st.button("Start Server", type="primary", on_click=api_server_start) - else: - st.error("Not all required RAG options are set, please review or disable RAG.") - - st.subheader("Activity") - if "server_thread" in state: - with st.container(border=True): - display_logs() - - -if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: - main() diff --git a/test.sh b/test.sh deleted file mode 100644 index af05dc7..0000000 --- a/test.sh +++ /dev/null @@ -1,11 +0,0 @@ -#curl -X POST \ -# -H 'Content-Type: application/json' \ -# -H 'Authorization: Bearer abc' \ -# -d '{"message":"How do I determine the accuracy of my vector indexes?"}' \ -# http://127.0.0.1:8000/v1/chat/completions -# -curl -X POST \ - -H 'Content-Type: application/json' \ - -H 'Authorization: Bearer abc' \ - -d '{"message":"Are you sure?"}' \ - http://127.0.0.1:8000/v1/chat/completions From 251196a2062acff6c3ef11979a5fc16b69f3261b Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Thu, 31 Oct 2024 11:24:33 +0000 Subject: [PATCH 07/13] start langgraph --- .gitignore | 1 + app/src/modules/langgraph.py | 117 ----------------------------------- 2 files changed, 1 insertion(+), 117 deletions(-) delete mode 100644 app/src/modules/langgraph.py diff --git a/.gitignore b/.gitignore index ff1c352..0939aa7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ app/THIRD_PARTY_LICENSES.txt rag/** sbin/** helm/values.yaml +**/*.bak ############################################################################## # Enviroment (PyVen, IDE, etc.) diff --git a/app/src/modules/langgraph.py b/app/src/modules/langgraph.py deleted file mode 100644 index 2689bff..0000000 --- a/app/src/modules/langgraph.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Copyright (c) 2023, 2024, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker:ignore langchain, langgraph, openai - -from typing import Sequence - -import bs4 -from langchain.chains import create_history_aware_retriever, create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain_community.document_loaders import WebBaseLoader -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.vectorstores import InMemoryVectorStore -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_text_splitters import RecursiveCharacterTextSplitter -from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph import START, StateGraph -from langgraph.graph.message import add_messages -from typing_extensions import Annotated, TypedDict - -llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) - - -### Construct retriever ### -loader = WebBaseLoader( - web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), - bs_kwargs=dict(parse_only=bs4.SoupStrainer(class_=("post-content", "post-title", "post-header"))), -) -docs = loader.load() - -text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) -splits = text_splitter.split_documents(docs) - -vectorstore = InMemoryVectorStore(embedding=OpenAIEmbeddings()) -vectorstore.add_documents(documents=splits) -retriever = vectorstore.as_retriever() - - -### Contextualize question ### -contextualize_q_system_prompt = ( - "Given a chat history and the latest user question " - "which might reference context in the chat history, " - "formulate a standalone question which can be understood " - "without the chat history. Do NOT answer the question, " - "just reformulate it if needed and otherwise return it as is." -) -contextualize_q_prompt = ChatPromptTemplate.from_messages( - [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) -history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) - - -### Answer question ### -system_prompt = ( - "You are an assistant for question-answering tasks. " - "Use the following pieces of retrieved context to answer " - "the question. If you don't know the answer, say that you " - "don't know. Use three sentences maximum and keep the " - "answer concise." - "\n\n" - "{context}" -) -qa_prompt = ChatPromptTemplate.from_messages( - [ - ("system", system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), - ] -) -question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) - -rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - - -### Statefully manage chat history ### - - -# We define a dict representing the state of the application. -# This state has the same input and output keys as `rag_chain`. -class State(TypedDict): - input: str - chat_history: Annotated[Sequence[BaseMessage], add_messages] - context: str - answer: str - - -# We then define a simple node that runs the `rag_chain`. -# The `return` values of the node update the graph state, so here we just -# update the chat history with the input message and response. -def call_model(state: State): - response = rag_chain.invoke(state) - return { - "chat_history": [ - HumanMessage(state["input"]), - AIMessage(response["answer"]), - ], - "context": response["context"], - "answer": response["answer"], - } - - -# Our graph consists only of one node: -workflow = StateGraph(state_schema=State) -workflow.add_edge(START, "model") -workflow.add_node("model", call_model) - -# Finally, we compile the graph with a checkpointer object. -# This persists the state, in this case in memory. -memory = MemorySaver() -app = workflow.compile(checkpointer=memory) From 31f2c9753e8ef468fa9e56862e79e648c4387596 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Fri, 1 Nov 2024 10:37:15 +0000 Subject: [PATCH 08/13] API Server Autostart --- app/src/content/api_server.py | 66 +++++++++++--------- app/src/modules/api_server.py | 12 ++-- app/src/modules/st_common.py | 4 +- app/src/modules/utilities.py | 109 ++++++++++++---------------------- 4 files changed, 85 insertions(+), 106 deletions(-) diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 5be3b39..69bd3f2 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -130,27 +130,6 @@ def main(): enabled_llms = sum(model_info["enabled"] for model_info in state.ll_model_config.values()) if enabled_llms > 0: initialize_streamlit() - server_running = False - if "server_thread" in state: - server_running = True - st.success("API Server is Running") - - left, right = st.columns([0.2, 0.8]) - left.number_input( - "API Server Port:", - value=state.api_server_config["port"], - min_value=1, - max_value=65535, - key="user_api_server_port", - disabled=server_running, - ) - right.text_input( - "API Server Key:", - type="password", - value=state.api_server_config["key"], - key="user_api_server_key", - disabled=server_running, - ) enable_history = st.sidebar.checkbox( "Enable History and Context?", value=True, @@ -160,21 +139,16 @@ def main(): chat_history.clear() st.sidebar.divider() ll_model = st_common.lm_sidebar() - if "server_thread" in state: - st.button("Stop Server", type="primary", on_click=api_server_stop) - elif "initialized" in state and state.initialized: - st.button("Start Server", type="primary", on_click=api_server_start) + st_common.rag_sidebar() else: st.error("No chat models are configured and/or enabled.", icon="🚨") st.stop() - # RAG - st_common.rag_sidebar() - ######################################################################### - # Initialize the Client + # Main ######################################################################### if "initialized" not in state: + api_server_stop() if not state.rag_params["enable"] or all( state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"] ): @@ -193,6 +167,40 @@ def main(): st.stop() else: st.error("Not all required RAG options are set, please review or disable RAG.") + st.stop() + + server_running = False + if "server_thread" in state: + server_running = True + elif state.api_server_config["auto_start"]: + server_running = True + + left, right = st.columns([0.2, 0.8]) + left.number_input( + "API Server Port:", + value=state.api_server_config["port"], + min_value=1, + max_value=65535, + key="user_api_server_port", + disabled=server_running, + ) + right.text_input( + "API Server Key:", + type="password", + value=state.api_server_config["key"], + key="user_api_server_key", + disabled=server_running, + ) + + if state.api_server_config["auto_start"]: + st.success("API Server automatically started.") + api_server_start() + else: + if server_running: + st.button("Stop Server", type="primary", on_click=api_server_stop) + else: + st.button("Start Server", type="primary", on_click=api_server_start) + ######################################################################### # API Server Centre ######################################################################### diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 3457616..807efa7 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -38,11 +38,15 @@ def generate_api_key(length=32): # Generates a URL-safe, base64-encoded random string with the given length return secrets.token_urlsafe(length) - auto_port = find_available_port() - auto_api_key = generate_api_key() + api_server_port = os.environ.get("API_SERVER_PORT") + api_server_key = os.environ.get("API_SERVER_KEY") + + auto_start = bool(api_server_port and api_server_key) + return { - "port": os.environ.get("API_SERVER_PORT", default=auto_port), - "key": os.environ.get("API_SERVER_KEY", default=auto_api_key), + "port": int(api_server_port) if api_server_port else find_available_port(), + "key": api_server_key if api_server_key else generate_api_key(), + "auto_start": auto_start } class ChatbotHTTPRequestHandler(BaseHTTPRequestHandler): diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index b68580b..3dcae03 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -417,7 +417,7 @@ def refresh_rag_filtered(): state.rag_user_idx[attr] = idx - rag_enable = st.sidebar.checkbox( + st.sidebar.checkbox( "RAG?", value=state.rag_params["enable"], key="rag_user_enable", @@ -426,7 +426,7 @@ def refresh_rag_filtered(): on_change=reset_rag, ) - if rag_enable: + if state.rag_params["enable"]: set_default_state("rag_user_rerank", False) st.sidebar.checkbox( "Enable Re-Ranking?", diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index 40c06ac..cf28ba0 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -51,29 +51,25 @@ # General ############################################################################### def is_url_accessible(url): - """Check that URL is Available""" - logger.debug("Checking %s is accessible", url) + """Check if the URL is accessible.""" + logger.debug("Checking if %s is accessible", url) + try: response = requests.get(url, timeout=2) - logger.info("Checking %s resulted in %s", url, response.status_code) - # Check if the response status code is 200 (OK) 403 (Forbidden) 404 (Not Found) 421 (Misdirected) - if response.status_code in [200, 403, 404, 421]: + logger.info("Response for %s: %s", url, response.status_code) + + if response.status_code in {200, 403, 404, 421}: return True, None - else: - err_msg = f"{url} is not accessible. (Status: {response.status_code})" - logger.warning(err_msg) - return False, err_msg - except requests.exceptions.ConnectionError: - err_msg = f"{url} is not accessible. (Connection Error)" + + err_msg = f"{url} is not accessible. (Status: {response.status_code})" logger.warning(err_msg) return False, err_msg - except requests.exceptions.Timeout: - err_msg = f"{url} is not accessible. (Request Timeout)" + + except requests.exceptions.RequestException as ex: + err_msg = f"{url} is not accessible. ({type(ex).__name__})" logger.warning(err_msg) - return False, err_msg - except requests.RequestException as ex: logger.exception(ex, exc_info=False) - return False, ex + return False, err_msg ############################################################################### @@ -97,6 +93,15 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): logger.debug("Matching LLM API: %s", llm_api) + common_params = { + "model": model, + "temperature": lm_params["temperature"][0], + "max_tokens": lm_params["max_tokens"][0], + "top_p": lm_params["top_p"][0], + "frequency_penalty": lm_params["frequency_penalty"][0], + "presence_penalty": lm_params["presence_penalty"][0], + } + ## Start - Add Additional Model Authentication Here client = None if giskarded: @@ -104,49 +109,13 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): _client = OpenAI(api_key=giskard_key, base_url=f"{llm_url}/v1/") client = OpenAIClient(model=model, client=_client) elif llm_api == "OpenAI": - client = ChatOpenAI( - api_key=lm_params["api_key"], - model_name=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - top_p=lm_params["top_p"][0], - frequency_penalty=lm_params["frequency_penalty"][0], - presence_penalty=lm_params["presence_penalty"][0], - ) + client = ChatOpenAI(api_key=lm_params["api_key"], **common_params) elif llm_api == "Cohere": - client = ChatCohere( - cohere_api_key=lm_params["api_key"], - model=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - top_p=lm_params["top_p"][0], - frequency_penalty=lm_params["frequency_penalty"][0], - presence_penalty=lm_params["presence_penalty"][0], - ) + client = ChatCohere(cohere_api_key=lm_params["api_key"], **common_params) elif llm_api == "ChatPerplexity": - client = ChatPerplexity( - pplx_api_key=lm_params["api_key"], - model=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - model_kwargs={ - "top_p": lm_params["top_p"][0], - "frequency_penalty": lm_params["frequency_penalty"][0], - "presence_penalty": lm_params["presence_penalty"][0], - }, - ) + client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params) elif llm_api == "ChatOllama": - client = ChatOllama( - model=model, - base_url=llm_url, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - model_kwargs={ - "top_p": lm_params["top_p"][0], - "frequency_penalty": lm_params["frequency_penalty"][0], - "presence_penalty": lm_params["presence_penalty"][0], - }, - ) + client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params) ## End - Add Additional Model Authentication Here api_accessible, err_msg = is_url_accessible(llm_url) @@ -514,35 +483,33 @@ def oci_config_from_file(file=None, profile=None): def oci_get_namespace(config, retries=True): - """Get the Object Storage Namespace. Also used for testing AuthN""" + """Get the Object Storage Namespace. Also used for testing AuthN.""" logger.info("Getting Object Storage Namespace") + + client = None try: client = oci_init_client(oci.object_storage.ObjectStorageClient, config, retries) - except oci.exceptions.InvalidConfig: + except (oci.exceptions.InvalidConfig, FileNotFoundError): try: client = oci_init_client(oci.object_storage.ObjectStorageClient, retries=retries) except ValueError: - pass - except FileNotFoundError: - pass + logger.error("Failed to initialize client without config") + return None + + if client is None: + logger.error("Client could not be initialized") + raise OciException("No Configuration - Disabling OCI") try: namespace = client.get_namespace().data logger.info("Succeeded - Namespace = %s", namespace) - except oci.exceptions.InvalidConfig as ex: - raise OciException("Invalid Config - Disabling OCI") from ex - except oci.exceptions.ServiceError as ex: + return namespace + except (oci.exceptions.InvalidConfig, oci.exceptions.ServiceError) as ex: raise OciException("AuthN Error - Disabling OCI") from ex - except oci.exceptions.RequestException as ex: - raise OciException("No Network Access - Disabling OCI") from ex - except oci.exceptions.ConnectTimeout as ex: + except (oci.exceptions.RequestException, oci.exceptions.ConnectTimeout) as ex: raise OciException("No Network Access - Disabling OCI") from ex except FileNotFoundError as ex: raise OciException("Invalid Key Path") from ex - except UnboundLocalError as ex: - raise OciException("No Configuration - Disabling OCI") from ex - - return namespace def oci_get_compartments(config, retries=True): From 1e0b3409c78c9922c489d0f3119688292e535dbb Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Fri, 1 Nov 2024 10:43:04 +0000 Subject: [PATCH 09/13] Linting --- app/src/content/chatbot.py | 1 + app/src/content/model_config.py | 2 +- app/src/content/split_embed.py | 2 +- app/src/modules/api_server.py | 5 ++-- app/src/modules/chatbot.py | 5 +--- app/src/modules/logging_config.py | 1 + app/src/modules/metadata.py | 19 +++++++-------- app/src/modules/report_utils.py | 40 ++++++++----------------------- app/src/modules/st_common.py | 2 ++ 9 files changed, 29 insertions(+), 48 deletions(-) diff --git a/app/src/content/chatbot.py b/app/src/content/chatbot.py index ee205db..4ec6ef3 100644 --- a/app/src/content/chatbot.py +++ b/app/src/content/chatbot.py @@ -20,6 +20,7 @@ logger = logging_config.logging.getLogger("chatbot") + ############################################################################# # MAIN ############################################################################# diff --git a/app/src/content/model_config.py b/app/src/content/model_config.py index 490b9ee..fe9bbfd 100644 --- a/app/src/content/model_config.py +++ b/app/src/content/model_config.py @@ -157,7 +157,6 @@ def main(): if update_ll_model: st.success("Language Model Configuration - Updated", icon="✅") - st.header("Embedding Models") with st.form("update_embed_model_config"): # Create table header @@ -210,5 +209,6 @@ def main(): if update_embed_model: st.success("Embedding Model Configuration - Updated", icon="✅") + if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: main() diff --git a/app/src/content/split_embed.py b/app/src/content/split_embed.py index 8a15372..57ef864 100644 --- a/app/src/content/split_embed.py +++ b/app/src/content/split_embed.py @@ -450,7 +450,7 @@ def main(): model, distance_metric, split_docos, - 0 if rate_limit is None else rate_limit + 0 if rate_limit is None else rate_limit, ) placeholder.empty() st_common.reset_rag() diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 807efa7..86a0fe2 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -46,9 +46,10 @@ def generate_api_key(length=32): return { "port": int(api_server_port) if api_server_port else find_available_port(), "key": api_server_key if api_server_key else generate_api_key(), - "auto_start": auto_start + "auto_start": auto_start, } + class ChatbotHTTPRequestHandler(BaseHTTPRequestHandler): """Handler for mini-chatbot""" @@ -62,7 +63,7 @@ def __init__( api_key=None, chat_history=None, enable_history=False, - **kwargs + **kwargs, ): self.chat_manager = chat_manager self.rag_params = rag_params diff --git a/app/src/modules/chatbot.py b/app/src/modules/chatbot.py index e390552..48d8c55 100644 --- a/app/src/modules/chatbot.py +++ b/app/src/modules/chatbot.py @@ -19,6 +19,7 @@ logger = logging_config.logging.getLogger("modules.chatbot") + def generate_response( chat_mgr, input, chat_history, enable_history, rag_params, chat_instr, context_instr=None, stream=False ): @@ -138,10 +139,6 @@ def langchain_rag(self, rag_params, chat_instr, context_instr, input, chat_histo # History Aware Chain rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - ### Statefully manage chat history ### - - - conversational_rag_chain = RunnableWithMessageHistory( rag_chain, lambda session_id: chat_history, diff --git a/app/src/modules/logging_config.py b/app/src/modules/logging_config.py index 694cfcc..eeb2dba 100644 --- a/app/src/modules/logging_config.py +++ b/app/src/modules/logging_config.py @@ -28,5 +28,6 @@ def setup_logging(): # Sagemaker continuously complains about config, suppress logging.getLogger("sagemaker.config").setLevel(logging.WARNING) + # Call setup_logging when this module is imported setup_logging() diff --git a/app/src/modules/metadata.py b/app/src/modules/metadata.py index 6edd150..4abfd12 100644 --- a/app/src/modules/metadata.py +++ b/app/src/modules/metadata.py @@ -210,7 +210,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 512, - "dimensions": 768 + "dimensions": 768, }, "text-embedding-3-small": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -219,8 +219,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 1536 - + "dimensions": 1536, }, "text-embedding-3-large": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -229,7 +228,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 3072 + "dimensions": 3072, }, "text-embedding-ada-002": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -238,7 +237,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 1536 + "dimensions": 1536, }, "embed-english-v3.0": { "enabled": os.getenv("COHERE_API_KEY") is not None, @@ -247,7 +246,7 @@ def embedding_models(): "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, "chunk_max": 512, - "dimensions": 1024 + "dimensions": 1024, }, "embed-english-light-v3.0": { "enabled": os.getenv("COHERE_API_KEY") is not None, @@ -256,7 +255,7 @@ def embedding_models(): "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, "chunk_max": 512, - "dimensions": 384 + "dimensions": 384, }, "mxbai-embed-large": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -265,7 +264,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 512, - "dimensions": 1024 + "dimensions": 1024, }, "nomic-embed-text": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -274,7 +273,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 8192, - "dimensions": 768 + "dimensions": 768, }, "all-minilm": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -283,7 +282,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 256, - "dimensions": 384 + "dimensions": 384, }, } return embedding_models_dict diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py index bde682d..c6a3d4f 100644 --- a/app/src/modules/report_utils.py +++ b/app/src/modules/report_utils.py @@ -67,9 +67,7 @@ def record_update(): if "reference_answer_input" not in state: state.reference_answer_input = state.df.iloc[state.index]["reference_answer"] if "reference_context_input" not in state: - state.reference_context_input = state.df.iloc[state.index][ - "reference_context" - ] + state.reference_context_input = state.df.iloc[state.index]["reference_context"] if "metadata_input" not in state: state.metadata_input = state.df.iloc[state.index]["metadata"] @@ -86,12 +84,8 @@ def record_update(): state.index -= 1 state.hide_input = state.df.at[state.index, "hide"] state.question_input = state.df.at[state.index, "question"] - state.reference_answer_input = state.df.at[ - state.index, "reference_answer" - ] - state.reference_context_input = state.df.at[ - state.index, "reference_context" - ] + state.reference_answer_input = state.df.at[state.index, "reference_answer"] + state.reference_context_input = state.df.at[state.index, "reference_context"] state.metadata_input = state.df.at[state.index, "metadata"] # Button to move to the next question @@ -102,12 +96,8 @@ def record_update(): state.index += 1 state.hide_input = state.df.at[state.index, "hide"] state.question_input = state.df.at[state.index, "question"] - state.reference_answer_input = state.df.at[ - state.index, "reference_answer" - ] - state.reference_context_input = state.df.at[ - state.index, "reference_context" - ] + state.reference_answer_input = state.df.at[state.index, "reference_answer"] + state.reference_context_input = state.df.at[state.index, "reference_context"] state.metadata_input = state.df.at[state.index, "metadata"] # Button to save the current input value @@ -117,12 +107,8 @@ def record_update(): # Save the current input value in the DataFrame state.df.at[state.index, "hide"] = state.hide_input state.df.at[state.index, "question"] = state.question_input - state.df.at[state.index, "reference_answer"] = ( - state.reference_answer_input - ) - state.df.at[state.index, "reference_context"] = ( - state.reference_context_input - ) + state.df.at[state.index, "reference_answer"] = state.reference_answer_input + state.df.at[state.index, "reference_context"] = state.reference_context_input # state.df.at[state.index, 'metadata'] = state.metadata_input # It's read-only logger.info("--------SAVE----------------------") @@ -141,20 +127,14 @@ def record_update(): state.df.to_json(file_path, orient="records", lines=True, index=False) # Text input for the question, storing the user's input in the session state - state.index_output = st.write( - "Record: " + str(state.index + 1) + "/" + str(state.df.shape[0]) - ) + state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0])) state.hide_input = st.checkbox("Hide", value=state.hide_input) state.question_input = st.text_area("question", height=1, value=state.question_input) - state.reference_answer_input = st.text_area( - "Reference answer", height=1, value=state.reference_answer_input - ) + state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input) state.reference_context_input = st.text_area( "Reference context", height=10, value=state.reference_context_input, disabled=True ) - state.metadata_input = st.text_area( - "Metadata", height=1, value=state.metadata_input, disabled=True - ) + state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True) if save_clicked: st.success("Q&A saved successfully!") diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index 3dcae03..078e872 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -140,6 +140,7 @@ def initialize_rag(): except AttributeError: st.error("Application has not been initialized, please restart.", icon="⛑️") + def show_rag_refs(context): """When RAG Enabled, show the references""" st.markdown( @@ -180,6 +181,7 @@ def show_rag_refs(context): for link in links: st.markdown("- " + link) + def initialize_chatbot(ll_model): """Initialize the Chatbot""" logger.info("Initializing ChatBot using %s; RAG: %s", ll_model, state.rag_params["enable"]) From 93fb72720a83b6b45ef8be2ce3ac4d41e191000a Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Fri, 1 Nov 2024 10:57:12 +0000 Subject: [PATCH 10/13] Fix Namespace --- app/src/modules/utilities.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index cf28ba0..2b7d415 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -483,33 +483,33 @@ def oci_config_from_file(file=None, profile=None): def oci_get_namespace(config, retries=True): - """Get the Object Storage Namespace. Also used for testing AuthN.""" + """Get the Object Storage Namespace. Also used for testing AuthN""" logger.info("Getting Object Storage Namespace") - - client = None try: client = oci_init_client(oci.object_storage.ObjectStorageClient, config, retries) - except (oci.exceptions.InvalidConfig, FileNotFoundError): + except oci.exceptions.InvalidConfig: try: client = oci_init_client(oci.object_storage.ObjectStorageClient, retries=retries) except ValueError: - logger.error("Failed to initialize client without config") - return None - - if client is None: - logger.error("Client could not be initialized") - raise OciException("No Configuration - Disabling OCI") + pass + except FileNotFoundError: + pass try: namespace = client.get_namespace().data logger.info("Succeeded - Namespace = %s", namespace) - return namespace - except (oci.exceptions.InvalidConfig, oci.exceptions.ServiceError) as ex: + except oci.exceptions.InvalidConfig as ex: + raise OciException("Invalid Config - Disabling OCI") from ex + except oci.exceptions.ServiceError as ex: raise OciException("AuthN Error - Disabling OCI") from ex - except (oci.exceptions.RequestException, oci.exceptions.ConnectTimeout) as ex: + except oci.exceptions.RequestException as ex: raise OciException("No Network Access - Disabling OCI") from ex except FileNotFoundError as ex: raise OciException("Invalid Key Path") from ex + except UnboundLocalError as ex: + raise OciException("No Configuration - Disabling OCI") from ex + + return namespace def oci_get_compartments(config, retries=True): From 05525ccd375c379321e15b8bdcdc554db4de52e2 Mon Sep 17 00:00:00 2001 From: John Lathouwers Date: Tue, 5 Nov 2024 11:57:39 +0000 Subject: [PATCH 11/13] Helm for API Server --- app/Dockerfile | 2 +- app/requirements.txt | 10 +++++----- app/src/content/api_server.py | 12 ++++++++---- app/src/modules/utilities.py | 10 ++++++---- app/src/oaim-sandbox.py | 6 ++++++ helm/README.md | 1 + helm/templates/configmap.yaml | 2 +- helm/templates/deployment.yaml | 20 +++++++++++++++++++- helm/templates/ingress.yaml | 12 ++++++------ helm/templates/service.yaml | 24 +++++++++++++++++++++--- 10 files changed, 74 insertions(+), 25 deletions(-) diff --git a/app/Dockerfile b/app/Dockerfile index e17fa05..4643ab7 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -12,7 +12,7 @@ RUN microdnf -y install python3.11 python3.11-pip RUN python3.11 -m venv --symlinks --upgrade-deps /opt/venv COPY ./requirements.txt /opt/requirements.txt RUN source /opt/venv/bin/activate && \ - pip3 install --upgrade pip wheel && \ + pip3 install --upgrade pip wheel setuptools && \ pip3 install -r /opt/requirements.txt RUN groupadd $RUNUSER && useradd -u 10001 -g $RUNUSER -md /app $RUNUSER; RUN install -d -m 0700 -o $RUNUSER -g $RUNUSER /app/.oci diff --git a/app/requirements.txt b/app/requirements.txt index 5e017d0..67e7b21 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -17,12 +17,12 @@ faiss-cpu==1.9.0 giskard==2.15.3 IPython==8.29.0 langchain-cohere==0.3.1 -langchain-community==0.3.3 -langchain-huggingface==0.1.1 +langchain-community==0.3.5 +langchain-huggingface==0.1.2 langchain-ollama==0.2.0 -langchain-openai==0.2.4 -langgraph==0.2.39 -llama_index==0.11.20 +langchain-openai==0.2.5 +langgraph==0.2.45 +llama_index==0.11.21 lxml==5.3.0 matplotlib==3.9.2 oci>=2.0.0 diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 69bd3f2..dbcf31f 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -57,8 +57,10 @@ def display_logs(): def api_server_start(): chat_history = StreamlitChatMessageHistory(key="api_chat_history") - state.api_server_config["port"] = state.user_api_server_port - state.api_server_config["key"] = state.user_api_server_key + if "user_api_server_port" in state: + state.api_server_config["port"] = state.user_api_server_port + if "user_api_server_key" in state: + state.api_server_config["key"] = state.user_api_server_key if "initialized" in state and state.initialized: if "server_thread" not in state: try: @@ -86,7 +88,8 @@ def api_server_process(httpd): state.server_thread.start() logger.info("Started API Server on port: %i", state.api_server_config["port"]) except OSError: - st.error("Port is already in use.") + if not state.api_server_config["auto_start"]: + st.error("Port is already in use.") else: st.warning("API Server is already running.") else: @@ -184,6 +187,7 @@ def main(): key="user_api_server_port", disabled=server_running, ) + right.text_input( "API Server Key:", type="password", @@ -193,8 +197,8 @@ def main(): ) if state.api_server_config["auto_start"]: - st.success("API Server automatically started.") api_server_start() + st.success("API Server automatically started.") else: if server_running: st.button("Stop Server", type="primary", on_click=api_server_stop) diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index 2b7d415..e005e16 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -394,11 +394,11 @@ def oci_init_client(client_type, config=None, retries=True): retry_strategy = oci.retry.NoneRetryStrategy() # Initialize Client (Workload Identity, Token and API) + client = None if not config: logger.info("OCI Authentication with Workload Identity") - signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer() - # Region is required for endpoint generation; not sure its value matters - client = client_type(config={"region": "us-ashburn-1"}, signer=signer) + oke_workload_signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer() + client = client_type(config={}, signer=oke_workload_signer) elif config and config["security_token_file"]: logger.info("OCI Authentication with Security Token") token = None @@ -406,7 +406,7 @@ def oci_init_client(client_type, config=None, retries=True): token = f.read() private_key = oci.signer.load_private_key_from_file(config["key_file"]) signer = oci.auth.signers.SecurityTokenSigner(token, private_key) - client = client_type({"region": config["region"]}, signer=signer) + client = client_type(config={"region": config["region"]}, signer=signer) else: logger.info("OCI Authentication as Standard") client = client_type(config, retry_strategy=retry_strategy) @@ -508,6 +508,8 @@ def oci_get_namespace(config, retries=True): raise OciException("Invalid Key Path") from ex except UnboundLocalError as ex: raise OciException("No Configuration - Disabling OCI") from ex + except Exception as ex: + raise OciException("Uncaught Exception - Disabling OCI") from ex return namespace diff --git a/app/src/oaim-sandbox.py b/app/src/oaim-sandbox.py index 4e203e0..74c1e7f 100644 --- a/app/src/oaim-sandbox.py +++ b/app/src/oaim-sandbox.py @@ -19,6 +19,7 @@ from content.db_config import initialize_streamlit as db_initialize from content.prompt_eng import initialize_streamlit as prompt_initialize from content.oci_config import initialize_streamlit as oci_initialize +import content.api_server as api_server_content logger = logging_config.logging.getLogger("sandbox") @@ -38,6 +39,7 @@ def main(): model_initialize() prompt_initialize() oci_initialize() + api_server_content.initialize_streamlit() # Setup rag_params into state enable as default if "rag_params" not in state: @@ -47,6 +49,10 @@ def main(): if "rag_filter" not in state: state.rag_filter = {} + # Start the API server + if state.api_server_config["auto_start"]: + api_server_content.api_server_start() + # GUI Defaults css = """