diff --git a/.gitignore b/.gitignore index ff1c352..0939aa7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ app/THIRD_PARTY_LICENSES.txt rag/** sbin/** helm/values.yaml +**/*.bak ############################################################################## # Enviroment (PyVen, IDE, etc.) diff --git a/app/Dockerfile b/app/Dockerfile index e17fa05..4643ab7 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -12,7 +12,7 @@ RUN microdnf -y install python3.11 python3.11-pip RUN python3.11 -m venv --symlinks --upgrade-deps /opt/venv COPY ./requirements.txt /opt/requirements.txt RUN source /opt/venv/bin/activate && \ - pip3 install --upgrade pip wheel && \ + pip3 install --upgrade pip wheel setuptools && \ pip3 install -r /opt/requirements.txt RUN groupadd $RUNUSER && useradd -u 10001 -g $RUNUSER -md /app $RUNUSER; RUN install -d -m 0700 -o $RUNUSER -g $RUNUSER /app/.oci diff --git a/app/requirements.txt b/app/requirements.txt index 6397168..d5c294d 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -7,31 +7,32 @@ ## Installation Example (from repo root): ### python3.11 -m venv .venv ### source .venv/bin/activate -### pip3 install --upgrade pip wheel +### pip3 install --upgrade pip wheel setuptools ### pip3 install -r app/requirements.txt ## Top-Level brings in the required dependencies, if adding modules, try to find the minimum required bokeh==3.6.0 evaluate==0.4.3 faiss-cpu==1.9.0 -giskard==2.15.2 -IPython==8.28.0 +giskard==2.15.3 +IPython==8.29.0 langchain-cohere==0.3.1 -langchain-community==0.3.2 -langchain-huggingface==0.1.0 +langchain-community==0.3.5 +langchain-huggingface==0.1.2 langchain-ollama==0.2.0 -langchain-openai==0.2.2 -llama_index==0.11.18 +langchain-openai==0.2.6 +langgraph==0.2.45 +llama_index==0.11.22 lxml==5.3.0 matplotlib==3.9.2 oci>=2.0.0 oracledb>=2.0.0 plotly==5.24.1 streamlit==1.39.0 -umap-learn==0.5.6 +umap-learn==0.5.7 ## For Licensing Purposes; ensures no GPU modules are installed ## as part of langchain-huggingface -f https://download.pytorch.org/whl/cpu/torch -torch==2.4.1+cpu ; sys_platform == "linux" -torch==2.2.2 ; sys_platform == "darwin" \ No newline at end of file +torch==2.5.1+cpu ; sys_platform == "linux" +torch==2.5.1 ; sys_platform == "darwin" \ No newline at end of file diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py index 17f2d74..dbcf31f 100644 --- a/app/src/content/api_server.py +++ b/app/src/content/api_server.py @@ -3,10 +3,11 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore streamlit +# spell-checker:ignore streamlit, langchain, llms + import inspect -import threading import time +import threading # Streamlit import streamlit as st @@ -15,9 +16,11 @@ # Utilities import modules.st_common as st_common import modules.api_server as api_server - import modules.logging_config as logging_config +# History +from langchain_community.chat_message_histories import StreamlitChatMessageHistory + logger = logging_config.logging.getLogger("api_server") @@ -32,17 +35,20 @@ def initialize_streamlit(): def display_logs(): - log_placeholder = st.empty() # A placeholder to update logs - logs = [] # Store logs for display - try: while "server_thread" in st.session_state: try: # Retrieve log from queue (non-blocking) - log_item = api_server.log_queue.get_nowait() - logs.append(log_item) - # Update the placeholder with new logs - log_placeholder.text("\n".join(logs)) + msg = api_server.log_queue.get_nowait() + logger.info("API Msg: %s", msg) + if "message" in msg: + st.chat_message("human").write(msg["message"]) + else: + if state.rag_params["enable"]: + st.chat_message("ai").write(msg["answer"]) + st_common.show_rag_refs(msg["context"]) + else: + st.chat_message("ai").write(msg.content) except api_server.queue.Empty: time.sleep(0.1) # Avoid busy-waiting finally: @@ -50,31 +56,40 @@ def display_logs(): def api_server_start(): - state.api_server_config["port"] = state.user_api_server_port - state.api_server_config["key"] = state.user_api_server_key + chat_history = StreamlitChatMessageHistory(key="api_chat_history") + if "user_api_server_port" in state: + state.api_server_config["port"] = state.user_api_server_port + if "user_api_server_key" in state: + state.api_server_config["key"] = state.user_api_server_key if "initialized" in state and state.initialized: if "server_thread" not in state: - state.httpd = api_server.run_server( - state.api_server_config["port"], - state.chat_manager, - state.rag_params, - state.lm_instr, - state.context_instr, - state.api_server_config["key"], - ) - - # Start the server in the thread - def api_server_process(httpd): - httpd.serve_forever() - - state.server_thread = threading.Thread( - target=api_server_process, - # Trailing , ensures tuple is passed - args=(state.httpd,), - daemon=True, - ) - state.server_thread.start() - logger.info("Started API Server on port: %i", state.api_server_config["port"]) + try: + state.httpd = api_server.run_server( + state.api_server_config["port"], + state.chat_manager, + state.rag_params, + state.lm_instr, + state.context_instr, + state.api_server_config["key"], + chat_history, + state.user_chat_history, + ) + + # Start the server in the thread + def api_server_process(httpd): + httpd.serve_forever() + + state.server_thread = threading.Thread( + target=api_server_process, + # Trailing , ensures tuple is passed + args=(state.httpd,), + daemon=True, + ) + state.server_thread.start() + logger.info("Started API Server on port: %i", state.api_server_config["port"]) + except OSError: + if not state.api_server_config["auto_start"]: + st.error("Port is already in use.") else: st.warning("API Server is already running.") else: @@ -106,22 +121,37 @@ def api_server_stop(): ############################################################################# def main(): """Streamlit GUI""" - initialize_streamlit() st.header("API Server") - - # LLM Params - ll_model = st_common.lm_sidebar() - # Initialize RAG st_common.initialize_rag() + # Setup History + chat_history = StreamlitChatMessageHistory(key="api_chat_history") - # RAG - st_common.rag_sidebar() + ######################################################################### + # Sidebar Settings + ######################################################################### + enabled_llms = sum(model_info["enabled"] for model_info in state.ll_model_config.values()) + if enabled_llms > 0: + initialize_streamlit() + enable_history = st.sidebar.checkbox( + "Enable History and Context?", + value=True, + key="user_chat_history", + ) + if st.sidebar.button("Clear History", disabled=not enable_history): + chat_history.clear() + st.sidebar.divider() + ll_model = st_common.lm_sidebar() + st_common.rag_sidebar() + else: + st.error("No chat models are configured and/or enabled.", icon="🚨") + st.stop() ######################################################################### - # Initialize the Client + # Main ######################################################################### if "initialized" not in state: + api_server_stop() if not state.rag_params["enable"] or all( state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"] ): @@ -130,10 +160,6 @@ def main(): state.initialized = True st_common.update_rag() logger.debug("Force rerun to save state") - if "server_thread" in state: - logger.info("Restarting API Server") - api_server_stop() - api_server_start() st.rerun() except Exception as ex: logger.exception(ex, exc_info=False) @@ -143,18 +169,14 @@ def main(): st.rerun() st.stop() else: - # RAG Enabled but not configured - if "server_thread" in state: - logger.info("Stopping API Server") - api_server_stop() + st.error("Not all required RAG options are set, please review or disable RAG.") + st.stop() - ######################################################################### - # API Server - ######################################################################### server_running = False if "server_thread" in state: server_running = True - st.success("API Server is Running") + elif state.api_server_config["auto_start"]: + server_running = True left, right = st.columns([0.2, 0.8]) left.number_input( @@ -165,6 +187,7 @@ def main(): key="user_api_server_port", disabled=server_running, ) + right.text_input( "API Server Key:", type="password", @@ -173,15 +196,20 @@ def main(): disabled=server_running, ) - if "server_thread" in state: - st.button("Stop Server", type="primary", on_click=api_server_stop) - elif "initialized" in state and state.initialized: - st.button("Start Server", type="primary", on_click=api_server_start) + if state.api_server_config["auto_start"]: + api_server_start() + st.success("API Server automatically started.") else: - st.error("Not all required RAG options are set, please review or disable RAG.") + if server_running: + st.button("Stop Server", type="primary", on_click=api_server_stop) + else: + st.button("Start Server", type="primary", on_click=api_server_start) - st.subheader("Activity") + ######################################################################### + # API Server Centre + ######################################################################### if "server_thread" in state: + st.subheader("Activity") with st.container(border=True): display_logs() diff --git a/app/src/content/chatbot.py b/app/src/content/chatbot.py index 12e0953..4ec6ef3 100644 --- a/app/src/content/chatbot.py +++ b/app/src/content/chatbot.py @@ -21,50 +21,6 @@ logger = logging_config.logging.getLogger("chatbot") -############################################################################# -# Functions -############################################################################# -def show_refs(context): - """When RAG Enabled, show the references""" - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) - - column_sizes = [10, 8, 8, 8, 2, 24] - cols = st.columns(column_sizes) - # Create a button in each column - links = set() - with cols[0]: - st.markdown("**References:**") - # Limit the maximum number of items to 3 (counting from 0) - max_items = min(len(context), 3) - - # Loop through the chunks and display them - for i in range(max_items): - with cols[i + 1]: - chunk = context[i] - links.add(chunk.metadata["source"]) - with st.popover(f"Ref: {i+1}"): - st.markdown(chunk.metadata["source"]) - st.markdown(chunk.page_content) - st.markdown(chunk.metadata["id"]) - - for link in links: - st.markdown("- " + link) - - ############################################################################# # MAIN ############################################################################# @@ -156,7 +112,7 @@ def main(): full_context = chunk["context"] message_placeholder.markdown(full_answer) if full_context: - show_refs(full_context) + st_common.show_rag_refs(full_context) else: st.chat_message("ai").write_stream(response) except Exception as ex: diff --git a/app/src/content/model_config.py b/app/src/content/model_config.py index 490b9ee..fe9bbfd 100644 --- a/app/src/content/model_config.py +++ b/app/src/content/model_config.py @@ -157,7 +157,6 @@ def main(): if update_ll_model: st.success("Language Model Configuration - Updated", icon="✅") - st.header("Embedding Models") with st.form("update_embed_model_config"): # Create table header @@ -210,5 +209,6 @@ def main(): if update_embed_model: st.success("Embedding Model Configuration - Updated", icon="✅") + if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: main() diff --git a/app/src/content/split_embed.py b/app/src/content/split_embed.py index 8a15372..57ef864 100644 --- a/app/src/content/split_embed.py +++ b/app/src/content/split_embed.py @@ -450,7 +450,7 @@ def main(): model, distance_metric, split_docos, - 0 if rate_limit is None else rate_limit + 0 if rate_limit is None else rate_limit, ) placeholder.empty() st_common.reset_rag() diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py index 179fb6c..3fdda66 100644 --- a/app/src/modules/api_server.py +++ b/app/src/modules/api_server.py @@ -2,6 +2,7 @@ Copyright (c) 2023, 2024, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ + # spell-checker:ignore streamlit, langchain import os @@ -12,15 +13,12 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import urlparse -from langchain_community.chat_message_histories import StreamlitChatMessageHistory - # Utilities import modules.chatbot as chatbot import modules.logging_config as logging_config logger = logging_config.logging.getLogger("modules.api_server") - # Create a queue to store the requests and responses log_queue = queue.Queue() @@ -40,64 +38,40 @@ def generate_api_key(length=32): # Generates a URL-safe, base64-encoded random string with the given length return secrets.token_urlsafe(length) - auto_port = find_available_port() - auto_api_key = generate_api_key() - return { - "port": os.environ.get("API_SERVER_PORT", default=auto_port), - "key": os.environ.get("API_SERVER_KEY", default=auto_api_key), - } + api_server_port = os.environ.get("API_SERVER_PORT") + api_server_key = os.environ.get("API_SERVER_KEY") + auto_start = bool(api_server_port and api_server_key) -def get_answer_fn( - question: str, - history=None, - chat_manager=None, - rag_params=None, - lm_instr=None, - context_instr=None, -) -> str: - """Send for completion""" - # Format appropriately the history for your RAG agent - chat_history_api = StreamlitChatMessageHistory(key="empty") - chat_history_api.clear() - if history: - for h in history: - if h["role"] == "assistant": - chat_history_api.add_ai_message(h["content"]) - else: - chat_history_api.add_user_message(h["content"]) - - try: - response = chatbot.generate_response( - chat_mgr=chat_manager, - input=question, - chat_history=chat_history_api, - enable_history=True, - rag_params=rag_params, - chat_instr=lm_instr, - context_instr=context_instr, - stream=False, - ) - logger.info("MSG from Chatbot API: %s", response) - if rag_params["enable"]: - return response["answer"] - else: - return response.content - except Exception as ex: - return f"I'm sorry, something's gone wrong: {ex}" + return { + "port": int(api_server_port) if api_server_port else find_available_port(), + "key": api_server_key if api_server_key else generate_api_key(), + "auto_start": auto_start, + } class ChatbotHTTPRequestHandler(BaseHTTPRequestHandler): """Handler for mini-chatbot""" def __init__( - self, *args, chat_manager=None, rag_params=None, lm_instr=None, context_instr=None, api_key=None, **kwargs + self, + *args, + chat_manager=None, + rag_params=None, + lm_instr=None, + context_instr=None, + api_key=None, + chat_history=None, + enable_history=False, + **kwargs, ): self.chat_manager = chat_manager self.rag_params = rag_params self.lm_instr = lm_instr self.context_instr = context_instr self.api_key = api_key + self.chat_history = chat_history + self.enable_history = enable_history super().__init__(*args, **kwargs) def do_OPTIONS(self): # pylint: disable=invalid-name @@ -126,38 +100,39 @@ def do_POST(self): # pylint: disable=invalid-name content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length).decode("utf-8") - # Log the raw body - logger.info("Raw Body: %s", post_data) try: # Parse the POST data as JSON post_json = json.loads(post_data) # Extract the 'message' field from the JSON message = post_json.get("message") - + response = None if message: # Log the incoming message logger.info("MSG to Chatbot API: %s", message) # Call your function to get the chatbot response - answer = get_answer_fn( - message, None, self.chat_manager, self.rag_params, self.lm_instr, self.context_instr + response = chatbot.generate_response( + chat_mgr=self.chat_manager, + input=message, + chat_history=self.chat_history, + enable_history=self.enable_history, + rag_params=self.rag_params, + chat_instr=self.lm_instr, + context_instr=self.context_instr, + stream=False, ) - - # Prepare the response as JSON - response = {"choices": [{"message": {"content": answer}}]} self.send_response(200) + # Process response to JSON else: - # If no message is provided, return an error - response = {"error": "No 'message' field found in request."} + json_response = {"error": "No 'message' field found in request."} self.send_response(400) # Bad request - # Add request/response to the queue - log_queue.put(f"Request: {post_data}") - log_queue.put(f"Response: {response}") + # Add request/response to the queue for output + log_queue.put(post_json) + log_queue.put(response) except json.JSONDecodeError: - # If JSON parsing fails, return an error - response = {"error": "Invalid JSON in request."} + json_response = {"error": "Invalid JSON in request."} self.send_response(400) # Bad request else: # Invalid or missing API Key @@ -170,32 +145,47 @@ def do_POST(self): # pylint: disable=invalid-name else: # Return a 404 response for unknown paths self.send_response(404) - response = {"error": "Path not found."} + json_response = {"error": "Path not found."} # Send the response self.send_header("Access-Control-Allow-Origin", "*") # Add CORS header self.send_header("Content-Type", "application/json") self.end_headers() - self.wfile.write(json.dumps(response).encode("utf-8")) + full_context = None + max_items = 0 + if self.rag_params["enable"]: + if "context" in response: + full_context = response["context"] + max_items = min(len(full_context), 3) + if full_context: + sources = set() + for i in range(max_items): + chunk = full_context[i] + sources.add(os.path.basename(chunk.metadata["source"])) + json_response = {"answer": response["answer"], "sources": list(sources)} + else: + json_response = {"answer": response.content} -def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key): - def create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key): - def handler(*args, **kwargs): - ChatbotHTTPRequestHandler( - *args, - chat_manager=chat_manager, - rag_params=rag_params, - lm_instr=lm_instr, - context_instr=context_instr, - api_key=api_key, - **kwargs, - ) + self.wfile.write(json.dumps(json_response).encode("utf-8")) - return handler - handler_with_params = create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key) +def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key, chat_history, enable_history): + # Define the request handler function + def handler(*args, **kwargs): + return ChatbotHTTPRequestHandler( + *args, + chat_manager=chat_manager, + rag_params=rag_params, + lm_instr=lm_instr, + context_instr=context_instr, + api_key=api_key, + chat_history=chat_history, + enable_history=enable_history, + **kwargs, + ) + server_address = ("", port) - httpd = HTTPServer(server_address, handler_with_params) + httpd = HTTPServer(server_address, handler) return httpd diff --git a/app/src/modules/chatbot.py b/app/src/modules/chatbot.py index 0427dab..48d8c55 100644 --- a/app/src/modules/chatbot.py +++ b/app/src/modules/chatbot.py @@ -19,6 +19,7 @@ logger = logging_config.logging.getLogger("modules.chatbot") + def generate_response( chat_mgr, input, chat_history, enable_history, rag_params, chat_instr, context_instr=None, stream=False ): diff --git a/app/src/modules/logging_config.py b/app/src/modules/logging_config.py index 694cfcc..eeb2dba 100644 --- a/app/src/modules/logging_config.py +++ b/app/src/modules/logging_config.py @@ -28,5 +28,6 @@ def setup_logging(): # Sagemaker continuously complains about config, suppress logging.getLogger("sagemaker.config").setLevel(logging.WARNING) + # Call setup_logging when this module is imported setup_logging() diff --git a/app/src/modules/metadata.py b/app/src/modules/metadata.py index 6edd150..4abfd12 100644 --- a/app/src/modules/metadata.py +++ b/app/src/modules/metadata.py @@ -210,7 +210,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 512, - "dimensions": 768 + "dimensions": 768, }, "text-embedding-3-small": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -219,8 +219,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 1536 - + "dimensions": 1536, }, "text-embedding-3-large": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -229,7 +228,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 3072 + "dimensions": 3072, }, "text-embedding-ada-002": { "enabled": os.getenv("OPENAI_API_KEY") is not None, @@ -238,7 +237,7 @@ def embedding_models(): "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "chunk_max": 8191, - "dimensions": 1536 + "dimensions": 1536, }, "embed-english-v3.0": { "enabled": os.getenv("COHERE_API_KEY") is not None, @@ -247,7 +246,7 @@ def embedding_models(): "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, "chunk_max": 512, - "dimensions": 1024 + "dimensions": 1024, }, "embed-english-light-v3.0": { "enabled": os.getenv("COHERE_API_KEY") is not None, @@ -256,7 +255,7 @@ def embedding_models(): "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, "chunk_max": 512, - "dimensions": 384 + "dimensions": 384, }, "mxbai-embed-large": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -265,7 +264,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 512, - "dimensions": 1024 + "dimensions": 1024, }, "nomic-embed-text": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -274,7 +273,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 8192, - "dimensions": 768 + "dimensions": 768, }, "all-minilm": { "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, @@ -283,7 +282,7 @@ def embedding_models(): "api_key": "", "openai_compat": True, "chunk_max": 256, - "dimensions": 384 + "dimensions": 384, }, } return embedding_models_dict diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py index bde682d..c6a3d4f 100644 --- a/app/src/modules/report_utils.py +++ b/app/src/modules/report_utils.py @@ -67,9 +67,7 @@ def record_update(): if "reference_answer_input" not in state: state.reference_answer_input = state.df.iloc[state.index]["reference_answer"] if "reference_context_input" not in state: - state.reference_context_input = state.df.iloc[state.index][ - "reference_context" - ] + state.reference_context_input = state.df.iloc[state.index]["reference_context"] if "metadata_input" not in state: state.metadata_input = state.df.iloc[state.index]["metadata"] @@ -86,12 +84,8 @@ def record_update(): state.index -= 1 state.hide_input = state.df.at[state.index, "hide"] state.question_input = state.df.at[state.index, "question"] - state.reference_answer_input = state.df.at[ - state.index, "reference_answer" - ] - state.reference_context_input = state.df.at[ - state.index, "reference_context" - ] + state.reference_answer_input = state.df.at[state.index, "reference_answer"] + state.reference_context_input = state.df.at[state.index, "reference_context"] state.metadata_input = state.df.at[state.index, "metadata"] # Button to move to the next question @@ -102,12 +96,8 @@ def record_update(): state.index += 1 state.hide_input = state.df.at[state.index, "hide"] state.question_input = state.df.at[state.index, "question"] - state.reference_answer_input = state.df.at[ - state.index, "reference_answer" - ] - state.reference_context_input = state.df.at[ - state.index, "reference_context" - ] + state.reference_answer_input = state.df.at[state.index, "reference_answer"] + state.reference_context_input = state.df.at[state.index, "reference_context"] state.metadata_input = state.df.at[state.index, "metadata"] # Button to save the current input value @@ -117,12 +107,8 @@ def record_update(): # Save the current input value in the DataFrame state.df.at[state.index, "hide"] = state.hide_input state.df.at[state.index, "question"] = state.question_input - state.df.at[state.index, "reference_answer"] = ( - state.reference_answer_input - ) - state.df.at[state.index, "reference_context"] = ( - state.reference_context_input - ) + state.df.at[state.index, "reference_answer"] = state.reference_answer_input + state.df.at[state.index, "reference_context"] = state.reference_context_input # state.df.at[state.index, 'metadata'] = state.metadata_input # It's read-only logger.info("--------SAVE----------------------") @@ -141,20 +127,14 @@ def record_update(): state.df.to_json(file_path, orient="records", lines=True, index=False) # Text input for the question, storing the user's input in the session state - state.index_output = st.write( - "Record: " + str(state.index + 1) + "/" + str(state.df.shape[0]) - ) + state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0])) state.hide_input = st.checkbox("Hide", value=state.hide_input) state.question_input = st.text_area("question", height=1, value=state.question_input) - state.reference_answer_input = st.text_area( - "Reference answer", height=1, value=state.reference_answer_input - ) + state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input) state.reference_context_input = st.text_area( "Reference context", height=10, value=state.reference_context_input, disabled=True ) - state.metadata_input = st.text_area( - "Metadata", height=1, value=state.metadata_input, disabled=True - ) + state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True) if save_clicked: st.success("Q&A saved successfully!") diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py index 7160f92..078e872 100644 --- a/app/src/modules/st_common.py +++ b/app/src/modules/st_common.py @@ -141,6 +141,47 @@ def initialize_rag(): st.error("Application has not been initialized, please restart.", icon="⛑️") +def show_rag_refs(context): + """When RAG Enabled, show the references""" + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + column_sizes = [10, 8, 8, 8, 2, 24] + cols = st.columns(column_sizes) + # Create a button in each column + links = set() + with cols[0]: + st.markdown("**References:**") + # Limit the maximum number of items to 3 (counting from 0) + max_items = min(len(context), 3) + + # Loop through the chunks and display them + for i in range(max_items): + with cols[i + 1]: + chunk = context[i] + links.add(chunk.metadata["source"]) + with st.popover(f"Ref: {i+1}"): + st.markdown(chunk.metadata["source"]) + st.markdown(chunk.page_content) + st.markdown(chunk.metadata["id"]) + + for link in links: + st.markdown("- " + link) + + def initialize_chatbot(ll_model): """Initialize the Chatbot""" logger.info("Initializing ChatBot using %s; RAG: %s", ll_model, state.rag_params["enable"]) @@ -273,87 +314,112 @@ def refresh_rag_filtered(): if user_alias: rag_filt_alias = [user_alias] else: - rag_filt_alias = [ - v.get("alias", None) - for v in state.vs_tables.values() - if (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_alias = [ + v.get("alias", None) + for v in state.vs_tables.values() + if (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_model: rag_filt_model = [user_model] else: - rag_filt_model = [ - v["model"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_model = [ + v["model"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_chunk_size: rag_filt_chunk_size = [user_chunk_size] else: - rag_filt_chunk_size = [ - v["chunk_size"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_chunk_size = [ + v["chunk_size"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_chunk_overlap: rag_filt_chunk_overlap = [user_chunk_overlap] else: - rag_filt_chunk_overlap = [ - v["chunk_overlap"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) - ] + try: + rag_filt_chunk_overlap = [ + v["chunk_overlap"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_distance_metric is None or v["distance_metric"] == user_distance_metric) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() if user_distance_metric: rag_filt_distance_metric = [user_distance_metric] else: - rag_filt_distance_metric = [ - v["distance_metric"] - for v in state.vs_tables.values() - if (user_alias is None or v.get("alias", None) == user_alias) - and (user_model is None or v["model"] == user_model) - and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) - and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) - ] + try: + rag_filt_distance_metric = [ + v["distance_metric"] + for v in state.vs_tables.values() + if (user_alias is None or v.get("alias", None) == user_alias) + and (user_model is None or v["model"] == user_model) + and (user_chunk_size is None or v["chunk_size"] == user_chunk_size) + and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap) + ] + except AttributeError: + # st.session_state has no attribute "vs_tables" + clear_initialized() # Remove duplicates and sort - state.rag_filter["alias"] = sorted(set(rag_filt_alias)) - state.rag_filter["model"] = sorted(set(rag_filt_model)) - state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size)) - state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap)) - state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric)) + try: + state.rag_filter["alias"] = sorted(set(rag_filt_alias)) + state.rag_filter["model"] = sorted(set(rag_filt_model)) + state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size)) + state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap)) + state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric)) + except UnboundLocalError: + clear_initialized() # (Re)set the index to previously selected option attributes = ["alias", "model", "chunk_size", "chunk_overlap", "distance_metric"] for attr in attributes: - filtered_list = state.rag_filter[attr] try: + filtered_list = state.rag_filter[attr] user_value = getattr(state, f"rag_user_{attr}") except AttributeError: setattr(state, f"rag_user_{attr}", []) + except KeyError: + clear_initialized() try: idx = 0 if len(filtered_list) == 1 else filtered_list.index(user_value) - except (ValueError, AttributeError): + except (ValueError, AttributeError, UnboundLocalError): idx = None state.rag_user_idx[attr] = idx - rag_enable = st.sidebar.checkbox( + st.sidebar.checkbox( "RAG?", value=state.rag_params["enable"], key="rag_user_enable", @@ -362,7 +428,7 @@ def refresh_rag_filtered(): on_change=reset_rag, ) - if rag_enable: + if state.rag_params["enable"]: set_default_state("rag_user_rerank", False) st.sidebar.checkbox( "Enable Re-Ranking?", diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py index 2c46da0..e005e16 100644 --- a/app/src/modules/utilities.py +++ b/app/src/modules/utilities.py @@ -51,29 +51,25 @@ # General ############################################################################### def is_url_accessible(url): - """Check that URL is Available""" - logger.debug("Checking %s is accessible", url) + """Check if the URL is accessible.""" + logger.debug("Checking if %s is accessible", url) + try: response = requests.get(url, timeout=2) - logger.info("Checking %s resulted in %s", url, response.status_code) - # Check if the response status code is 200 (OK) 403 (Forbidden) 404 (Not Found) 421 (Misdirected) - if response.status_code in [200, 403, 404, 421]: + logger.info("Response for %s: %s", url, response.status_code) + + if response.status_code in {200, 403, 404, 421}: return True, None - else: - err_msg = f"{url} is not accessible. (Status: {response.status_code})" - logger.warning(err_msg) - return False, err_msg - except requests.exceptions.ConnectionError: - err_msg = f"{url} is not accessible. (Connection Error)" + + err_msg = f"{url} is not accessible. (Status: {response.status_code})" logger.warning(err_msg) return False, err_msg - except requests.exceptions.Timeout: - err_msg = f"{url} is not accessible. (Request Timeout)" + + except requests.exceptions.RequestException as ex: + err_msg = f"{url} is not accessible. ({type(ex).__name__})" logger.warning(err_msg) - return False, err_msg - except requests.RequestException as ex: logger.exception(ex, exc_info=False) - return False, ex + return False, err_msg ############################################################################### @@ -97,6 +93,15 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): logger.debug("Matching LLM API: %s", llm_api) + common_params = { + "model": model, + "temperature": lm_params["temperature"][0], + "max_tokens": lm_params["max_tokens"][0], + "top_p": lm_params["top_p"][0], + "frequency_penalty": lm_params["frequency_penalty"][0], + "presence_penalty": lm_params["presence_penalty"][0], + } + ## Start - Add Additional Model Authentication Here client = None if giskarded: @@ -104,49 +109,13 @@ def get_ll_model(model, ll_models_config=None, giskarded=False): _client = OpenAI(api_key=giskard_key, base_url=f"{llm_url}/v1/") client = OpenAIClient(model=model, client=_client) elif llm_api == "OpenAI": - client = ChatOpenAI( - api_key=lm_params["api_key"], - model_name=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - top_p=lm_params["top_p"][0], - frequency_penalty=lm_params["frequency_penalty"][0], - presence_penalty=lm_params["presence_penalty"][0], - ) + client = ChatOpenAI(api_key=lm_params["api_key"], **common_params) elif llm_api == "Cohere": - client = ChatCohere( - cohere_api_key=lm_params["api_key"], - model=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - top_p=lm_params["top_p"][0], - frequency_penalty=lm_params["frequency_penalty"][0], - presence_penalty=lm_params["presence_penalty"][0], - ) + client = ChatCohere(cohere_api_key=lm_params["api_key"], **common_params) elif llm_api == "ChatPerplexity": - client = ChatPerplexity( - pplx_api_key=lm_params["api_key"], - model=model, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - model_kwargs={ - "top_p": lm_params["top_p"][0], - "frequency_penalty": lm_params["frequency_penalty"][0], - "presence_penalty": lm_params["presence_penalty"][0], - }, - ) + client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params) elif llm_api == "ChatOllama": - client = ChatOllama( - model=model, - base_url=llm_url, - temperature=lm_params["temperature"][0], - max_tokens=lm_params["max_tokens"][0], - model_kwargs={ - "top_p": lm_params["top_p"][0], - "frequency_penalty": lm_params["frequency_penalty"][0], - "presence_penalty": lm_params["presence_penalty"][0], - }, - ) + client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params) ## End - Add Additional Model Authentication Here api_accessible, err_msg = is_url_accessible(llm_url) @@ -264,7 +233,7 @@ def init_vs(db_conn, embedding_function, store_table, distance_metric): def get_vs_table(model, chunk_size, chunk_overlap, distance_metric, embed_alias=None): - """Get a list of Vector Store Tables""" + """Return the concatenated VS Table name and comment""" chunk_overlap_ceil = math.ceil(chunk_overlap) table_string = f"{model}_{chunk_size}_{chunk_overlap_ceil}_{distance_metric}" if embed_alias: @@ -425,11 +394,11 @@ def oci_init_client(client_type, config=None, retries=True): retry_strategy = oci.retry.NoneRetryStrategy() # Initialize Client (Workload Identity, Token and API) + client = None if not config: logger.info("OCI Authentication with Workload Identity") - signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer() - # Region is required for endpoint generation; not sure its value matters - client = client_type(config={"region": "us-ashburn-1"}, signer=signer) + oke_workload_signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer() + client = client_type(config={}, signer=oke_workload_signer) elif config and config["security_token_file"]: logger.info("OCI Authentication with Security Token") token = None @@ -437,7 +406,7 @@ def oci_init_client(client_type, config=None, retries=True): token = f.read() private_key = oci.signer.load_private_key_from_file(config["key_file"]) signer = oci.auth.signers.SecurityTokenSigner(token, private_key) - client = client_type({"region": config["region"]}, signer=signer) + client = client_type(config={"region": config["region"]}, signer=signer) else: logger.info("OCI Authentication as Standard") client = client_type(config, retry_strategy=retry_strategy) @@ -539,6 +508,8 @@ def oci_get_namespace(config, retries=True): raise OciException("Invalid Key Path") from ex except UnboundLocalError as ex: raise OciException("No Configuration - Disabling OCI") from ex + except Exception as ex: + raise OciException("Uncaught Exception - Disabling OCI") from ex return namespace diff --git a/app/src/oaim-sandbox.py b/app/src/oaim-sandbox.py index 4e203e0..522a38c 100644 --- a/app/src/oaim-sandbox.py +++ b/app/src/oaim-sandbox.py @@ -19,6 +19,7 @@ from content.db_config import initialize_streamlit as db_initialize from content.prompt_eng import initialize_streamlit as prompt_initialize from content.oci_config import initialize_streamlit as oci_initialize +import content.api_server as api_server_content logger = logging_config.logging.getLogger("sandbox") @@ -38,6 +39,7 @@ def main(): model_initialize() prompt_initialize() oci_initialize() + api_server_content.initialize_streamlit() # Setup rag_params into state enable as default if "rag_params" not in state: @@ -47,6 +49,12 @@ def main(): if "rag_filter" not in state: state.rag_filter = {} + # Start the API server + if state.api_server_config["auto_start"]: + api_server_content.api_server_start() + if "user_chat_history" not in state: + state.user_chat_history = True + # GUI Defaults css = """