diff --git a/.gitignore b/.gitignore
index ff1c352..0939aa7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,6 +15,7 @@ app/THIRD_PARTY_LICENSES.txt
rag/**
sbin/**
helm/values.yaml
+**/*.bak
##############################################################################
# Enviroment (PyVen, IDE, etc.)
diff --git a/app/Dockerfile b/app/Dockerfile
index e17fa05..4643ab7 100644
--- a/app/Dockerfile
+++ b/app/Dockerfile
@@ -12,7 +12,7 @@ RUN microdnf -y install python3.11 python3.11-pip
RUN python3.11 -m venv --symlinks --upgrade-deps /opt/venv
COPY ./requirements.txt /opt/requirements.txt
RUN source /opt/venv/bin/activate && \
- pip3 install --upgrade pip wheel && \
+ pip3 install --upgrade pip wheel setuptools && \
pip3 install -r /opt/requirements.txt
RUN groupadd $RUNUSER && useradd -u 10001 -g $RUNUSER -md /app $RUNUSER;
RUN install -d -m 0700 -o $RUNUSER -g $RUNUSER /app/.oci
diff --git a/app/requirements.txt b/app/requirements.txt
index 6397168..d5c294d 100644
--- a/app/requirements.txt
+++ b/app/requirements.txt
@@ -7,31 +7,32 @@
## Installation Example (from repo root):
### python3.11 -m venv .venv
### source .venv/bin/activate
-### pip3 install --upgrade pip wheel
+### pip3 install --upgrade pip wheel setuptools
### pip3 install -r app/requirements.txt
## Top-Level brings in the required dependencies, if adding modules, try to find the minimum required
bokeh==3.6.0
evaluate==0.4.3
faiss-cpu==1.9.0
-giskard==2.15.2
-IPython==8.28.0
+giskard==2.15.3
+IPython==8.29.0
langchain-cohere==0.3.1
-langchain-community==0.3.2
-langchain-huggingface==0.1.0
+langchain-community==0.3.5
+langchain-huggingface==0.1.2
langchain-ollama==0.2.0
-langchain-openai==0.2.2
-llama_index==0.11.18
+langchain-openai==0.2.6
+langgraph==0.2.45
+llama_index==0.11.22
lxml==5.3.0
matplotlib==3.9.2
oci>=2.0.0
oracledb>=2.0.0
plotly==5.24.1
streamlit==1.39.0
-umap-learn==0.5.6
+umap-learn==0.5.7
## For Licensing Purposes; ensures no GPU modules are installed
## as part of langchain-huggingface
-f https://download.pytorch.org/whl/cpu/torch
-torch==2.4.1+cpu ; sys_platform == "linux"
-torch==2.2.2 ; sys_platform == "darwin"
\ No newline at end of file
+torch==2.5.1+cpu ; sys_platform == "linux"
+torch==2.5.1 ; sys_platform == "darwin"
\ No newline at end of file
diff --git a/app/src/content/api_server.py b/app/src/content/api_server.py
index 17f2d74..dbcf31f 100644
--- a/app/src/content/api_server.py
+++ b/app/src/content/api_server.py
@@ -3,10 +3,11 @@
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
-# spell-checker:ignore streamlit
+# spell-checker:ignore streamlit, langchain, llms
+
import inspect
-import threading
import time
+import threading
# Streamlit
import streamlit as st
@@ -15,9 +16,11 @@
# Utilities
import modules.st_common as st_common
import modules.api_server as api_server
-
import modules.logging_config as logging_config
+# History
+from langchain_community.chat_message_histories import StreamlitChatMessageHistory
+
logger = logging_config.logging.getLogger("api_server")
@@ -32,17 +35,20 @@ def initialize_streamlit():
def display_logs():
- log_placeholder = st.empty() # A placeholder to update logs
- logs = [] # Store logs for display
-
try:
while "server_thread" in st.session_state:
try:
# Retrieve log from queue (non-blocking)
- log_item = api_server.log_queue.get_nowait()
- logs.append(log_item)
- # Update the placeholder with new logs
- log_placeholder.text("\n".join(logs))
+ msg = api_server.log_queue.get_nowait()
+ logger.info("API Msg: %s", msg)
+ if "message" in msg:
+ st.chat_message("human").write(msg["message"])
+ else:
+ if state.rag_params["enable"]:
+ st.chat_message("ai").write(msg["answer"])
+ st_common.show_rag_refs(msg["context"])
+ else:
+ st.chat_message("ai").write(msg.content)
except api_server.queue.Empty:
time.sleep(0.1) # Avoid busy-waiting
finally:
@@ -50,31 +56,40 @@ def display_logs():
def api_server_start():
- state.api_server_config["port"] = state.user_api_server_port
- state.api_server_config["key"] = state.user_api_server_key
+ chat_history = StreamlitChatMessageHistory(key="api_chat_history")
+ if "user_api_server_port" in state:
+ state.api_server_config["port"] = state.user_api_server_port
+ if "user_api_server_key" in state:
+ state.api_server_config["key"] = state.user_api_server_key
if "initialized" in state and state.initialized:
if "server_thread" not in state:
- state.httpd = api_server.run_server(
- state.api_server_config["port"],
- state.chat_manager,
- state.rag_params,
- state.lm_instr,
- state.context_instr,
- state.api_server_config["key"],
- )
-
- # Start the server in the thread
- def api_server_process(httpd):
- httpd.serve_forever()
-
- state.server_thread = threading.Thread(
- target=api_server_process,
- # Trailing , ensures tuple is passed
- args=(state.httpd,),
- daemon=True,
- )
- state.server_thread.start()
- logger.info("Started API Server on port: %i", state.api_server_config["port"])
+ try:
+ state.httpd = api_server.run_server(
+ state.api_server_config["port"],
+ state.chat_manager,
+ state.rag_params,
+ state.lm_instr,
+ state.context_instr,
+ state.api_server_config["key"],
+ chat_history,
+ state.user_chat_history,
+ )
+
+ # Start the server in the thread
+ def api_server_process(httpd):
+ httpd.serve_forever()
+
+ state.server_thread = threading.Thread(
+ target=api_server_process,
+ # Trailing , ensures tuple is passed
+ args=(state.httpd,),
+ daemon=True,
+ )
+ state.server_thread.start()
+ logger.info("Started API Server on port: %i", state.api_server_config["port"])
+ except OSError:
+ if not state.api_server_config["auto_start"]:
+ st.error("Port is already in use.")
else:
st.warning("API Server is already running.")
else:
@@ -106,22 +121,37 @@ def api_server_stop():
#############################################################################
def main():
"""Streamlit GUI"""
- initialize_streamlit()
st.header("API Server")
-
- # LLM Params
- ll_model = st_common.lm_sidebar()
-
# Initialize RAG
st_common.initialize_rag()
+ # Setup History
+ chat_history = StreamlitChatMessageHistory(key="api_chat_history")
- # RAG
- st_common.rag_sidebar()
+ #########################################################################
+ # Sidebar Settings
+ #########################################################################
+ enabled_llms = sum(model_info["enabled"] for model_info in state.ll_model_config.values())
+ if enabled_llms > 0:
+ initialize_streamlit()
+ enable_history = st.sidebar.checkbox(
+ "Enable History and Context?",
+ value=True,
+ key="user_chat_history",
+ )
+ if st.sidebar.button("Clear History", disabled=not enable_history):
+ chat_history.clear()
+ st.sidebar.divider()
+ ll_model = st_common.lm_sidebar()
+ st_common.rag_sidebar()
+ else:
+ st.error("No chat models are configured and/or enabled.", icon="🚨")
+ st.stop()
#########################################################################
- # Initialize the Client
+ # Main
#########################################################################
if "initialized" not in state:
+ api_server_stop()
if not state.rag_params["enable"] or all(
state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"]
):
@@ -130,10 +160,6 @@ def main():
state.initialized = True
st_common.update_rag()
logger.debug("Force rerun to save state")
- if "server_thread" in state:
- logger.info("Restarting API Server")
- api_server_stop()
- api_server_start()
st.rerun()
except Exception as ex:
logger.exception(ex, exc_info=False)
@@ -143,18 +169,14 @@ def main():
st.rerun()
st.stop()
else:
- # RAG Enabled but not configured
- if "server_thread" in state:
- logger.info("Stopping API Server")
- api_server_stop()
+ st.error("Not all required RAG options are set, please review or disable RAG.")
+ st.stop()
- #########################################################################
- # API Server
- #########################################################################
server_running = False
if "server_thread" in state:
server_running = True
- st.success("API Server is Running")
+ elif state.api_server_config["auto_start"]:
+ server_running = True
left, right = st.columns([0.2, 0.8])
left.number_input(
@@ -165,6 +187,7 @@ def main():
key="user_api_server_port",
disabled=server_running,
)
+
right.text_input(
"API Server Key:",
type="password",
@@ -173,15 +196,20 @@ def main():
disabled=server_running,
)
- if "server_thread" in state:
- st.button("Stop Server", type="primary", on_click=api_server_stop)
- elif "initialized" in state and state.initialized:
- st.button("Start Server", type="primary", on_click=api_server_start)
+ if state.api_server_config["auto_start"]:
+ api_server_start()
+ st.success("API Server automatically started.")
else:
- st.error("Not all required RAG options are set, please review or disable RAG.")
+ if server_running:
+ st.button("Stop Server", type="primary", on_click=api_server_stop)
+ else:
+ st.button("Start Server", type="primary", on_click=api_server_start)
- st.subheader("Activity")
+ #########################################################################
+ # API Server Centre
+ #########################################################################
if "server_thread" in state:
+ st.subheader("Activity")
with st.container(border=True):
display_logs()
diff --git a/app/src/content/chatbot.py b/app/src/content/chatbot.py
index 12e0953..4ec6ef3 100644
--- a/app/src/content/chatbot.py
+++ b/app/src/content/chatbot.py
@@ -21,50 +21,6 @@
logger = logging_config.logging.getLogger("chatbot")
-#############################################################################
-# Functions
-#############################################################################
-def show_refs(context):
- """When RAG Enabled, show the references"""
- st.markdown(
- """
-
- """,
- unsafe_allow_html=True,
- )
-
- column_sizes = [10, 8, 8, 8, 2, 24]
- cols = st.columns(column_sizes)
- # Create a button in each column
- links = set()
- with cols[0]:
- st.markdown("**References:**")
- # Limit the maximum number of items to 3 (counting from 0)
- max_items = min(len(context), 3)
-
- # Loop through the chunks and display them
- for i in range(max_items):
- with cols[i + 1]:
- chunk = context[i]
- links.add(chunk.metadata["source"])
- with st.popover(f"Ref: {i+1}"):
- st.markdown(chunk.metadata["source"])
- st.markdown(chunk.page_content)
- st.markdown(chunk.metadata["id"])
-
- for link in links:
- st.markdown("- " + link)
-
-
#############################################################################
# MAIN
#############################################################################
@@ -156,7 +112,7 @@ def main():
full_context = chunk["context"]
message_placeholder.markdown(full_answer)
if full_context:
- show_refs(full_context)
+ st_common.show_rag_refs(full_context)
else:
st.chat_message("ai").write_stream(response)
except Exception as ex:
diff --git a/app/src/content/model_config.py b/app/src/content/model_config.py
index 490b9ee..fe9bbfd 100644
--- a/app/src/content/model_config.py
+++ b/app/src/content/model_config.py
@@ -157,7 +157,6 @@ def main():
if update_ll_model:
st.success("Language Model Configuration - Updated", icon="✅")
-
st.header("Embedding Models")
with st.form("update_embed_model_config"):
# Create table header
@@ -210,5 +209,6 @@ def main():
if update_embed_model:
st.success("Embedding Model Configuration - Updated", icon="✅")
+
if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename:
main()
diff --git a/app/src/content/split_embed.py b/app/src/content/split_embed.py
index 8a15372..57ef864 100644
--- a/app/src/content/split_embed.py
+++ b/app/src/content/split_embed.py
@@ -450,7 +450,7 @@ def main():
model,
distance_metric,
split_docos,
- 0 if rate_limit is None else rate_limit
+ 0 if rate_limit is None else rate_limit,
)
placeholder.empty()
st_common.reset_rag()
diff --git a/app/src/modules/api_server.py b/app/src/modules/api_server.py
index 179fb6c..3fdda66 100644
--- a/app/src/modules/api_server.py
+++ b/app/src/modules/api_server.py
@@ -2,6 +2,7 @@
Copyright (c) 2023, 2024, Oracle and/or its affiliates.
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
+
# spell-checker:ignore streamlit, langchain
import os
@@ -12,15 +13,12 @@
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import urlparse
-from langchain_community.chat_message_histories import StreamlitChatMessageHistory
-
# Utilities
import modules.chatbot as chatbot
import modules.logging_config as logging_config
logger = logging_config.logging.getLogger("modules.api_server")
-
# Create a queue to store the requests and responses
log_queue = queue.Queue()
@@ -40,64 +38,40 @@ def generate_api_key(length=32):
# Generates a URL-safe, base64-encoded random string with the given length
return secrets.token_urlsafe(length)
- auto_port = find_available_port()
- auto_api_key = generate_api_key()
- return {
- "port": os.environ.get("API_SERVER_PORT", default=auto_port),
- "key": os.environ.get("API_SERVER_KEY", default=auto_api_key),
- }
+ api_server_port = os.environ.get("API_SERVER_PORT")
+ api_server_key = os.environ.get("API_SERVER_KEY")
+ auto_start = bool(api_server_port and api_server_key)
-def get_answer_fn(
- question: str,
- history=None,
- chat_manager=None,
- rag_params=None,
- lm_instr=None,
- context_instr=None,
-) -> str:
- """Send for completion"""
- # Format appropriately the history for your RAG agent
- chat_history_api = StreamlitChatMessageHistory(key="empty")
- chat_history_api.clear()
- if history:
- for h in history:
- if h["role"] == "assistant":
- chat_history_api.add_ai_message(h["content"])
- else:
- chat_history_api.add_user_message(h["content"])
-
- try:
- response = chatbot.generate_response(
- chat_mgr=chat_manager,
- input=question,
- chat_history=chat_history_api,
- enable_history=True,
- rag_params=rag_params,
- chat_instr=lm_instr,
- context_instr=context_instr,
- stream=False,
- )
- logger.info("MSG from Chatbot API: %s", response)
- if rag_params["enable"]:
- return response["answer"]
- else:
- return response.content
- except Exception as ex:
- return f"I'm sorry, something's gone wrong: {ex}"
+ return {
+ "port": int(api_server_port) if api_server_port else find_available_port(),
+ "key": api_server_key if api_server_key else generate_api_key(),
+ "auto_start": auto_start,
+ }
class ChatbotHTTPRequestHandler(BaseHTTPRequestHandler):
"""Handler for mini-chatbot"""
def __init__(
- self, *args, chat_manager=None, rag_params=None, lm_instr=None, context_instr=None, api_key=None, **kwargs
+ self,
+ *args,
+ chat_manager=None,
+ rag_params=None,
+ lm_instr=None,
+ context_instr=None,
+ api_key=None,
+ chat_history=None,
+ enable_history=False,
+ **kwargs,
):
self.chat_manager = chat_manager
self.rag_params = rag_params
self.lm_instr = lm_instr
self.context_instr = context_instr
self.api_key = api_key
+ self.chat_history = chat_history
+ self.enable_history = enable_history
super().__init__(*args, **kwargs)
def do_OPTIONS(self): # pylint: disable=invalid-name
@@ -126,38 +100,39 @@ def do_POST(self): # pylint: disable=invalid-name
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length).decode("utf-8")
- # Log the raw body
- logger.info("Raw Body: %s", post_data)
try:
# Parse the POST data as JSON
post_json = json.loads(post_data)
# Extract the 'message' field from the JSON
message = post_json.get("message")
-
+ response = None
if message:
# Log the incoming message
logger.info("MSG to Chatbot API: %s", message)
# Call your function to get the chatbot response
- answer = get_answer_fn(
- message, None, self.chat_manager, self.rag_params, self.lm_instr, self.context_instr
+ response = chatbot.generate_response(
+ chat_mgr=self.chat_manager,
+ input=message,
+ chat_history=self.chat_history,
+ enable_history=self.enable_history,
+ rag_params=self.rag_params,
+ chat_instr=self.lm_instr,
+ context_instr=self.context_instr,
+ stream=False,
)
-
- # Prepare the response as JSON
- response = {"choices": [{"message": {"content": answer}}]}
self.send_response(200)
+ # Process response to JSON
else:
- # If no message is provided, return an error
- response = {"error": "No 'message' field found in request."}
+ json_response = {"error": "No 'message' field found in request."}
self.send_response(400) # Bad request
- # Add request/response to the queue
- log_queue.put(f"Request: {post_data}")
- log_queue.put(f"Response: {response}")
+ # Add request/response to the queue for output
+ log_queue.put(post_json)
+ log_queue.put(response)
except json.JSONDecodeError:
- # If JSON parsing fails, return an error
- response = {"error": "Invalid JSON in request."}
+ json_response = {"error": "Invalid JSON in request."}
self.send_response(400) # Bad request
else:
# Invalid or missing API Key
@@ -170,32 +145,47 @@ def do_POST(self): # pylint: disable=invalid-name
else:
# Return a 404 response for unknown paths
self.send_response(404)
- response = {"error": "Path not found."}
+ json_response = {"error": "Path not found."}
# Send the response
self.send_header("Access-Control-Allow-Origin", "*") # Add CORS header
self.send_header("Content-Type", "application/json")
self.end_headers()
- self.wfile.write(json.dumps(response).encode("utf-8"))
+ full_context = None
+ max_items = 0
+ if self.rag_params["enable"]:
+ if "context" in response:
+ full_context = response["context"]
+ max_items = min(len(full_context), 3)
+ if full_context:
+ sources = set()
+ for i in range(max_items):
+ chunk = full_context[i]
+ sources.add(os.path.basename(chunk.metadata["source"]))
+ json_response = {"answer": response["answer"], "sources": list(sources)}
+ else:
+ json_response = {"answer": response.content}
-def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key):
- def create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key):
- def handler(*args, **kwargs):
- ChatbotHTTPRequestHandler(
- *args,
- chat_manager=chat_manager,
- rag_params=rag_params,
- lm_instr=lm_instr,
- context_instr=context_instr,
- api_key=api_key,
- **kwargs,
- )
+ self.wfile.write(json.dumps(json_response).encode("utf-8"))
- return handler
- handler_with_params = create_handler(chat_manager, rag_params, lm_instr, context_instr, api_key)
+def run_server(port, chat_manager, rag_params, lm_instr, context_instr, api_key, chat_history, enable_history):
+ # Define the request handler function
+ def handler(*args, **kwargs):
+ return ChatbotHTTPRequestHandler(
+ *args,
+ chat_manager=chat_manager,
+ rag_params=rag_params,
+ lm_instr=lm_instr,
+ context_instr=context_instr,
+ api_key=api_key,
+ chat_history=chat_history,
+ enable_history=enable_history,
+ **kwargs,
+ )
+
server_address = ("", port)
- httpd = HTTPServer(server_address, handler_with_params)
+ httpd = HTTPServer(server_address, handler)
return httpd
diff --git a/app/src/modules/chatbot.py b/app/src/modules/chatbot.py
index 0427dab..48d8c55 100644
--- a/app/src/modules/chatbot.py
+++ b/app/src/modules/chatbot.py
@@ -19,6 +19,7 @@
logger = logging_config.logging.getLogger("modules.chatbot")
+
def generate_response(
chat_mgr, input, chat_history, enable_history, rag_params, chat_instr, context_instr=None, stream=False
):
diff --git a/app/src/modules/logging_config.py b/app/src/modules/logging_config.py
index 694cfcc..eeb2dba 100644
--- a/app/src/modules/logging_config.py
+++ b/app/src/modules/logging_config.py
@@ -28,5 +28,6 @@ def setup_logging():
# Sagemaker continuously complains about config, suppress
logging.getLogger("sagemaker.config").setLevel(logging.WARNING)
+
# Call setup_logging when this module is imported
setup_logging()
diff --git a/app/src/modules/metadata.py b/app/src/modules/metadata.py
index 6edd150..4abfd12 100644
--- a/app/src/modules/metadata.py
+++ b/app/src/modules/metadata.py
@@ -210,7 +210,7 @@ def embedding_models():
"api_key": "",
"openai_compat": True,
"chunk_max": 512,
- "dimensions": 768
+ "dimensions": 768,
},
"text-embedding-3-small": {
"enabled": os.getenv("OPENAI_API_KEY") is not None,
@@ -219,8 +219,7 @@ def embedding_models():
"api_key": os.environ.get("OPENAI_API_KEY", default=""),
"openai_compat": True,
"chunk_max": 8191,
- "dimensions": 1536
-
+ "dimensions": 1536,
},
"text-embedding-3-large": {
"enabled": os.getenv("OPENAI_API_KEY") is not None,
@@ -229,7 +228,7 @@ def embedding_models():
"api_key": os.environ.get("OPENAI_API_KEY", default=""),
"openai_compat": True,
"chunk_max": 8191,
- "dimensions": 3072
+ "dimensions": 3072,
},
"text-embedding-ada-002": {
"enabled": os.getenv("OPENAI_API_KEY") is not None,
@@ -238,7 +237,7 @@ def embedding_models():
"api_key": os.environ.get("OPENAI_API_KEY", default=""),
"openai_compat": True,
"chunk_max": 8191,
- "dimensions": 1536
+ "dimensions": 1536,
},
"embed-english-v3.0": {
"enabled": os.getenv("COHERE_API_KEY") is not None,
@@ -247,7 +246,7 @@ def embedding_models():
"api_key": os.environ.get("COHERE_API_KEY", default=""),
"openai_compat": False,
"chunk_max": 512,
- "dimensions": 1024
+ "dimensions": 1024,
},
"embed-english-light-v3.0": {
"enabled": os.getenv("COHERE_API_KEY") is not None,
@@ -256,7 +255,7 @@ def embedding_models():
"api_key": os.environ.get("COHERE_API_KEY", default=""),
"openai_compat": False,
"chunk_max": 512,
- "dimensions": 384
+ "dimensions": 384,
},
"mxbai-embed-large": {
"enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None,
@@ -265,7 +264,7 @@ def embedding_models():
"api_key": "",
"openai_compat": True,
"chunk_max": 512,
- "dimensions": 1024
+ "dimensions": 1024,
},
"nomic-embed-text": {
"enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None,
@@ -274,7 +273,7 @@ def embedding_models():
"api_key": "",
"openai_compat": True,
"chunk_max": 8192,
- "dimensions": 768
+ "dimensions": 768,
},
"all-minilm": {
"enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None,
@@ -283,7 +282,7 @@ def embedding_models():
"api_key": "",
"openai_compat": True,
"chunk_max": 256,
- "dimensions": 384
+ "dimensions": 384,
},
}
return embedding_models_dict
diff --git a/app/src/modules/report_utils.py b/app/src/modules/report_utils.py
index bde682d..c6a3d4f 100644
--- a/app/src/modules/report_utils.py
+++ b/app/src/modules/report_utils.py
@@ -67,9 +67,7 @@ def record_update():
if "reference_answer_input" not in state:
state.reference_answer_input = state.df.iloc[state.index]["reference_answer"]
if "reference_context_input" not in state:
- state.reference_context_input = state.df.iloc[state.index][
- "reference_context"
- ]
+ state.reference_context_input = state.df.iloc[state.index]["reference_context"]
if "metadata_input" not in state:
state.metadata_input = state.df.iloc[state.index]["metadata"]
@@ -86,12 +84,8 @@ def record_update():
state.index -= 1
state.hide_input = state.df.at[state.index, "hide"]
state.question_input = state.df.at[state.index, "question"]
- state.reference_answer_input = state.df.at[
- state.index, "reference_answer"
- ]
- state.reference_context_input = state.df.at[
- state.index, "reference_context"
- ]
+ state.reference_answer_input = state.df.at[state.index, "reference_answer"]
+ state.reference_context_input = state.df.at[state.index, "reference_context"]
state.metadata_input = state.df.at[state.index, "metadata"]
# Button to move to the next question
@@ -102,12 +96,8 @@ def record_update():
state.index += 1
state.hide_input = state.df.at[state.index, "hide"]
state.question_input = state.df.at[state.index, "question"]
- state.reference_answer_input = state.df.at[
- state.index, "reference_answer"
- ]
- state.reference_context_input = state.df.at[
- state.index, "reference_context"
- ]
+ state.reference_answer_input = state.df.at[state.index, "reference_answer"]
+ state.reference_context_input = state.df.at[state.index, "reference_context"]
state.metadata_input = state.df.at[state.index, "metadata"]
# Button to save the current input value
@@ -117,12 +107,8 @@ def record_update():
# Save the current input value in the DataFrame
state.df.at[state.index, "hide"] = state.hide_input
state.df.at[state.index, "question"] = state.question_input
- state.df.at[state.index, "reference_answer"] = (
- state.reference_answer_input
- )
- state.df.at[state.index, "reference_context"] = (
- state.reference_context_input
- )
+ state.df.at[state.index, "reference_answer"] = state.reference_answer_input
+ state.df.at[state.index, "reference_context"] = state.reference_context_input
# state.df.at[state.index, 'metadata'] = state.metadata_input # It's read-only
logger.info("--------SAVE----------------------")
@@ -141,20 +127,14 @@ def record_update():
state.df.to_json(file_path, orient="records", lines=True, index=False)
# Text input for the question, storing the user's input in the session state
- state.index_output = st.write(
- "Record: " + str(state.index + 1) + "/" + str(state.df.shape[0])
- )
+ state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0]))
state.hide_input = st.checkbox("Hide", value=state.hide_input)
state.question_input = st.text_area("question", height=1, value=state.question_input)
- state.reference_answer_input = st.text_area(
- "Reference answer", height=1, value=state.reference_answer_input
- )
+ state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input)
state.reference_context_input = st.text_area(
"Reference context", height=10, value=state.reference_context_input, disabled=True
)
- state.metadata_input = st.text_area(
- "Metadata", height=1, value=state.metadata_input, disabled=True
- )
+ state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True)
if save_clicked:
st.success("Q&A saved successfully!")
diff --git a/app/src/modules/st_common.py b/app/src/modules/st_common.py
index 7160f92..078e872 100644
--- a/app/src/modules/st_common.py
+++ b/app/src/modules/st_common.py
@@ -141,6 +141,47 @@ def initialize_rag():
st.error("Application has not been initialized, please restart.", icon="⛑️")
+def show_rag_refs(context):
+ """When RAG Enabled, show the references"""
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+ )
+
+ column_sizes = [10, 8, 8, 8, 2, 24]
+ cols = st.columns(column_sizes)
+ # Create a button in each column
+ links = set()
+ with cols[0]:
+ st.markdown("**References:**")
+ # Limit the maximum number of items to 3 (counting from 0)
+ max_items = min(len(context), 3)
+
+ # Loop through the chunks and display them
+ for i in range(max_items):
+ with cols[i + 1]:
+ chunk = context[i]
+ links.add(chunk.metadata["source"])
+ with st.popover(f"Ref: {i+1}"):
+ st.markdown(chunk.metadata["source"])
+ st.markdown(chunk.page_content)
+ st.markdown(chunk.metadata["id"])
+
+ for link in links:
+ st.markdown("- " + link)
+
+
def initialize_chatbot(ll_model):
"""Initialize the Chatbot"""
logger.info("Initializing ChatBot using %s; RAG: %s", ll_model, state.rag_params["enable"])
@@ -273,87 +314,112 @@ def refresh_rag_filtered():
if user_alias:
rag_filt_alias = [user_alias]
else:
- rag_filt_alias = [
- v.get("alias", None)
- for v in state.vs_tables.values()
- if (user_model is None or v["model"] == user_model)
- and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
- and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
- and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
- ]
+ try:
+ rag_filt_alias = [
+ v.get("alias", None)
+ for v in state.vs_tables.values()
+ if (user_model is None or v["model"] == user_model)
+ and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
+ and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
+ and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
+ ]
+ except AttributeError:
+ # st.session_state has no attribute "vs_tables"
+ clear_initialized()
if user_model:
rag_filt_model = [user_model]
else:
- rag_filt_model = [
- v["model"]
- for v in state.vs_tables.values()
- if (user_alias is None or v.get("alias", None) == user_alias)
- and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
- and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
- and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
- ]
+ try:
+ rag_filt_model = [
+ v["model"]
+ for v in state.vs_tables.values()
+ if (user_alias is None or v.get("alias", None) == user_alias)
+ and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
+ and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
+ and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
+ ]
+ except AttributeError:
+ # st.session_state has no attribute "vs_tables"
+ clear_initialized()
if user_chunk_size:
rag_filt_chunk_size = [user_chunk_size]
else:
- rag_filt_chunk_size = [
- v["chunk_size"]
- for v in state.vs_tables.values()
- if (user_alias is None or v.get("alias", None) == user_alias)
- and (user_model is None or v["model"] == user_model)
- and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
- and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
- ]
+ try:
+ rag_filt_chunk_size = [
+ v["chunk_size"]
+ for v in state.vs_tables.values()
+ if (user_alias is None or v.get("alias", None) == user_alias)
+ and (user_model is None or v["model"] == user_model)
+ and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
+ and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
+ ]
+ except AttributeError:
+ # st.session_state has no attribute "vs_tables"
+ clear_initialized()
if user_chunk_overlap:
rag_filt_chunk_overlap = [user_chunk_overlap]
else:
- rag_filt_chunk_overlap = [
- v["chunk_overlap"]
- for v in state.vs_tables.values()
- if (user_alias is None or v.get("alias", None) == user_alias)
- and (user_model is None or v["model"] == user_model)
- and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
- and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
- ]
+ try:
+ rag_filt_chunk_overlap = [
+ v["chunk_overlap"]
+ for v in state.vs_tables.values()
+ if (user_alias is None or v.get("alias", None) == user_alias)
+ and (user_model is None or v["model"] == user_model)
+ and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
+ and (user_distance_metric is None or v["distance_metric"] == user_distance_metric)
+ ]
+ except AttributeError:
+ # st.session_state has no attribute "vs_tables"
+ clear_initialized()
if user_distance_metric:
rag_filt_distance_metric = [user_distance_metric]
else:
- rag_filt_distance_metric = [
- v["distance_metric"]
- for v in state.vs_tables.values()
- if (user_alias is None or v.get("alias", None) == user_alias)
- and (user_model is None or v["model"] == user_model)
- and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
- and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
- ]
+ try:
+ rag_filt_distance_metric = [
+ v["distance_metric"]
+ for v in state.vs_tables.values()
+ if (user_alias is None or v.get("alias", None) == user_alias)
+ and (user_model is None or v["model"] == user_model)
+ and (user_chunk_size is None or v["chunk_size"] == user_chunk_size)
+ and (user_chunk_overlap is None or v["chunk_overlap"] == user_chunk_overlap)
+ ]
+ except AttributeError:
+ # st.session_state has no attribute "vs_tables"
+ clear_initialized()
# Remove duplicates and sort
- state.rag_filter["alias"] = sorted(set(rag_filt_alias))
- state.rag_filter["model"] = sorted(set(rag_filt_model))
- state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size))
- state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap))
- state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric))
+ try:
+ state.rag_filter["alias"] = sorted(set(rag_filt_alias))
+ state.rag_filter["model"] = sorted(set(rag_filt_model))
+ state.rag_filter["chunk_size"] = sorted(set(rag_filt_chunk_size))
+ state.rag_filter["chunk_overlap"] = sorted(set(rag_filt_chunk_overlap))
+ state.rag_filter["distance_metric"] = sorted(set(rag_filt_distance_metric))
+ except UnboundLocalError:
+ clear_initialized()
# (Re)set the index to previously selected option
attributes = ["alias", "model", "chunk_size", "chunk_overlap", "distance_metric"]
for attr in attributes:
- filtered_list = state.rag_filter[attr]
try:
+ filtered_list = state.rag_filter[attr]
user_value = getattr(state, f"rag_user_{attr}")
except AttributeError:
setattr(state, f"rag_user_{attr}", [])
+ except KeyError:
+ clear_initialized()
try:
idx = 0 if len(filtered_list) == 1 else filtered_list.index(user_value)
- except (ValueError, AttributeError):
+ except (ValueError, AttributeError, UnboundLocalError):
idx = None
state.rag_user_idx[attr] = idx
- rag_enable = st.sidebar.checkbox(
+ st.sidebar.checkbox(
"RAG?",
value=state.rag_params["enable"],
key="rag_user_enable",
@@ -362,7 +428,7 @@ def refresh_rag_filtered():
on_change=reset_rag,
)
- if rag_enable:
+ if state.rag_params["enable"]:
set_default_state("rag_user_rerank", False)
st.sidebar.checkbox(
"Enable Re-Ranking?",
diff --git a/app/src/modules/utilities.py b/app/src/modules/utilities.py
index 2c46da0..e005e16 100644
--- a/app/src/modules/utilities.py
+++ b/app/src/modules/utilities.py
@@ -51,29 +51,25 @@
# General
###############################################################################
def is_url_accessible(url):
- """Check that URL is Available"""
- logger.debug("Checking %s is accessible", url)
+ """Check if the URL is accessible."""
+ logger.debug("Checking if %s is accessible", url)
+
try:
response = requests.get(url, timeout=2)
- logger.info("Checking %s resulted in %s", url, response.status_code)
- # Check if the response status code is 200 (OK) 403 (Forbidden) 404 (Not Found) 421 (Misdirected)
- if response.status_code in [200, 403, 404, 421]:
+ logger.info("Response for %s: %s", url, response.status_code)
+
+ if response.status_code in {200, 403, 404, 421}:
return True, None
- else:
- err_msg = f"{url} is not accessible. (Status: {response.status_code})"
- logger.warning(err_msg)
- return False, err_msg
- except requests.exceptions.ConnectionError:
- err_msg = f"{url} is not accessible. (Connection Error)"
+
+ err_msg = f"{url} is not accessible. (Status: {response.status_code})"
logger.warning(err_msg)
return False, err_msg
- except requests.exceptions.Timeout:
- err_msg = f"{url} is not accessible. (Request Timeout)"
+
+ except requests.exceptions.RequestException as ex:
+ err_msg = f"{url} is not accessible. ({type(ex).__name__})"
logger.warning(err_msg)
- return False, err_msg
- except requests.RequestException as ex:
logger.exception(ex, exc_info=False)
- return False, ex
+ return False, err_msg
###############################################################################
@@ -97,6 +93,15 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
logger.debug("Matching LLM API: %s", llm_api)
+ common_params = {
+ "model": model,
+ "temperature": lm_params["temperature"][0],
+ "max_tokens": lm_params["max_tokens"][0],
+ "top_p": lm_params["top_p"][0],
+ "frequency_penalty": lm_params["frequency_penalty"][0],
+ "presence_penalty": lm_params["presence_penalty"][0],
+ }
+
## Start - Add Additional Model Authentication Here
client = None
if giskarded:
@@ -104,49 +109,13 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
_client = OpenAI(api_key=giskard_key, base_url=f"{llm_url}/v1/")
client = OpenAIClient(model=model, client=_client)
elif llm_api == "OpenAI":
- client = ChatOpenAI(
- api_key=lm_params["api_key"],
- model_name=model,
- temperature=lm_params["temperature"][0],
- max_tokens=lm_params["max_tokens"][0],
- top_p=lm_params["top_p"][0],
- frequency_penalty=lm_params["frequency_penalty"][0],
- presence_penalty=lm_params["presence_penalty"][0],
- )
+ client = ChatOpenAI(api_key=lm_params["api_key"], **common_params)
elif llm_api == "Cohere":
- client = ChatCohere(
- cohere_api_key=lm_params["api_key"],
- model=model,
- temperature=lm_params["temperature"][0],
- max_tokens=lm_params["max_tokens"][0],
- top_p=lm_params["top_p"][0],
- frequency_penalty=lm_params["frequency_penalty"][0],
- presence_penalty=lm_params["presence_penalty"][0],
- )
+ client = ChatCohere(cohere_api_key=lm_params["api_key"], **common_params)
elif llm_api == "ChatPerplexity":
- client = ChatPerplexity(
- pplx_api_key=lm_params["api_key"],
- model=model,
- temperature=lm_params["temperature"][0],
- max_tokens=lm_params["max_tokens"][0],
- model_kwargs={
- "top_p": lm_params["top_p"][0],
- "frequency_penalty": lm_params["frequency_penalty"][0],
- "presence_penalty": lm_params["presence_penalty"][0],
- },
- )
+ client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params)
elif llm_api == "ChatOllama":
- client = ChatOllama(
- model=model,
- base_url=llm_url,
- temperature=lm_params["temperature"][0],
- max_tokens=lm_params["max_tokens"][0],
- model_kwargs={
- "top_p": lm_params["top_p"][0],
- "frequency_penalty": lm_params["frequency_penalty"][0],
- "presence_penalty": lm_params["presence_penalty"][0],
- },
- )
+ client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params)
## End - Add Additional Model Authentication Here
api_accessible, err_msg = is_url_accessible(llm_url)
@@ -264,7 +233,7 @@ def init_vs(db_conn, embedding_function, store_table, distance_metric):
def get_vs_table(model, chunk_size, chunk_overlap, distance_metric, embed_alias=None):
- """Get a list of Vector Store Tables"""
+ """Return the concatenated VS Table name and comment"""
chunk_overlap_ceil = math.ceil(chunk_overlap)
table_string = f"{model}_{chunk_size}_{chunk_overlap_ceil}_{distance_metric}"
if embed_alias:
@@ -425,11 +394,11 @@ def oci_init_client(client_type, config=None, retries=True):
retry_strategy = oci.retry.NoneRetryStrategy()
# Initialize Client (Workload Identity, Token and API)
+ client = None
if not config:
logger.info("OCI Authentication with Workload Identity")
- signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer()
- # Region is required for endpoint generation; not sure its value matters
- client = client_type(config={"region": "us-ashburn-1"}, signer=signer)
+ oke_workload_signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer()
+ client = client_type(config={}, signer=oke_workload_signer)
elif config and config["security_token_file"]:
logger.info("OCI Authentication with Security Token")
token = None
@@ -437,7 +406,7 @@ def oci_init_client(client_type, config=None, retries=True):
token = f.read()
private_key = oci.signer.load_private_key_from_file(config["key_file"])
signer = oci.auth.signers.SecurityTokenSigner(token, private_key)
- client = client_type({"region": config["region"]}, signer=signer)
+ client = client_type(config={"region": config["region"]}, signer=signer)
else:
logger.info("OCI Authentication as Standard")
client = client_type(config, retry_strategy=retry_strategy)
@@ -539,6 +508,8 @@ def oci_get_namespace(config, retries=True):
raise OciException("Invalid Key Path") from ex
except UnboundLocalError as ex:
raise OciException("No Configuration - Disabling OCI") from ex
+ except Exception as ex:
+ raise OciException("Uncaught Exception - Disabling OCI") from ex
return namespace
diff --git a/app/src/oaim-sandbox.py b/app/src/oaim-sandbox.py
index 4e203e0..522a38c 100644
--- a/app/src/oaim-sandbox.py
+++ b/app/src/oaim-sandbox.py
@@ -19,6 +19,7 @@
from content.db_config import initialize_streamlit as db_initialize
from content.prompt_eng import initialize_streamlit as prompt_initialize
from content.oci_config import initialize_streamlit as oci_initialize
+import content.api_server as api_server_content
logger = logging_config.logging.getLogger("sandbox")
@@ -38,6 +39,7 @@ def main():
model_initialize()
prompt_initialize()
oci_initialize()
+ api_server_content.initialize_streamlit()
# Setup rag_params into state enable as default
if "rag_params" not in state:
@@ -47,6 +49,12 @@ def main():
if "rag_filter" not in state:
state.rag_filter = {}
+ # Start the API server
+ if state.api_server_config["auto_start"]:
+ api_server_content.api_server_start()
+ if "user_chat_history" not in state:
+ state.user_chat_history = True
+
# GUI Defaults
css = """