Skip to content

Commit

Permalink
API Server (#38)
Browse files Browse the repository at this point in the history
* Enable API Server Autostart
* Return RAG results in API call
  • Loading branch information
gotsysdba authored Nov 7, 2024
1 parent 1735c7d commit df24730
Show file tree
Hide file tree
Showing 20 changed files with 396 additions and 357 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ app/THIRD_PARTY_LICENSES.txt
rag/**
sbin/**
helm/values.yaml
**/*.bak

##############################################################################
# Enviroment (PyVen, IDE, etc.)
Expand Down
2 changes: 1 addition & 1 deletion app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RUN microdnf -y install python3.11 python3.11-pip
RUN python3.11 -m venv --symlinks --upgrade-deps /opt/venv
COPY ./requirements.txt /opt/requirements.txt
RUN source /opt/venv/bin/activate && \
pip3 install --upgrade pip wheel && \
pip3 install --upgrade pip wheel setuptools && \
pip3 install -r /opt/requirements.txt
RUN groupadd $RUNUSER && useradd -u 10001 -g $RUNUSER -md /app $RUNUSER;
RUN install -d -m 0700 -o $RUNUSER -g $RUNUSER /app/.oci
Expand Down
21 changes: 11 additions & 10 deletions app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@
## Installation Example (from repo root):
### python3.11 -m venv .venv
### source .venv/bin/activate
### pip3 install --upgrade pip wheel
### pip3 install --upgrade pip wheel setuptools
### pip3 install -r app/requirements.txt

## Top-Level brings in the required dependencies, if adding modules, try to find the minimum required
bokeh==3.6.0
evaluate==0.4.3
faiss-cpu==1.9.0
giskard==2.15.2
IPython==8.28.0
giskard==2.15.3
IPython==8.29.0
langchain-cohere==0.3.1
langchain-community==0.3.2
langchain-huggingface==0.1.0
langchain-community==0.3.5
langchain-huggingface==0.1.2
langchain-ollama==0.2.0
langchain-openai==0.2.2
llama_index==0.11.18
langchain-openai==0.2.6
langgraph==0.2.45
llama_index==0.11.22
lxml==5.3.0
matplotlib==3.9.2
oci>=2.0.0
oracledb>=2.0.0
plotly==5.24.1
streamlit==1.39.0
umap-learn==0.5.6
umap-learn==0.5.7

## For Licensing Purposes; ensures no GPU modules are installed
## as part of langchain-huggingface
-f https://download.pytorch.org/whl/cpu/torch
torch==2.4.1+cpu ; sys_platform == "linux"
torch==2.2.2 ; sys_platform == "darwin"
torch==2.5.1+cpu ; sys_platform == "linux"
torch==2.5.1 ; sys_platform == "darwin"
146 changes: 87 additions & 59 deletions app/src/content/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""

# spell-checker:ignore streamlit
# spell-checker:ignore streamlit, langchain, llms

import inspect
import threading
import time
import threading

# Streamlit
import streamlit as st
Expand All @@ -15,9 +16,11 @@
# Utilities
import modules.st_common as st_common
import modules.api_server as api_server

import modules.logging_config as logging_config

# History
from langchain_community.chat_message_histories import StreamlitChatMessageHistory

logger = logging_config.logging.getLogger("api_server")


Expand All @@ -32,49 +35,61 @@ def initialize_streamlit():


def display_logs():
log_placeholder = st.empty() # A placeholder to update logs
logs = [] # Store logs for display

try:
while "server_thread" in st.session_state:
try:
# Retrieve log from queue (non-blocking)
log_item = api_server.log_queue.get_nowait()
logs.append(log_item)
# Update the placeholder with new logs
log_placeholder.text("\n".join(logs))
msg = api_server.log_queue.get_nowait()
logger.info("API Msg: %s", msg)
if "message" in msg:
st.chat_message("human").write(msg["message"])
else:
if state.rag_params["enable"]:
st.chat_message("ai").write(msg["answer"])
st_common.show_rag_refs(msg["context"])
else:
st.chat_message("ai").write(msg.content)
except api_server.queue.Empty:
time.sleep(0.1) # Avoid busy-waiting
finally:
logger.info("API Server events display has stopped.")


def api_server_start():
state.api_server_config["port"] = state.user_api_server_port
state.api_server_config["key"] = state.user_api_server_key
chat_history = StreamlitChatMessageHistory(key="api_chat_history")
if "user_api_server_port" in state:
state.api_server_config["port"] = state.user_api_server_port
if "user_api_server_key" in state:
state.api_server_config["key"] = state.user_api_server_key
if "initialized" in state and state.initialized:
if "server_thread" not in state:
state.httpd = api_server.run_server(
state.api_server_config["port"],
state.chat_manager,
state.rag_params,
state.lm_instr,
state.context_instr,
state.api_server_config["key"],
)

# Start the server in the thread
def api_server_process(httpd):
httpd.serve_forever()

state.server_thread = threading.Thread(
target=api_server_process,
# Trailing , ensures tuple is passed
args=(state.httpd,),
daemon=True,
)
state.server_thread.start()
logger.info("Started API Server on port: %i", state.api_server_config["port"])
try:
state.httpd = api_server.run_server(
state.api_server_config["port"],
state.chat_manager,
state.rag_params,
state.lm_instr,
state.context_instr,
state.api_server_config["key"],
chat_history,
state.user_chat_history,
)

# Start the server in the thread
def api_server_process(httpd):
httpd.serve_forever()

state.server_thread = threading.Thread(
target=api_server_process,
# Trailing , ensures tuple is passed
args=(state.httpd,),
daemon=True,
)
state.server_thread.start()
logger.info("Started API Server on port: %i", state.api_server_config["port"])
except OSError:
if not state.api_server_config["auto_start"]:
st.error("Port is already in use.")
else:
st.warning("API Server is already running.")
else:
Expand Down Expand Up @@ -106,22 +121,37 @@ def api_server_stop():
#############################################################################
def main():
"""Streamlit GUI"""
initialize_streamlit()
st.header("API Server")

# LLM Params
ll_model = st_common.lm_sidebar()

# Initialize RAG
st_common.initialize_rag()
# Setup History
chat_history = StreamlitChatMessageHistory(key="api_chat_history")

# RAG
st_common.rag_sidebar()
#########################################################################
# Sidebar Settings
#########################################################################
enabled_llms = sum(model_info["enabled"] for model_info in state.ll_model_config.values())
if enabled_llms > 0:
initialize_streamlit()
enable_history = st.sidebar.checkbox(
"Enable History and Context?",
value=True,
key="user_chat_history",
)
if st.sidebar.button("Clear History", disabled=not enable_history):
chat_history.clear()
st.sidebar.divider()
ll_model = st_common.lm_sidebar()
st_common.rag_sidebar()
else:
st.error("No chat models are configured and/or enabled.", icon="🚨")
st.stop()

#########################################################################
# Initialize the Client
# Main
#########################################################################
if "initialized" not in state:
api_server_stop()
if not state.rag_params["enable"] or all(
state.rag_params[key] for key in ["model", "chunk_size", "chunk_overlap", "distance_metric"]
):
Expand All @@ -130,10 +160,6 @@ def main():
state.initialized = True
st_common.update_rag()
logger.debug("Force rerun to save state")
if "server_thread" in state:
logger.info("Restarting API Server")
api_server_stop()
api_server_start()
st.rerun()
except Exception as ex:
logger.exception(ex, exc_info=False)
Expand All @@ -143,18 +169,14 @@ def main():
st.rerun()
st.stop()
else:
# RAG Enabled but not configured
if "server_thread" in state:
logger.info("Stopping API Server")
api_server_stop()
st.error("Not all required RAG options are set, please review or disable RAG.")
st.stop()

#########################################################################
# API Server
#########################################################################
server_running = False
if "server_thread" in state:
server_running = True
st.success("API Server is Running")
elif state.api_server_config["auto_start"]:
server_running = True

left, right = st.columns([0.2, 0.8])
left.number_input(
Expand All @@ -165,6 +187,7 @@ def main():
key="user_api_server_port",
disabled=server_running,
)

right.text_input(
"API Server Key:",
type="password",
Expand All @@ -173,15 +196,20 @@ def main():
disabled=server_running,
)

if "server_thread" in state:
st.button("Stop Server", type="primary", on_click=api_server_stop)
elif "initialized" in state and state.initialized:
st.button("Start Server", type="primary", on_click=api_server_start)
if state.api_server_config["auto_start"]:
api_server_start()
st.success("API Server automatically started.")
else:
st.error("Not all required RAG options are set, please review or disable RAG.")
if server_running:
st.button("Stop Server", type="primary", on_click=api_server_stop)
else:
st.button("Start Server", type="primary", on_click=api_server_start)

st.subheader("Activity")
#########################################################################
# API Server Centre
#########################################################################
if "server_thread" in state:
st.subheader("Activity")
with st.container(border=True):
display_logs()

Expand Down
46 changes: 1 addition & 45 deletions app/src/content/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,6 @@
logger = logging_config.logging.getLogger("chatbot")


#############################################################################
# Functions
#############################################################################
def show_refs(context):
"""When RAG Enabled, show the references"""
st.markdown(
"""
<style>
.stButton button {
width: 8px; /* Adjust the width as needed */
height: 8px; /* Adjust the height as needed */
font-size: 8px; /* Adjust the font size as needed */
}
ul {
padding: 0px
}
</style>
""",
unsafe_allow_html=True,
)

column_sizes = [10, 8, 8, 8, 2, 24]
cols = st.columns(column_sizes)
# Create a button in each column
links = set()
with cols[0]:
st.markdown("**References:**")
# Limit the maximum number of items to 3 (counting from 0)
max_items = min(len(context), 3)

# Loop through the chunks and display them
for i in range(max_items):
with cols[i + 1]:
chunk = context[i]
links.add(chunk.metadata["source"])
with st.popover(f"Ref: {i+1}"):
st.markdown(chunk.metadata["source"])
st.markdown(chunk.page_content)
st.markdown(chunk.metadata["id"])

for link in links:
st.markdown("- " + link)


#############################################################################
# MAIN
#############################################################################
Expand Down Expand Up @@ -156,7 +112,7 @@ def main():
full_context = chunk["context"]
message_placeholder.markdown(full_answer)
if full_context:
show_refs(full_context)
st_common.show_rag_refs(full_context)
else:
st.chat_message("ai").write_stream(response)
except Exception as ex:
Expand Down
2 changes: 1 addition & 1 deletion app/src/content/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def main():
if update_ll_model:
st.success("Language Model Configuration - Updated", icon="✅")


st.header("Embedding Models")
with st.form("update_embed_model_config"):
# Create table header
Expand Down Expand Up @@ -210,5 +209,6 @@ def main():
if update_embed_model:
st.success("Embedding Model Configuration - Updated", icon="✅")


if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename:
main()
2 changes: 1 addition & 1 deletion app/src/content/split_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def main():
model,
distance_metric,
split_docos,
0 if rate_limit is None else rate_limit
0 if rate_limit is None else rate_limit,
)
placeholder.empty()
st_common.reset_rag()
Expand Down
Loading

0 comments on commit df24730

Please sign in to comment.