diff --git a/.gitignore b/.gitignore
index 57e26ae..0d1b862 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,7 +14,8 @@ dist/
downloads/
eggs/
.eggs/
-lib/
+.idea/
+
lib64/
parts/
sdist/
@@ -163,5 +164,5 @@ cython_debug/
# dataset files
data/
.streamlit/
-*.ipynb
+#*.ipynb
.DS_Store
\ No newline at end of file
diff --git a/README.md b/README.md
index 8a9d1ed..977be5b 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,7 @@ In conclusion, with ChatData, you can effortlessly navigate through vast amounts
➡️ Dive in and experience ChatData on [Hugging Face](https://huggingface.co/spaces/myscale/ChatData)🤗
-![ChatData Homepage](assets/chatdata-homepage.png)
+![ChatData Homepage](assets/home.png)
### Data schema
@@ -117,15 +117,6 @@ And for overall table schema, please refer to [table creation section in docs/se
If you want to use this database with `langchain.chains.sql_database.base.SQLDatabaseChain` or `langchain.retrievers.SQLDatabaseRetriever`, please follow guides on [data preparation section](docs/vector-sql.md#prepare-the-database) and [chain creation section](docs/vector-sql.md#create-the-sqldatabasechain) in docs/vector-sql.md
-### How to run ChatData
-
-
-
-```bash
-python3 -m pip install requirements.txt
-python3 -m streamlit run app.py
-```
-
### Where can I get those arXiv data?
- [From parquet files on S3](docs/self-query.md#insert-data)
@@ -167,18 +158,12 @@ cd app/
2. Create an virtual environment
```bash
-python3 -m venv .venv
-source .venv/bin/activate
+python3 -m venv venv
+source venv/bin/activate
```
3. Install dependencies
-> This app is currently using [MyScale's technical preview of LangChain](https://github.com/myscale/langchain/tree/preview).
->
->> It contains improved SQLDatabaseChain in [this PR](https://github.com/hwchase17/langchain/pull/7454)
->>
->> It contains [improved prompts](https://github.com/hwchase17/langchain/pull/6737#discussion_r1243527112) for comparators `LIKE` and `CONTAIN` in [MyScale self-query retriever](https://github.com/hwchase17/langchain/pull/6143).
-
```bash
python3 -m pip install -r requirements.txt
```
diff --git a/app/.streamlit/config.toml b/app/.streamlit/config.toml
index 35fc398..37b9ad3 100644
--- a/app/.streamlit/config.toml
+++ b/app/.streamlit/config.toml
@@ -1,6 +1,2 @@
[theme]
-primaryColor="#523EFD"
-backgroundColor="#FFFFFF"
-secondaryBackgroundColor="#D4CEFF"
-textColor="#262730"
-font="sans serif"
\ No newline at end of file
+base="dark"
diff --git a/app/.streamlit/secrets.example.toml b/app/.streamlit/secrets.example.toml
index 9c5d344..add2dbf 100644
--- a/app/.streamlit/secrets.example.toml
+++ b/app/.streamlit/secrets.example.toml
@@ -1,7 +1,8 @@
-MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
+MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" # read-only database provided by MyScale
MYSCALE_PORT = 443
MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
+MYSCALE_ENABLE_HTTPS = true
OPENAI_API_BASE = "https://api.openai.com/v1"
OPENAI_API_KEY = ""
UNSTRUCTURED_API = "" # optional if you don't upload documents
diff --git a/app/app.py b/app/app.py
index ecae84a..df1bbfd 100644
--- a/app/app.py
+++ b/app/app.py
@@ -1,133 +1,86 @@
-import pandas as pd
-from os import environ
+import os
+import time
+
import streamlit as st
-from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
- ChatDataSQLAskCallBackHandler
+from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
+ DATA_INITIALIZE_STARTED
+from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
+ TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
+from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
+from backend.types.global_config import GlobalConfig
+from logger import logger
+from ui.chat_page import chat_page
+from ui.home import render_home
+from ui.retrievers import render_retrievers
-from chat import chat_page
-from login import login, back_to_main
-from lib.helper import build_tools, build_all, sel_map, display
+# warnings.filterwarnings("ignore", category=UserWarning)
-environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
+def prepare_environment():
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ os.environ["LANGCHAIN_TRACING_V2"] = "false"
+ # os.environ["LANGCHAIN_API_KEY"] = ""
+ os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
+ os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
+ os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
+ os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']
+
+ update_global_config(GlobalConfig(
+ openai_api_base=st.secrets['OPENAI_API_BASE'],
+ openai_api_key=st.secrets['OPENAI_API_KEY'],
+ auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
+ auth0_domain=st.secrets['AUTH0_DOMAIN'],
+ myscale_user=st.secrets['MYSCALE_USER'],
+ myscale_password=st.secrets['MYSCALE_PASSWORD'],
+ myscale_host=st.secrets['MYSCALE_HOST'],
+ myscale_port=st.secrets['MYSCALE_PORT'],
+ query_model="gpt-3.5-turbo-0125",
+ chat_model="gpt-3.5-turbo-0125",
+ untrusted_api=st.secrets['UNSTRUCTURED_API'],
+ myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
+ ))
-st.set_page_config(page_title="ChatData",
- page_icon="https://myscale.com/favicon.ico")
-st.markdown(
- f"""
- """,
- unsafe_allow_html=True,
-)
-st.header("ChatData")
-if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state:
- st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all()
- st.session_state["tools"] = build_tools()
+# when refresh browser, all session keys will be cleaned.
+def initialize_session_state():
+ if DATA_INITIALIZE_STATUS not in st.session_state:
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
+ logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
+ if JUMP_QUERY_ASK not in st.session_state:
+ st.session_state[JUMP_QUERY_ASK] = False
+ logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")
-if login():
- if "user_name" in st.session_state:
- chat_page()
- elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
- sel = st.selectbox('Choose the knowledge base you want to ask with:',
- options=['ArXiv Papers', 'Wikipedia'])
- sel_map[sel]['hint']()
- tab_sql, tab_self_query = st.tabs(
- ['Vector SQL', 'Self-Query Retrievers'])
- with tab_sql:
- sel_map[sel]['hint_sql']()
- st.text_input("Ask a question:", key='query_sql')
- cols = st.columns([1, 1, 1, 4])
- cols[0].button("Query", key='search_sql')
- cols[1].button("Ask", key='ask_sql')
- cols[2].button("Back", key='back_sql', on_click=back_to_main)
- plc_hldr = st.empty()
- if st.session_state.search_sql:
- plc_hldr = st.empty()
- print(st.session_state.query_sql)
- with plc_hldr.expander('Query Log', expanded=True):
- callback = ChatDataSQLSearchCallBackHandler()
- try:
- docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
- st.session_state.query_sql, callbacks=[callback])
- callback.progress_bar.progress(value=1.0, text="Done!")
- docs = pd.DataFrame(
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
- display(docs)
- except Exception as e:
- st.write('Oops 😵 Something bad happened...')
- raise e
+def initialize_chat_data():
+ if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
+ start_time = time.time()
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
+ st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
+ st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
+ st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
+ # mark data initialization finished.
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
+ end_time = time.time()
+ logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
+ f"session state keys: {list(st.session_state.keys())}")
- if st.session_state.ask_sql:
- plc_hldr = st.empty()
- print(st.session_state.query_sql)
- with plc_hldr.expander('Chat Log', expanded=True):
- callback = ChatDataSQLAskCallBackHandler()
- try:
- ret = st.session_state.sel_map_obj[sel]["sql_chain"](
- st.session_state.query_sql, callbacks=[callback])
- callback.progress_bar.progress(value=1.0, text="Done!")
- st.markdown(
- f"### Answer from LLM\n{ret['answer']}\n### References")
- docs = ret['sources']
- docs = pd.DataFrame(
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
- display(
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
- except Exception as e:
- st.write('Oops 😵 Something bad happened...')
- raise e
- with tab_self_query:
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
- st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
- st.text_input("Ask a question:", key='query_self')
- cols = st.columns([1, 1, 1, 4])
- cols[0].button("Query", key='search_self')
- cols[1].button("Ask", key='ask_self')
- cols[2].button("Back", key='back_self', on_click=back_to_main)
- plc_hldr = st.empty()
- if st.session_state.search_self:
- plc_hldr = st.empty()
- print(st.session_state.query_self)
- with plc_hldr.expander('Query Log', expanded=True):
- call_back = None
- callback = ChatDataSelfSearchCallBackHandler()
- try:
- docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
- st.session_state.query_self, callbacks=[callback])
- print(docs)
- callback.progress_bar.progress(value=1.0, text="Done!")
- docs = pd.DataFrame(
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
- display(docs, sel_map[sel]["must_have_cols"])
- except Exception as e:
- st.write('Oops 😵 Something bad happened...')
- raise e
+st.set_page_config(
+ page_title="ChatData",
+ page_icon="https://myscale.com/favicon.ico",
+ initial_sidebar_state="expanded",
+ layout="wide",
+)
+
+prepare_environment()
+initialize_session_state()
+initialize_chat_data()
- if st.session_state.ask_self:
- plc_hldr = st.empty()
- print(st.session_state.query_self)
- with plc_hldr.expander('Chat Log', expanded=True):
- call_back = None
- callback = ChatDataSelfAskCallBackHandler()
- try:
- ret = st.session_state.sel_map_obj[sel]["chain"](
- st.session_state.query_self, callbacks=[callback])
- callback.progress_bar.progress(value=1.0, text="Done!")
- st.markdown(
- f"### Answer from LLM\n{ret['answer']}\n### References")
- docs = ret['sources']
- docs = pd.DataFrame(
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
- display(
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
- except Exception as e:
- st.write('Oops 😵 Something bad happened...')
- raise e
+if USER_NAME in st.session_state:
+ chat_page()
+else:
+ if st.session_state[JUMP_QUERY_ASK]:
+ render_retrievers()
+ else:
+ render_home()
diff --git a/app/backend/__init__.py b/app/backend/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/callbacks/__init__.py b/app/backend/callbacks/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/callbacks/arxiv_callbacks.py b/app/backend/callbacks/arxiv_callbacks.py
new file mode 100644
index 0000000..49f8fd1
--- /dev/null
+++ b/app/backend/callbacks/arxiv_callbacks.py
@@ -0,0 +1,46 @@
+import json
+import textwrap
+from typing import Dict, Any, List
+
+from langchain.callbacks.streamlit.streamlit_callback_handler import (
+ LLMThought,
+ StreamlitCallbackHandler,
+)
+
+
+class LLMThoughtWithKnowledgeBase(LLMThought):
+ def on_tool_end(
+ self,
+ output: str,
+ color=None,
+ observation_prefix=None,
+ llm_prefix=None,
+ **kwargs: Any,
+ ) -> None:
+ try:
+ self._container.markdown(
+ "\n\n".join(
+ ["### Retrieved Documents:"]
+ + [
+ f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
+ for i, r in enumerate(json.loads(output))
+ ]
+ )
+ )
+ except Exception as e:
+ super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
+
+
+class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ if self._current_thought is None:
+ self._current_thought = LLMThoughtWithKnowledgeBase(
+ parent_container=self._parent_container,
+ expanded=self._expand_new_thoughts,
+ collapse_on_complete=self._collapse_completed_thoughts,
+ labeler=self._thought_labeler,
+ )
+
+ self._current_thought.on_llm_start(serialized, prompts)
diff --git a/app/backend/callbacks/llm_thought_with_table.py b/app/backend/callbacks/llm_thought_with_table.py
new file mode 100644
index 0000000..56fb0a1
--- /dev/null
+++ b/app/backend/callbacks/llm_thought_with_table.py
@@ -0,0 +1,36 @@
+from typing import Any, Dict, List
+
+import streamlit as st
+from langchain_core.outputs import LLMResult
+from streamlit.external.langchain import StreamlitCallbackHandler
+
+
+class ChatDataSelfQueryCallBack(StreamlitCallbackHandler):
+ def __init__(self):
+ super().__init__(st.container())
+ self._current_thought = None
+ self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...")
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
+ pass
+
+ def on_chain_end(self, outputs, **kwargs) -> None:
+ if len(kwargs['tags']) == 0:
+ self.progress_bar.progress(value=0.75, text="Searching in DB...")
+
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
+
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ st.markdown("### Generate filter by LLM \n"
+ "> Here we get `query_constructor` results \n\n")
+
+ self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
+ for item in response.generations:
+ st.markdown(f"{item[0].text}")
+
+ pass
diff --git a/app/backend/callbacks/self_query_callbacks.py b/app/backend/callbacks/self_query_callbacks.py
new file mode 100644
index 0000000..b8237b8
--- /dev/null
+++ b/app/backend/callbacks/self_query_callbacks.py
@@ -0,0 +1,57 @@
+from typing import Dict, Any, List
+
+import streamlit as st
+from langchain.callbacks.streamlit.streamlit_callback_handler import (
+ StreamlitCallbackHandler,
+)
+from langchain.schema.output import LLMResult
+
+
+class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler):
+ def __init__(self):
+ super().__init__(st.container())
+ self._current_thought = None
+ self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...")
+
+ def on_llm_start(
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
+ ) -> None:
+ self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
+ pass
+
+ def on_chain_end(self, outputs, **kwargs) -> None:
+ if len(kwargs['tags']) == 0:
+ self.progress_bar.progress(value=0.75, text="Searching in DB...")
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ st.markdown("### Generate filter by LLM \n"
+ "> Here we get `query_constructor` results \n\n")
+ self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
+ for item in response.generations:
+ st.markdown(f"{item[0].text}")
+ pass
+
+
+class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
+ def __init__(self) -> None:
+ super().__init__(st.container())
+ self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...")
+
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
+ pass
+
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+
+ if len(kwargs['tags']) != 0:
+ self.progress_bar.progress(value=0.5, text="We got filter info from LLM...")
+ st.markdown("### Generate filter by LLM \n"
+ "> Here we get `query_constructor` results \n\n")
+ for item in response.generations:
+ st.markdown(f"{item[0].text}")
+ pass
+
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
+ cid = ".".join(serialized["id"])
+ if cid.endswith(".CustomStuffDocumentChain"):
+ self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...")
diff --git a/app/backend/callbacks/vector_sql_callbacks.py b/app/backend/callbacks/vector_sql_callbacks.py
new file mode 100644
index 0000000..b83cf76
--- /dev/null
+++ b/app/backend/callbacks/vector_sql_callbacks.py
@@ -0,0 +1,53 @@
+import streamlit as st
+from langchain.callbacks.streamlit.streamlit_callback_handler import (
+ StreamlitCallbackHandler,
+)
+from langchain.schema.output import LLMResult
+from sql_formatter.core import format_sql
+
+
+class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
+ def __init__(self) -> None:
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
+ self.status_bar = st.empty()
+ self.prog_value = 0
+ self.prog_interval = 0.2
+
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
+ pass
+
+ def on_llm_end(
+ self,
+ response: LLMResult,
+ *args,
+ **kwargs,
+ ):
+ text = response.generations[0][0].text
+ if text.replace(" ", "").upper().startswith("SELECT"):
+ st.markdown("### Generated Vector Search SQL Statement \n"
+ "> This sql statement is generated by LLM \n\n")
+ st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
+ self.prog_value += self.prog_interval
+ self.progress_bar.progress(
+ value=self.prog_value, text="Searching in DB...")
+
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
+ cid = ".".join(serialized["id"])
+ self.prog_value += self.prog_interval
+ self.progress_bar.progress(
+ value=self.prog_value, text=f"Running Chain `{cid}`..."
+ )
+
+ def on_chain_end(self, outputs, **kwargs) -> None:
+ pass
+
+
+class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
+ def __init__(self, table: str) -> None:
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
+ self.status_bar = st.empty()
+ self.prog_value = 0
+ self.prog_interval = 0.1
+ self.table = table
+
+
diff --git a/app/backend/chains/__init__.py b/app/backend/chains/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/chains/retrieval_qa_with_sources.py b/app/backend/chains/retrieval_qa_with_sources.py
new file mode 100644
index 0000000..06b1f39
--- /dev/null
+++ b/app/backend/chains/retrieval_qa_with_sources.py
@@ -0,0 +1,70 @@
+import inspect
+from typing import Dict, Any, Optional, List
+
+from langchain.callbacks.manager import (
+ AsyncCallbackManagerForChainRun,
+ CallbackManagerForChainRun,
+)
+from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
+from langchain.docstore.document import Document
+
+from logger import logger
+
+
+class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
+ """QA with source chain for Chat ArXiv app with references
+
+ This chain will automatically assign reference number to the article,
+ Then parse it back to titles or anything else.
+ """
+
+ def _call(
+ self,
+ inputs: Dict[str, Any],
+ run_manager: Optional[CallbackManagerForChainRun] = None,
+ ) -> Dict[str, str]:
+ logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m")
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
+ accepts_run_manager = (
+ "run_manager" in inspect.signature(self._get_docs).parameters
+ )
+ if accepts_run_manager:
+ docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager)
+ else:
+ docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg]
+
+ answer = self.combine_documents_chain.run(
+ input_documents=docs, callbacks=_run_manager.get_child(), **inputs
+ )
+ # parse source with ref_id
+ sources = []
+ ref_cnt = 1
+ for d in docs:
+ ref_id = d.metadata['ref_id']
+ if f"Doc #{ref_id}" in answer:
+ answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
+ if f"#{ref_id}" in answer:
+ title = d.metadata['title'].replace('\n', '')
+ d.metadata['ref_id'] = ref_cnt
+ answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
+ sources.append(d)
+ ref_cnt += 1
+
+ result: Dict[str, Any] = {
+ self.answer_key: answer,
+ self.sources_answer_key: sources,
+ }
+ if self.return_source_documents:
+ result["source_documents"] = docs
+ return result
+
+ async def _acall(
+ self,
+ inputs: Dict[str, Any],
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+ ) -> Dict[str, Any]:
+ raise NotImplementedError
+
+ @property
+ def _chain_type(self) -> str:
+ return "custom_retrieval_qa_with_sources_chain"
diff --git a/app/backend/chains/stuff_documents.py b/app/backend/chains/stuff_documents.py
new file mode 100644
index 0000000..0f6c762
--- /dev/null
+++ b/app/backend/chains/stuff_documents.py
@@ -0,0 +1,65 @@
+from typing import Any, List, Tuple
+
+from langchain.callbacks.manager import Callbacks
+from langchain.chains.combine_documents.stuff import StuffDocumentsChain
+from langchain.docstore.document import Document
+from langchain.schema.prompt_template import format_document
+
+
+class CustomStuffDocumentChain(StuffDocumentsChain):
+ """Combine arxiv documents with PDF reference number"""
+
+ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
+ """Construct inputs from kwargs and docs.
+
+ Format and the join all the documents together into one input with name
+ `self.document_variable_name`. The pluck any additional variables
+ from **kwargs.
+
+ Args:
+ docs: List of documents to format and then join into single input
+ **kwargs: additional inputs to chain, will pluck any other required
+ arguments from here.
+
+ Returns:
+ dictionary of inputs to LLMChain
+ """
+ # Format each document according to the prompt
+ doc_strings = []
+ for doc_id, doc in enumerate(docs):
+ # add temp reference number in metadata
+ doc.metadata.update({'ref_id': doc_id})
+ doc.page_content = doc.page_content.replace('\n', ' ')
+ doc_strings.append(format_document(doc, self.document_prompt))
+ # Join the documents together to put them in the prompt.
+ inputs = {
+ k: v
+ for k, v in kwargs.items()
+ if k in self.llm_chain.prompt.input_variables
+ }
+ inputs[self.document_variable_name] = self.document_separator.join(
+ doc_strings)
+ return inputs
+
+ def combine_docs(
+ self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
+ ) -> Tuple[str, dict]:
+ """Stuff all documents into one prompt and pass to LLM.
+
+ Args:
+ docs: List of documents to join together into one variable
+ callbacks: Optional callbacks to pass along
+ **kwargs: additional parameters to use to get inputs to LLMChain.
+
+ Returns:
+ The first element returned is the single string output. The second
+ element returned is a dictionary of other keys to return.
+ """
+ inputs = self._get_inputs(docs, **kwargs)
+ # Call predict on the LLM.
+ output = self.llm_chain.predict(callbacks=callbacks, **inputs)
+ return output, {}
+
+ @property
+ def _chain_type(self) -> str:
+ return "custom_stuff_document_chain"
diff --git a/app/backend/chat_bot/__init__.py b/app/backend/chat_bot/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/chat_bot/chat.py b/app/backend/chat_bot/chat.py
new file mode 100644
index 0000000..cfa1aef
--- /dev/null
+++ b/app/backend/chat_bot/chat.py
@@ -0,0 +1,225 @@
+import time
+
+from os import environ
+from time import sleep
+import streamlit as st
+
+from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
+from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \
+ CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \
+ EL_BUILD_KB_WITH_FILES, \
+ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
+ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
+ EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES
+from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS
+from backend.construct.build_agents import build_agents
+from backend.chat_bot.session_manager import SessionManager
+from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
+
+from logger import logger
+
+environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
+
+TOOL_NAMES = {
+ "langchain_retriever_tool": "Self-querying retriever",
+ "vecsql_retriever_tool": "Vector SQL",
+}
+
+
+def on_chat_submit():
+ with st.session_state.next_round.container():
+ with st.chat_message("user"):
+ st.write(st.session_state.chat_input)
+ with st.chat_message("assistant"):
+ container = st.container()
+ st_callback = ChatDataAgentCallBackHandler(
+ container, collapse_completed_thoughts=False
+ )
+ ret = st.session_state.agent(
+ {"input": st.session_state.chat_input}, callbacks=[st_callback]
+ )
+ logger.info(f"ret:{ret}")
+
+
+def clear_history():
+ if "agent" in st.session_state:
+ st.session_state.agent.memory.clear()
+
+
+def back_to_main():
+ if USER_INFO in st.session_state:
+ del st.session_state[USER_INFO]
+ if USER_NAME in st.session_state:
+ del st.session_state[USER_NAME]
+ if JUMP_QUERY_ASK in st.session_state:
+ del st.session_state[JUMP_QUERY_ASK]
+ if EL_SESSION_SELECTOR in st.session_state:
+ del st.session_state[EL_SESSION_SELECTOR]
+ if CHAT_CURRENT_USER_SESSIONS in st.session_state:
+ del st.session_state[CHAT_CURRENT_USER_SESSIONS]
+
+
+def refresh_sessions():
+ chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER]
+ current_user_name = st.session_state[USER_NAME]
+ current_user_sessions = chat_session_manager.list_sessions(current_user_name)
+
+ if not isinstance(current_user_sessions, dict) or not current_user_sessions:
+ # generate a default session for current user.
+ chat_session_manager.add_session(
+ user_id=current_user_name,
+ session_id=f"{current_user_name}?default",
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
+ )
+ st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name)
+ current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS]
+ else:
+ st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions
+
+ # load current user files.
+ st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(
+ current_user_name
+ )
+ # load current user private knowledge bases.
+ st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \
+ st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name)
+ logger.info(f"current user name: {current_user_name}, "
+ f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, "
+ f"user private files: {st.session_state[USER_PRIVATE_FILES]}")
+ st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = {
+ # public retrieval tools
+ **st.session_state[RETRIEVER_TOOLS],
+ # private retrieval tools
+ **st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name),
+ }
+ # print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}")
+ print(f"current_user_sessions is {current_user_sessions}")
+ st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0]
+
+
+# process for session add and delete.
+def on_session_change_submit():
+ if "session_manager" in st.session_state and "session_editor" in st.session_state:
+ try:
+ for elem in st.session_state.session_editor["added_rows"]:
+ if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
+ if elem["session_id"] != "" and "?" not in elem["session_id"]:
+ st.session_state.session_manager.add_session(
+ user_id=st.session_state.user_name,
+ session_id=f"{st.session_state.user_name}?{elem['session_id']}",
+ system_prompt=elem["system_prompt"],
+ )
+ else:
+ st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌")
+ raise KeyError(
+ "`session_id` shouldn't be neither empty nor contain char `?`."
+ )
+ else:
+ st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌")
+ raise KeyError(
+ "You should fill both `session_id` and `system_prompt` to add a column!"
+ )
+ for elem in st.session_state.session_editor["deleted_rows"]:
+ user_name = st.session_state[USER_NAME]
+ session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id']
+ user_with_session_id = f"{user_name}?{session_id}"
+ st.session_state.session_manager.remove_session(session_id=user_with_session_id)
+ st.toast(f"session `{user_with_session_id}` removed.", icon="✅")
+
+ refresh_sessions()
+ except Exception as e:
+ sleep(2)
+ st.error(f"{type(e)}: {str(e)}")
+ finally:
+ st.session_state.session_editor["added_rows"] = []
+ st.session_state.session_editor["deleted_rows"] = []
+ refresh_agent()
+
+
+def create_private_knowledge_base_as_tool():
+ current_user_name = st.session_state[USER_NAME]
+
+ if (
+ EL_PERSONAL_KB_NAME in st.session_state
+ and EL_PERSONAL_KB_DESCRIPTION in st.session_state
+ and EL_BUILD_KB_WITH_FILES in st.session_state
+ and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0
+ and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0
+ and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0
+ ):
+ st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base(
+ user_id=current_user_name,
+ tool_name=st.session_state[EL_PERSONAL_KB_NAME],
+ tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION],
+ files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]],
+ )
+ refresh_sessions()
+ else:
+ st.session_state[EL_UPLOAD_FILES_STATUS].error(
+ "You should fill all fields to build up a tool!"
+ )
+ sleep(2)
+
+
+def remove_private_knowledge_bases():
+ if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]:
+ private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]
+ private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove]
+ # remove these private knowledge bases.
+ st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases(
+ user_id=st.session_state[USER_NAME],
+ private_knowledge_bases=private_knowledge_base_names
+ )
+ refresh_sessions()
+ else:
+ st.session_state[EL_UPLOAD_FILES_STATUS].error(
+ "You should specify at least one private knowledge base to delete!"
+ )
+ time.sleep(2)
+
+
+def refresh_agent():
+ with st.spinner("Initializing session..."):
+ user_name = st.session_state[USER_NAME]
+ session_id = st.session_state[EL_SESSION_SELECTOR]['session_id']
+ user_with_session_id = f"{user_name}?{session_id}"
+
+ if EL_SELECTED_KBS in st.session_state:
+ selected_knowledge_bases = st.session_state[EL_SELECTED_KBS]
+ else:
+ selected_knowledge_bases = ["Wikipedia + Vector SQL"]
+
+ logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}")
+ if EL_SESSION_SELECTOR in st.session_state:
+ system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"]
+ else:
+ system_prompt = DEFAULT_SYSTEM_PROMPT
+
+ st.session_state["agent"] = build_agents(
+ session_id=user_with_session_id,
+ tool_names=selected_knowledge_bases,
+ system_prompt=system_prompt
+ )
+
+
+def add_file():
+ user_name = st.session_state[USER_NAME]
+ if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0:
+ st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️")
+ sleep(2)
+ return
+ try:
+ st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...")
+ st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file(
+ user_id=user_name,
+ files=st.session_state[EL_UPLOAD_FILES]
+ )
+ refresh_sessions()
+ except ValueError as e:
+ st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e))
+ sleep(2)
+
+
+def clear_files():
+ st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME])
+ refresh_sessions()
diff --git a/app/lib/json_conv.py b/app/backend/chat_bot/json_decoder.py
similarity index 100%
rename from app/lib/json_conv.py
rename to app/backend/chat_bot/json_decoder.py
diff --git a/app/backend/chat_bot/message_converter.py b/app/backend/chat_bot/message_converter.py
new file mode 100644
index 0000000..0a7ad0c
--- /dev/null
+++ b/app/backend/chat_bot/message_converter.py
@@ -0,0 +1,67 @@
+import hashlib
+import json
+import time
+from typing import Any
+
+from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
+from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage
+from langchain.schema.messages import ToolMessage
+from sqlalchemy.orm import declarative_base
+
+from backend.chat_bot.tools import create_message_history_table
+
+
+def _message_from_dict(message: dict) -> BaseMessage:
+ _type = message["type"]
+ if _type == "human":
+ return HumanMessage(**message["data"])
+ elif _type == "ai":
+ return AIMessage(**message["data"])
+ elif _type == "system":
+ return SystemMessage(**message["data"])
+ elif _type == "chat":
+ return ChatMessage(**message["data"])
+ elif _type == "function":
+ return FunctionMessage(**message["data"])
+ elif _type == "tool":
+ return ToolMessage(**message["data"])
+ elif _type == "AIMessageChunk":
+ message["data"]["type"] = "ai"
+ return AIMessage(**message["data"])
+ else:
+ raise ValueError(f"Got unexpected message type: {_type}")
+
+
+class DefaultClickhouseMessageConverter(DefaultMessageConverter):
+ """The default message converter for SQLChatMessageHistory."""
+
+ def __init__(self, table_name: str):
+ super().__init__(table_name)
+ self.model_class = create_message_history_table(table_name, declarative_base())
+
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
+ time_stamp = time.time()
+ msg_id = hashlib.sha256(
+ f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest()
+ user_id, _ = session_id.split("?")
+ return self.model_class(
+ id=time_stamp,
+ msg_id=msg_id,
+ user_id=user_id,
+ session_id=session_id,
+ type=message.type,
+ addtionals=json.dumps(message.additional_kwargs),
+ message=json.dumps({
+ "type": message.type,
+ "additional_kwargs": {"timestamp": time_stamp},
+ "data": message.dict()})
+ )
+
+ def from_sql_model(self, sql_message: Any) -> BaseMessage:
+ msg_dump = json.loads(sql_message.message)
+ msg = _message_from_dict(msg_dump)
+ msg.additional_kwargs = msg_dump["additional_kwargs"]
+ return msg
+
+ def get_sql_model_class(self) -> Any:
+ return self.model_class
diff --git a/app/backend/chat_bot/private_knowledge_base.py b/app/backend/chat_bot/private_knowledge_base.py
new file mode 100644
index 0000000..89e8f2b
--- /dev/null
+++ b/app/backend/chat_bot/private_knowledge_base.py
@@ -0,0 +1,167 @@
+import hashlib
+from datetime import datetime
+from typing import List, Optional
+
+import pandas as pd
+from clickhouse_connect import get_client
+from langchain.schema.embeddings import Embeddings
+from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
+from streamlit.runtime.uploaded_file_manager import UploadedFile
+
+from backend.chat_bot.tools import parse_files, extract_embedding
+from backend.construct.build_retriever_tool import create_retriever_tool
+from logger import logger
+
+
+class ChatBotKnowledgeTable:
+ def __init__(self, host, port, username, password,
+ embedding: Embeddings, parser_api_key: str, db="chat",
+ kb_table="private_kb", tool_table="private_tool") -> None:
+ super().__init__()
+ personal_files_schema_ = f"""
+ CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
+ entity_id String,
+ file_name String,
+ text String,
+ user_id String,
+ created_by DateTime,
+ vector Array(Float32),
+ CONSTRAINT cons_vec_len CHECK length(vector) = 768,
+ VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
+ ) ENGINE = ReplacingMergeTree ORDER BY entity_id
+ """
+
+ # `tool_name` represent private knowledge database name.
+ private_knowledge_base_schema_ = f"""
+ CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
+ tool_id String,
+ tool_name String,
+ file_names Array(String),
+ user_id String,
+ created_by DateTime,
+ tool_description String
+ ) ENGINE = ReplacingMergeTree ORDER BY tool_id
+ """
+ self.personal_files_table = kb_table
+ self.private_knowledge_base_table = tool_table
+ config = MyScaleSettings(
+ host=host,
+ port=port,
+ username=username,
+ password=password,
+ database=db,
+ table=kb_table,
+ )
+ self.client = get_client(
+ host=config.host,
+ port=config.port,
+ username=config.username,
+ password=config.password,
+ )
+ self.client.command("SET allow_experimental_object_type=1")
+ self.client.command(personal_files_schema_)
+ self.client.command(private_knowledge_base_schema_)
+ self.parser_api_key = parser_api_key
+ self.vector_store = MyScaleWithoutJSON(
+ embedding=embedding,
+ config=config,
+ must_have_cols=["file_name", "text", "created_by"],
+ )
+
+ # List all files with given `user_id`
+ def list_files(self, user_id: str):
+ query = f"""
+ SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
+ arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
+ FROM {self.vector_store.config.database}.{self.personal_files_table}
+ WHERE user_id = '{user_id}' GROUP BY file_name
+ """
+ return [r for r in self.vector_store.client.query(query).named_results()]
+
+ # Parse and embedding files
+ def add_by_file(self, user_id, files: List[UploadedFile]):
+ data = parse_files(self.parser_api_key, user_id, files)
+ data = extract_embedding(self.vector_store.embeddings, data)
+ self.vector_store.client.insert_df(
+ table=self.personal_files_table,
+ df=pd.DataFrame(data),
+ database=self.vector_store.config.database,
+ )
+
+ # Remove all files and private_knowledge_bases with given `user_id`
+ def clear(self, user_id: str):
+ self.vector_store.client.command(
+ f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
+ f"WHERE user_id='{user_id}'"
+ )
+ query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
+ WHERE user_id = '{user_id}'"""
+ self.vector_store.client.command(query)
+
+ def create_private_knowledge_base(
+ self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
+ ):
+ self.vector_store.client.insert_df(
+ self.private_knowledge_base_table,
+ pd.DataFrame(
+ [
+ {
+ "tool_id": hashlib.sha256(
+ (user_id + tool_name).encode("utf-8")
+ ).hexdigest(),
+ "tool_name": tool_name, # tool_name represent user's private knowledge base.
+ "file_names": files,
+ "user_id": user_id,
+ "created_by": datetime.now(),
+ "tool_description": tool_description,
+ }
+ ]
+ ),
+ database=self.vector_store.config.database,
+ )
+
+ # Show all private knowledge bases with given `user_id`
+ def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
+ extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
+ query = f"""
+ SELECT tool_name, tool_description, length(file_names)
+ FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
+ WHERE user_id = '{user_id}' {extended_where}
+ """
+ return [r for r in self.vector_store.client.query(query).named_results()]
+
+ def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
+ unique_list = list(set(private_knowledge_bases))
+ unique_list = ",".join([f"'{t}'" for t in unique_list])
+ query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
+ WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]"""
+ self.vector_store.client.command(query)
+
+ def as_retrieval_tools(self, user_id, tool_name=None):
+ logger.info(f"")
+ private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
+ retrievers = {}
+ for private_kb in private_knowledge_bases:
+ file_names_sql = f"""
+ SELECT arrayJoin(file_names) FROM (
+ SELECT file_names
+ FROM chat.private_tool
+ WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
+ )
+ """
+ logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
+ res = self.client.query(file_names_sql)
+ file_names = []
+ for line in res.result_rows:
+ file_names.append(line[0])
+ file_names = ', '.join(f"'{item}'" for item in file_names)
+ logger.info(f"user_id is {user_id}, file_names is {file_names}")
+ retrievers[private_kb["tool_name"]] = create_retriever_tool(
+ self.vector_store.as_retriever(
+ search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
+ ),
+ tool_name=private_kb["tool_name"],
+ description=private_kb["tool_description"],
+ )
+ return retrievers
+
diff --git a/app/backend/chat_bot/session_manager.py b/app/backend/chat_bot/session_manager.py
new file mode 100644
index 0000000..bb061f9
--- /dev/null
+++ b/app/backend/chat_bot/session_manager.py
@@ -0,0 +1,96 @@
+import json
+
+from backend.chat_bot.tools import create_session_table, create_message_history_table
+from backend.constants.variables import GLOBAL_CONFIG
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+from datetime import datetime
+from sqlalchemy import orm, create_engine
+from logger import logger
+
+
+def get_sessions(engine, model_class, user_id):
+ with orm.sessionmaker(engine)() as session:
+ result = (
+ session.query(model_class)
+ .where(
+ model_class.session_id == user_id
+ )
+ .order_by(model_class.create_by.desc())
+ )
+ return json.loads(result)
+
+
+class SessionManager:
+ def __init__(
+ self,
+ session_state,
+ host,
+ port,
+ username,
+ password,
+ db='chat',
+ session_table='sessions',
+ msg_table='chat_memory'
+ ) -> None:
+ if GLOBAL_CONFIG.myscale_enable_https == False:
+ conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
+ else:
+ conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
+ self.engine = create_engine(conn_str, echo=False)
+ self.session_model_class = create_session_table(
+ session_table, declarative_base())
+ self.session_model_class.metadata.create_all(self.engine)
+ self.msg_model_class = create_message_history_table(msg_table, declarative_base())
+ self.msg_model_class.metadata.create_all(self.engine)
+ self.session_orm = orm.sessionmaker(self.engine)
+ self.session_state = session_state
+
+ def list_sessions(self, user_id: str):
+ with self.session_orm() as session:
+ result = (
+ session.query(self.session_model_class)
+ .where(
+ self.session_model_class.user_id == user_id
+ )
+ .order_by(self.session_model_class.create_by.desc())
+ )
+ sessions = []
+ for r in result:
+ sessions.append({
+ "session_id": r.session_id.split("?")[-1],
+ "system_prompt": r.system_prompt,
+ })
+ return sessions
+
+ # Update sys_prompt with given session_id
+ def modify_system_prompt(self, session_id, sys_prompt):
+ with self.session_orm() as session:
+ obj = session.query(self.session_model_class).where(
+ self.session_model_class.session_id == session_id).first()
+ if obj:
+ obj.system_prompt = sys_prompt
+ session.commit()
+ else:
+ logger.warning(f"Session {session_id} not found")
+
+ # Add a session(session_id, sys_prompt)
+ def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
+ with self.session_orm() as session:
+ elem = self.session_model_class(
+ user_id=user_id, session_id=session_id, system_prompt=system_prompt,
+ create_by=datetime.now(), additionals=json.dumps(kwargs)
+ )
+ session.add(elem)
+ session.commit()
+
+ # Remove a session and related chat history.
+ def remove_session(self, session_id: str):
+ with self.session_orm() as session:
+ # remove session
+ session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
+ # remove related chat history.
+ session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()
diff --git a/app/backend/chat_bot/tools.py b/app/backend/chat_bot/tools.py
new file mode 100644
index 0000000..7f2c82f
--- /dev/null
+++ b/app/backend/chat_bot/tools.py
@@ -0,0 +1,100 @@
+import hashlib
+from datetime import datetime
+from multiprocessing.pool import ThreadPool
+from typing import List
+
+import requests
+from clickhouse_sqlalchemy import types, engines
+from langchain.schema.embeddings import Embeddings
+from sqlalchemy import Column, Text
+from streamlit.runtime.uploaded_file_manager import UploadedFile
+
+
+def parse_files(api_key, user_id, files: List[UploadedFile]):
+ def parse_file(file: UploadedFile):
+ headers = {
+ "accept": "application/json",
+ "unstructured-api-key": api_key,
+ }
+ data = {"strategy": "auto", "ocr_languages": ["eng"]}
+ file_hash = hashlib.sha256(file.read()).hexdigest()
+ file_data = {"files": (file.name, file.getvalue(), file.type)}
+ response = requests.post(
+ url="https://api.unstructured.io/general/v0/general",
+ headers=headers,
+ data=data,
+ files=file_data
+ )
+ json_response = response.json()
+ if response.status_code != 200:
+ raise ValueError(str(json_response))
+ texts = [
+ {
+ "text": t["text"],
+ "file_name": t["metadata"]["filename"],
+ "entity_id": hashlib.sha256(
+ (file_hash + t["text"]).encode()
+ ).hexdigest(),
+ "user_id": user_id,
+ "created_by": datetime.now(),
+ }
+ for t in json_response
+ if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
+ ]
+ return texts
+
+ with ThreadPool(8) as p:
+ rows = []
+ for r in p.imap_unordered(parse_file, files):
+ rows.extend(r)
+ return rows
+
+
+def extract_embedding(embeddings: Embeddings, texts):
+ if len(texts) > 0:
+ embeddings = embeddings.embed_documents(
+ [t["text"] for _, t in enumerate(texts)])
+ for i, _ in enumerate(texts):
+ texts[i]["vector"] = embeddings[i]
+ return texts
+ raise ValueError("No texts extracted!")
+
+
+def create_message_history_table(table_name: str, base_class):
+ class Message(base_class):
+ __tablename__ = table_name
+ id = Column(types.Float64)
+ session_id = Column(Text)
+ user_id = Column(Text)
+ msg_id = Column(Text, primary_key=True)
+ type = Column(Text)
+ # should be additions, formal developer mistake spell it.
+ addtionals = Column(Text)
+ message = Column(Text)
+ __table_args__ = (
+ engines.MergeTree(
+ partition_by='session_id',
+ order_by=('id', 'msg_id')
+ ),
+ {'comment': 'Store Chat History'}
+ )
+
+ return Message
+
+
+def create_session_table(table_name: str, DynamicBase):
+ class Session(DynamicBase):
+ __tablename__ = table_name
+ user_id = Column(Text)
+ session_id = Column(Text, primary_key=True)
+ system_prompt = Column(Text)
+ # represent create time.
+ create_by = Column(types.DateTime)
+ # should be additions, formal developer mistake spell it.
+ additionals = Column(Text)
+ __table_args__ = (
+ engines.MergeTree(order_by=session_id),
+ {'comment': 'Store Session and Prompts'}
+ )
+
+ return Session
diff --git a/app/backend/constants/__init__.py b/app/backend/constants/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/constants/myscale_tables.py b/app/backend/constants/myscale_tables.py
new file mode 100644
index 0000000..87c31d9
--- /dev/null
+++ b/app/backend/constants/myscale_tables.py
@@ -0,0 +1,128 @@
+from typing import Dict, List
+import streamlit as st
+from langchain.chains.query_constructor.schema import AttributeInfo
+from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings
+from langchain.prompts import PromptTemplate
+
+from backend.types.table_config import TableConfig
+
+
+def hint_arxiv():
+ st.markdown("Here we provide some query samples.")
+ st.markdown("- If you want to search papers with filters")
+ st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more "
+ "than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```")
+ st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```")
+ st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB")
+ st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```")
+ st.markdown("2. ```Introduce some applications of GANs published around 2019.```")
+ st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```")
+
+
+def hint_sql_arxiv():
+ st.markdown('''```sql
+CREATE TABLE default.ChatArXiv (
+ `abstract` String,
+ `id` String,
+ `vector` Array(Float32),
+ `metadata` Object('JSON'),
+ `pubdate` DateTime,
+ `title` String,
+ `categories` Array(String),
+ `authors` Array(String),
+ `comment` String,
+ `primary_category` String,
+ VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
+ CONSTRAINT vec_len CHECK length(vector) = 768)
+ENGINE = ReplacingMergeTree ORDER BY id
+```''')
+
+
+def hint_wiki():
+ st.markdown("Here we provide some query samples.")
+ st.markdown("1. ```Which company did Elon Musk found?```")
+ st.markdown("2. ```What is Iron Gwazi?```")
+ st.markdown("3. ```苹果的发源地是哪里?```")
+ st.markdown("4. ```What is a Ring in mathematics?```")
+ st.markdown("5. ```The producer of Rick and Morty.```")
+ st.markdown("6. ```How low is the temperature on Pluto?```")
+
+
+def hint_sql_wiki():
+ st.markdown('''```sql
+CREATE TABLE wiki.Wikipedia (
+ `id` String,
+ `title` String,
+ `text` String,
+ `url` String,
+ `wiki_id` UInt64,
+ `views` Float32,
+ `paragraph_id` UInt64,
+ `langs` UInt32,
+ `emb` Array(Float32),
+ VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
+ CONSTRAINT emb_len CHECK length(emb) = 768)
+ENGINE = ReplacingMergeTree ORDER BY id
+```''')
+
+
+MYSCALE_TABLES: Dict[str, TableConfig] = {
+ 'Wikipedia': TableConfig(
+ database="wiki",
+ table="Wikipedia",
+ table_contents="Snapshort from Wikipedia for 2022. All in English.",
+ hint=hint_wiki,
+ hint_sql=hint_sql_wiki,
+ # doc_prompt 对 qa source chain 有用
+ doc_prompt=PromptTemplate(
+ input_variables=["page_content", "url", "title", "ref_id", "views"],
+ template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"
+ ),
+ metadata_col_attributes=[
+ AttributeInfo(name="title", description="title of the wikipedia page", type="string"),
+ AttributeInfo(name="text", description="paragraph from this wiki page", type="string"),
+ AttributeInfo(name="views", description="number of views", type="float")
+ ],
+ must_have_col_names=['id', 'title', 'url', 'text', 'views'],
+ vector_col_name="emb",
+ text_col_name="text",
+ metadata_col_name="metadata",
+ emb_model=lambda: SentenceTransformerEmbeddings(
+ model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
+ ),
+ tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages")
+ ),
+ 'ArXiv Papers': TableConfig(
+ database="default",
+ table="ChatArXiv",
+ table_contents="Snapshort from Wikipedia for 2022. All in English.",
+ hint=hint_arxiv,
+ hint_sql=hint_sql_arxiv,
+ doc_prompt=PromptTemplate(
+ input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"],
+ template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t"
+ "Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"
+ ),
+ metadata_col_attributes=[
+ AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"),
+ AttributeInfo(name="authors", description="List of author names", type="list[string]"),
+ AttributeInfo(name="title", description="Title of the paper", type="string"),
+ AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"),
+ AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int")
+ ],
+ must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
+ vector_col_name="vector",
+ text_col_name="abstract",
+ metadata_col_name="metadata",
+ emb_model=lambda: HuggingFaceInstructEmbeddings(
+ model_name='hkunlp/instructor-xl',
+ embed_instruction="Represent the question for retrieving supporting scientific papers: "
+ ),
+ tool_desc=(
+ "search_among_scientific_papers",
+ "Searches among scientific papers from ArXiv and returns research papers"
+ )
+ )
+}
+
+ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()]
diff --git a/app/prompts/arxiv_prompt.py b/app/backend/constants/prompts.py
similarity index 67%
rename from app/prompts/arxiv_prompt.py
rename to app/backend/constants/prompts.py
index 5b8b661..3f4152f 100644
--- a/app/prompts/arxiv_prompt.py
+++ b/app/backend/constants/prompts.py
@@ -1,15 +1,33 @@
-combine_prompt_template = (
- "You are a helpful document assistant. Your task is to provide information and answer any questions "
- + "related to documents given below. You should use the sections, title and abstract of the selected documents as your source of information "
- + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
- + "relevant information in the given sections, you will need to let the user know that the source does not contain "
- + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
- + "corresponding section name and page that you refer to when answering. The following is the related information "
- + "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
- + "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
- )
-
-_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
+from langchain.prompts import ChatPromptTemplate, \
+ SystemMessagePromptTemplate, HumanMessagePromptTemplate
+
+DEFAULT_SYSTEM_PROMPT = (
+ "Do your best to answer the questions. "
+ "Feel free to use any tools available to look up "
+ "relevant information. Please keep all details in query "
+ "when calling search functions."
+)
+
+COMBINE_PROMPT_TEMPLATE = (
+ "You are a helpful document assistant. "
+ "Your task is to provide information and answer any questions related to documents given below. "
+ "You should use the sections, title and abstract of the selected documents as your source of information "
+ "and try to provide concise and accurate answers to any questions asked by the user. "
+ "If you are unable to find relevant information in the given sections, "
+ "you will need to let the user know that the source does not contain relevant information but still try to "
+ "provide an answer based on your general knowledge. You must refer to the corresponding section name and page "
+ "that you refer to when answering. "
+ "The following is the related information about the document that will help you answer users' questions, "
+ "you MUST answer it using question's language:\n\n {summaries} "
+ "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
+)
+
+COMBINE_PROMPT = ChatPromptTemplate.from_strings(
+ string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE),
+ (HumanMessagePromptTemplate, '{question}')])
+
+MYSCALE_PROMPT = """
+You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
@@ -43,7 +61,7 @@
PRIMARY KEY id
Question: What is Feartue Pyramid Network?
-SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
+SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
======== table info ========
diff --git a/app/backend/constants/streamlit_keys.py b/app/backend/constants/streamlit_keys.py
new file mode 100644
index 0000000..234f904
--- /dev/null
+++ b/app/backend/constants/streamlit_keys.py
@@ -0,0 +1,35 @@
+DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started"
+DATA_INITIALIZE_STARTED = "data_initialize_started"
+DATA_INITIALIZE_COMPLETED = "data_initialize_completed"
+
+
+CHAT_SESSION = "sel_sess"
+CHAT_KNOWLEDGE_TABLE = "private_kb"
+
+CHAT_SESSION_MANAGER = "session_manager"
+CHAT_CURRENT_USER_SESSIONS = "current_sessions"
+
+EL_SESSION_SELECTOR = "el_session_selector"
+
+# all personal knowledge bases under a specific user.
+USER_PERSONAL_KNOWLEDGE_BASES = "user_tools"
+# all personal files under a specific user.
+USER_PRIVATE_FILES = "user_files"
+# public and personal knowledge bases.
+AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users"
+
+EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove"
+
+# files needs upload
+EL_UPLOAD_FILES = "el_upload_files"
+EL_UPLOAD_FILES_STATUS = "el_upload_files_status"
+
+# use these files to build private knowledge base
+EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files"
+# build a personal kb, given name.
+EL_PERSONAL_KB_NAME = "el_personal_kb_name"
+# build a personal kb, given description.
+EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description"
+
+# knowledge bases selected by user.
+EL_SELECTED_KBS = "el_selected_kbs"
diff --git a/app/backend/constants/variables.py b/app/backend/constants/variables.py
new file mode 100644
index 0000000..5f29b2f
--- /dev/null
+++ b/app/backend/constants/variables.py
@@ -0,0 +1,58 @@
+from backend.types.global_config import GlobalConfig
+
+# ***** str variables ***** #
+EMBEDDING_MODEL_PREFIX = "embedding_model"
+CHAINS_RETRIEVERS_MAPPING = "sel_map_obj"
+LANGCHAIN_RETRIEVER = "langchain_retriever"
+VECTOR_SQL_RETRIEVER = "vecsql_retriever"
+TABLE_EMBEDDINGS_MAPPING = "embeddings"
+RETRIEVER_TOOLS = "tools"
+DATA_INITIALIZE_STATUS = "data_initialized"
+UI_INITIALIZED = "ui_initialized"
+JUMP_QUERY_ASK = "jump_query_ask"
+USER_NAME = "user_name"
+USER_INFO = "user_info"
+
+DIVIDER_HTML = """
+
+"""
+
+DIVIDER_THIN_HTML = """
+
+"""
+
+
+class RetrieverButtons:
+ vector_sql_query_from_db = "vector_sql_query_from_db"
+ vector_sql_query_with_llm = "vector_sql_query_with_llm"
+ self_query_from_db = "self_query_from_db"
+ self_query_with_llm = "self_query_with_llm"
+
+
+GLOBAL_CONFIG = GlobalConfig()
+
+
+def update_global_config(new_config: GlobalConfig):
+ global GLOBAL_CONFIG
+ GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base
+ GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key
+ GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id
+ GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain
+ GLOBAL_CONFIG.myscale_user = new_config.myscale_user
+ GLOBAL_CONFIG.myscale_password = new_config.myscale_password
+ GLOBAL_CONFIG.myscale_host = new_config.myscale_host
+ GLOBAL_CONFIG.myscale_port = new_config.myscale_port
+ GLOBAL_CONFIG.query_model = new_config.query_model
+ GLOBAL_CONFIG.chat_model = new_config.chat_model
+ GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api
+ GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https
diff --git a/app/backend/construct/__init__.py b/app/backend/construct/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/construct/build_agents.py b/app/backend/construct/build_agents.py
new file mode 100644
index 0000000..b521e69
--- /dev/null
+++ b/app/backend/construct/build_agents.py
@@ -0,0 +1,82 @@
+import os
+from typing import Sequence, List
+
+import streamlit as st
+from langchain.agents import AgentExecutor
+from langchain.schema.language_model import BaseLanguageModel
+from langchain.tools import BaseTool
+
+from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
+from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
+from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
+from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
+from logger import logger
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+from langchain.chat_models import ChatOpenAI
+from langchain.prompts.chat import MessagesPlaceholder
+from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
+from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
+from langchain.schema.messages import SystemMessage
+from langchain.memory import SQLChatMessageHistory
+
+
+def create_agent_executor(
+ agent_name: str,
+ session_id: str,
+ llm: BaseLanguageModel,
+ tools: Sequence[BaseTool],
+ system_prompt: str,
+ **kwargs
+) -> AgentExecutor:
+ agent_name = agent_name.replace(" ", "_")
+ conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
+ chat_memory = SQLChatMessageHistory(
+ session_id,
+ connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
+ custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
+ memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
+
+ prompt = OpenAIFunctionsAgent.create_prompt(
+ system_message=SystemMessage(content=system_prompt),
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
+ )
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
+ return AgentExecutor(
+ agent=agent,
+ tools=tools,
+ memory=memory,
+ verbose=True,
+ return_intermediate_steps=True,
+ **kwargs
+ )
+
+
+def build_agents(
+ session_id: str,
+ tool_names: List[str],
+ model: str = "gpt-3.5-turbo-0125",
+ temperature: float = 0.6,
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT
+):
+ chat_llm = ChatOpenAI(
+ model_name=model,
+ temperature=temperature,
+ base_url=GLOBAL_CONFIG.openai_api_base,
+ api_key=GLOBAL_CONFIG.openai_api_key,
+ streaming=True
+ )
+ tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
+ selected_tools = [tools[k] for k in tool_names]
+ logger.info(f"create agent, use tools: {selected_tools}")
+ agent = create_agent_executor(
+ agent_name="chat_memory",
+ session_id=session_id,
+ llm=chat_llm,
+ tools=selected_tools,
+ system_prompt=system_prompt
+ )
+ return agent
diff --git a/app/backend/construct/build_all.py b/app/backend/construct/build_all.py
new file mode 100644
index 0000000..ebfba90
--- /dev/null
+++ b/app/backend/construct/build_all.py
@@ -0,0 +1,95 @@
+from logger import logger
+from typing import Dict, Any, Union
+
+import streamlit as st
+
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING
+from backend.construct.build_chains import build_retrieval_qa_with_sources_chain
+from backend.construct.build_retriever_tool import create_retriever_tool
+from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever
+from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn
+
+from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \
+ SentenceTransformerEmbeddings
+
+
+@st.cache_resource
+def load_embedding_model_for_table(table_name: str) -> \
+ Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]:
+ with st.spinner(f"Loading embedding models for [{table_name}] ..."):
+ embeddings = MYSCALE_TABLES[table_name].emb_model()
+ return embeddings
+
+
+@st.cache_resource
+def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]:
+ embedding_models = {}
+ for table in MYSCALE_TABLES:
+ embedding_models[table] = load_embedding_model_for_table(table)
+ return embedding_models
+
+
+@st.cache_resource
+def update_retriever_tools():
+ retrievers_tools = {}
+ for table in MYSCALE_TABLES:
+ logger.info(f"Updating retriever tools [, ] for table {table}")
+ retrievers_tools.update(
+ {
+ f"{table} + Self Querying": create_retriever_tool(
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"],
+ *MYSCALE_TABLES[table].tool_desc
+ ),
+ f"{table} + Vector SQL": create_retriever_tool(
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"],
+ *MYSCALE_TABLES[table].tool_desc
+ ),
+ })
+ return retrievers_tools
+
+
+@st.cache_resource
+def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers:
+ metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes
+
+ self_query_retriever = build_self_query_retriever(table_name)
+ self_query_chain = build_retrieval_qa_with_sources_chain(
+ table_name=table_name,
+ retriever=self_query_retriever,
+ chain_name="Self Query Retriever"
+ )
+
+ vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name)
+ vector_sql_chain = build_retrieval_qa_with_sources_chain(
+ table_name=table_name,
+ retriever=vector_sql_retriever,
+ chain_name="Vector SQL DB Retriever"
+ )
+
+ metadata_columns = [
+ MetadataColumn(
+ name=attribute.name,
+ desc=attribute.description,
+ type=attribute.type
+ )
+ for attribute in metadata_col_attributes
+ ]
+ return ChainsAndRetrievers(
+ metadata_columns=metadata_columns,
+ # for self query
+ retriever=self_query_retriever,
+ chain=self_query_chain,
+ # for vector sql
+ sql_retriever=vector_sql_retriever,
+ sql_chain=vector_sql_chain
+ )
+
+
+@st.cache_resource
+def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]:
+ chains_and_retrievers = {}
+ for table in MYSCALE_TABLES:
+ logger.info(f"Building chains, retrievers for table {table}")
+ chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict()
+ return chains_and_retrievers
diff --git a/app/backend/construct/build_chains.py b/app/backend/construct/build_chains.py
new file mode 100644
index 0000000..43b450d
--- /dev/null
+++ b/app/backend/construct/build_chains.py
@@ -0,0 +1,39 @@
+from langchain.chains import LLMChain
+from langchain.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
+from langchain.schema import BaseRetriever
+import streamlit as st
+
+from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
+from backend.chains.stuff_documents import CustomStuffDocumentChain
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.prompts import COMBINE_PROMPT
+from backend.constants.variables import GLOBAL_CONFIG
+
+
+def build_retrieval_qa_with_sources_chain(
+ table_name: str,
+ retriever: BaseRetriever,
+ chain_name: str = ""
+) -> CustomRetrievalQAWithSourcesChain:
+ with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
+ # Assign ref_id for documents
+ custom_stuff_document_chain = CustomStuffDocumentChain(
+ llm_chain=LLMChain(
+ prompt=COMBINE_PROMPT,
+ llm=ChatOpenAI(
+ model_name=GLOBAL_CONFIG.chat_model,
+ openai_api_key=GLOBAL_CONFIG.openai_api_key,
+ temperature=0.6
+ ),
+ ),
+ document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
+ document_variable_name="summaries",
+ )
+ chain = CustomRetrievalQAWithSourcesChain(
+ retriever=retriever,
+ combine_documents_chain=custom_stuff_document_chain,
+ return_source_documents=True,
+ max_tokens_limit=12000,
+ )
+ return chain
diff --git a/app/backend/construct/build_chat_bot.py b/app/backend/construct/build_chat_bot.py
new file mode 100644
index 0000000..a79f4e4
--- /dev/null
+++ b/app/backend/construct/build_chat_bot.py
@@ -0,0 +1,36 @@
+from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable
+from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER
+import streamlit as st
+
+from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING
+from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
+from backend.chat_bot.session_manager import SessionManager
+
+
+def build_chat_knowledge_table():
+ if CHAT_KNOWLEDGE_TABLE not in st.session_state:
+ st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable(
+ host=GLOBAL_CONFIG.myscale_host,
+ port=GLOBAL_CONFIG.myscale_port,
+ username=GLOBAL_CONFIG.myscale_user,
+ password=GLOBAL_CONFIG.myscale_password,
+ # embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"],
+ embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"],
+ parser_api_key=GLOBAL_CONFIG.untrusted_api,
+ )
+
+
+def initialize_session_manager():
+ if CHAT_SESSION not in st.session_state:
+ st.session_state[CHAT_SESSION] = {
+ "session_id": "default",
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
+ }
+ if CHAT_SESSION_MANAGER not in st.session_state:
+ st.session_state[CHAT_SESSION_MANAGER] = SessionManager(
+ st.session_state,
+ host=GLOBAL_CONFIG.myscale_host,
+ port=GLOBAL_CONFIG.myscale_port,
+ username=GLOBAL_CONFIG.myscale_user,
+ password=GLOBAL_CONFIG.myscale_password,
+ )
diff --git a/app/backend/construct/build_retriever_tool.py b/app/backend/construct/build_retriever_tool.py
new file mode 100644
index 0000000..b1ed27d
--- /dev/null
+++ b/app/backend/construct/build_retriever_tool.py
@@ -0,0 +1,45 @@
+import json
+from typing import List
+
+from langchain.pydantic_v1 import BaseModel, Field
+from langchain.schema import BaseRetriever, Document
+from langchain.tools import Tool
+
+from backend.chat_bot.json_decoder import CustomJSONEncoder
+
+
+class RetrieverInput(BaseModel):
+ query: str = Field(description="query to look up in retriever")
+
+
+def create_retriever_tool(
+ retriever: BaseRetriever,
+ tool_name: str,
+ description: str
+) -> Tool:
+ """Create a tool to do retrieval of documents.
+
+ Args:
+ retriever: The retriever to use for the retrieval
+ tool_name: The name for the tool. This will be passed to the language model,
+ so should be unique and somewhat descriptive.
+ description: The description for the tool. This will be passed to the language
+ model, so should be descriptive.
+
+ Returns:
+ Tool class to pass to an agent
+ """
+ def wrap(func):
+ def wrapped_retrieve(*args, **kwargs):
+ docs: List[Document] = func(*args, **kwargs)
+ return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
+
+ return wrapped_retrieve
+
+ return Tool(
+ name=tool_name,
+ description=description,
+ func=wrap(retriever.get_relevant_documents),
+ coroutine=retriever.aget_relevant_documents,
+ args_schema=RetrieverInput,
+ )
diff --git a/app/backend/construct/build_retrievers.py b/app/backend/construct/build_retrievers.py
new file mode 100644
index 0000000..6c3918a
--- /dev/null
+++ b/app/backend/construct/build_retrievers.py
@@ -0,0 +1,120 @@
+import streamlit as st
+from langchain.chat_models import ChatOpenAI
+from langchain.prompts.prompt import PromptTemplate
+from langchain.retrievers.self_query.base import SelfQueryRetriever
+from langchain.retrievers.self_query.myscale import MyScaleTranslator
+from langchain.utilities.sql_database import SQLDatabase
+from langchain.vectorstores import MyScaleSettings
+from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
+from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
+from sqlalchemy import create_engine, MetaData
+
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.prompts import MYSCALE_PROMPT
+from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG
+from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser
+from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson
+from logger import logger
+
+
+@st.cache_resource
+def build_self_query_retriever(table_name: str) -> SelfQueryRetriever:
+ with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."):
+ myscale_connection = {
+ "host": GLOBAL_CONFIG.myscale_host,
+ "port": GLOBAL_CONFIG.myscale_port,
+ "username": GLOBAL_CONFIG.myscale_user,
+ "password": GLOBAL_CONFIG.myscale_password,
+ }
+ myscale_settings = MyScaleSettings(
+ **myscale_connection,
+ database=MYSCALE_TABLES[table_name].database,
+ table=MYSCALE_TABLES[table_name].table,
+ column_map={
+ "id": "id",
+ "text": MYSCALE_TABLES[table_name].text_col_name,
+ "vector": MYSCALE_TABLES[table_name].vector_col_name,
+ # TODO refine MyScaleDB metadata in langchain.
+ "metadata": MYSCALE_TABLES[table_name].metadata_col_name
+ }
+ )
+ myscale_vector_store = MyScaleWithoutMetadataJson(
+ embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
+ config=myscale_settings,
+ must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names
+ )
+
+ with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."):
+ retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm(
+ llm=ChatOpenAI(
+ model_name=GLOBAL_CONFIG.query_model,
+ base_url=GLOBAL_CONFIG.openai_api_base,
+ api_key=GLOBAL_CONFIG.openai_api_key,
+ temperature=0
+ ),
+ vectorstore=myscale_vector_store,
+ document_contents=MYSCALE_TABLES[table_name].table_contents,
+ metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes,
+ use_original_query=False,
+ structured_query_translator=MyScaleTranslator()
+ )
+ return retriever
+
+
+@st.cache_resource
+def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever:
+ """Get a group of relative docs from MyScaleDB"""
+ with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'):
+ if GLOBAL_CONFIG.myscale_enable_https == False:
+ engine = create_engine(
+ f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
+ f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
+ f'/{MYSCALE_TABLES[table_name].database}?protocol=http'
+ )
+ else:
+ engine = create_engine(
+ f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
+ f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
+ f'/{MYSCALE_TABLES[table_name].database}?protocol=https'
+ )
+ metadata = MetaData(bind=engine)
+ logger.info(f"{table_name} metadata is : {metadata}")
+ prompt = PromptTemplate(
+ input_variables=["input", "table_info", "top_k"],
+ template=MYSCALE_PROMPT,
+ )
+ # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column.
+ output_parser = VectorSQLRetrieveOutputParser.from_embeddings(
+ model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
+ # rewrite columns needs be searched.
+ must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names
+ )
+
+ # `db_chain` will generate a SQL
+ vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm(
+ llm=ChatOpenAI(
+ model_name=GLOBAL_CONFIG.query_model,
+ base_url=GLOBAL_CONFIG.openai_api_base,
+ api_key=GLOBAL_CONFIG.openai_api_key,
+ temperature=0
+ ),
+ prompt=prompt,
+ top_k=10,
+ return_direct=True,
+ db=SQLDatabase(
+ engine,
+ None,
+ metadata,
+ include_tables=[MYSCALE_TABLES[table_name].table],
+ max_string_length=1024
+ ),
+ sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type.
+ native_format=True
+ )
+
+ # `retriever` can search a group of documents with `db_chain`
+ vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever(
+ sql_db_chain=vector_sql_db_chain,
+ page_content_key=MYSCALE_TABLES[table_name].text_col_name
+ )
+ return vector_sql_db_chain_retriever
diff --git a/app/backend/retrievers/__init__.py b/app/backend/retrievers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/retrievers/self_query.py b/app/backend/retrievers/self_query.py
new file mode 100644
index 0000000..d30e06a
--- /dev/null
+++ b/app/backend/retrievers/self_query.py
@@ -0,0 +1,89 @@
+from typing import List
+
+import pandas as pd
+import streamlit as st
+from langchain.retrievers import SelfQueryRetriever
+from langchain_core.documents import Document
+from langchain_core.runnables import RunnableConfig
+
+from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
+from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler
+from ui.utils import display
+from logger import logger
+
+
+def process_self_query(selected_table, query_type):
+ place_holder = st.empty()
+ logger.info(
+ f"button-1: {RetrieverButtons.self_query_from_db}, "
+ f"button-2: {RetrieverButtons.self_query_with_llm}, "
+ f"content: {st.session_state.query_self}"
+ )
+ with place_holder.expander('🪵 Chat Log', expanded=True):
+ try:
+ if query_type == RetrieverButtons.self_query_from_db:
+ callback = CustomSelfQueryRetrieverCallBackHandler()
+ retriever: SelfQueryRetriever = \
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"]
+ config: RunnableConfig = {"callbacks": [callback]}
+
+ relevant_docs = retriever.invoke(
+ input=st.session_state.query_self,
+ config=config
+ )
+
+ callback.progress_bar.progress(
+ value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅")
+
+ st.markdown(f"### Self Query Results from `{selected_table}` \n"
+ f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
+ display(
+ dataframe=pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
+ ),
+ columns_=MYSCALE_TABLES[selected_table].must_have_col_names
+ )
+ elif query_type == RetrieverButtons.self_query_with_llm:
+ # callback = CustomSelfQueryRetrieverCallBackHandler()
+ callback = ChatDataSelfAskCallBackHandler()
+ chain: CustomRetrievalQAWithSourcesChain = \
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"]
+ chain_results = chain(st.session_state.query_self, callbacks=[callback])
+ callback.progress_bar.progress(
+ value=1.0,
+ text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅"
+ )
+
+ documents_reference: List[Document] = chain_results["source_documents"]
+ st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n"
+ f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
+ display(
+ pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
+ )
+ )
+ st.markdown(
+ f"### Answer from LLM \n"
+ f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n"
+ )
+ st.write(chain_results['answer'])
+ st.markdown(
+ f"### References from `{selected_table}`\n"
+ f"> Here shows that which documents used by LLM \n\n"
+ )
+ if len(chain_results['sources']) == 0:
+ st.write("No documents is used by LLM.")
+ else:
+ display(
+ dataframe=pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
+ ),
+ columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
+ index='ref_id'
+ )
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
+ except Exception as e:
+ st.write('Oops 😵 Something bad happened...')
+ raise e
diff --git a/app/backend/retrievers/vector_sql_output_parser.py b/app/backend/retrievers/vector_sql_output_parser.py
new file mode 100644
index 0000000..79ebca5
--- /dev/null
+++ b/app/backend/retrievers/vector_sql_output_parser.py
@@ -0,0 +1,23 @@
+from typing import Dict, Any, List
+
+from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
+
+
+class VectorSQLRetrieveOutputParser(VectorSQLOutputParser):
+ """Based on VectorSQLOutputParser
+ It also modify the SQL to get all columns
+ """
+ must_have_columns: List[str]
+
+ @property
+ def _type(self) -> str:
+ return "vector_sql_retrieve_custom"
+
+ def parse(self, text: str) -> Dict[str, Any]:
+ text = text.strip()
+ start = text.upper().find("SELECT")
+ if start >= 0:
+ end = text.upper().find("FROM")
+ text = text.replace(
+ text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
+ return super().parse(text)
diff --git a/app/backend/retrievers/vector_sql_query.py b/app/backend/retrievers/vector_sql_query.py
new file mode 100644
index 0000000..d7d97ae
--- /dev/null
+++ b/app/backend/retrievers/vector_sql_query.py
@@ -0,0 +1,95 @@
+from typing import List
+
+import pandas as pd
+import streamlit as st
+from langchain.schema import Document
+from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
+
+from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
+from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler
+from ui.utils import display
+from logger import logger
+
+
+def process_sql_query(selected_table: str, query_type: str):
+ place_holder = st.empty()
+ logger.info(
+ f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, "
+ f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, "
+ f"table: {selected_table}, "
+ f"content: {st.session_state.query_sql}"
+ )
+ with place_holder.expander('🪵 Query Log', expanded=True):
+ try:
+ if query_type == RetrieverButtons.vector_sql_query_from_db:
+ callback = VectorSQLSearchDBCallBackHandler()
+ vector_sql_retriever: VectorSQLDatabaseChainRetriever = \
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"]
+ relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents(
+ query=st.session_state.query_sql,
+ callbacks=[callback]
+ )
+
+ callback.progress_bar.progress(
+ value=1.0,
+ text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅"
+ )
+
+ st.markdown(f"### Vector Search Results from `{selected_table}` \n"
+ f"> Here we get documents from MyScaleDB with given sql statement \n\n")
+ display(
+ pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
+ )
+ )
+ elif query_type == RetrieverButtons.vector_sql_query_with_llm:
+ callback = VectorSQLSearchLLMCallBackHandler(table=selected_table)
+ vector_sql_chain: CustomRetrievalQAWithSourcesChain = \
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"]
+ chain_results = vector_sql_chain(
+ inputs=st.session_state.query_sql,
+ callbacks=[callback]
+ )
+
+ callback.progress_bar.progress(
+ value=1.0,
+ text="[Question -> LLM -> SQL Statement -> MyScaleDB -> "
+ "(Question,Results) -> LLM -> Results] Done! ✅"
+ )
+
+ documents_reference: List[Document] = chain_results["source_documents"]
+ st.markdown(f"### Vector Search Results from `{selected_table}` \n"
+ f"> Here we get documents from MyScaleDB with given sql statement \n\n")
+ display(
+ pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
+ )
+ )
+ st.markdown(
+ f"### Answer from LLM \n"
+ f"> The response of the LLM when given the vector search results. \n\n"
+ )
+ st.write(chain_results['answer'])
+ st.markdown(
+ f"### References from `{selected_table}`\n"
+ f"> Here shows that which documents used by LLM \n\n"
+ )
+ if len(chain_results['sources']) == 0:
+ st.write("No documents is used by LLM.")
+ else:
+ display(
+ dataframe=pd.DataFrame(
+ [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
+ ),
+ columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
+ index='ref_id'
+ )
+ else:
+ raise NotImplementedError(f"Unsupported query type: {query_type}")
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
+ except Exception as e:
+ st.write('Oops 😵 Something bad happened...')
+ raise e
+
diff --git a/app/backend/types/__init__.py b/app/backend/types/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/types/chains_and_retrievers.py b/app/backend/types/chains_and_retrievers.py
new file mode 100644
index 0000000..91e3082
--- /dev/null
+++ b/app/backend/types/chains_and_retrievers.py
@@ -0,0 +1,34 @@
+from typing import Dict
+from dataclasses import dataclass
+from typing import List, Any
+from langchain.retrievers import SelfQueryRetriever
+from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
+
+from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
+
+
+@dataclass
+class MetadataColumn:
+ name: str
+ desc: str
+ type: str
+
+
+@dataclass
+class ChainsAndRetrievers:
+ metadata_columns: List[MetadataColumn]
+ retriever: SelfQueryRetriever
+ chain: CustomRetrievalQAWithSourcesChain
+ sql_retriever: VectorSQLDatabaseChainRetriever
+ sql_chain: CustomRetrievalQAWithSourcesChain
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "metadata_columns": self.metadata_columns,
+ "retriever": self.retriever,
+ "chain": self.chain,
+ "sql_retriever": self.sql_retriever,
+ "sql_chain": self.sql_chain
+ }
+
+
diff --git a/app/backend/types/global_config.py b/app/backend/types/global_config.py
new file mode 100644
index 0000000..ffa624f
--- /dev/null
+++ b/app/backend/types/global_config.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+from typing import Optional
+
+
+@dataclass
+class GlobalConfig:
+ openai_api_base: Optional[str] = ""
+ openai_api_key: Optional[str] = ""
+
+ auth0_client_id: Optional[str] = ""
+ auth0_domain: Optional[str] = ""
+
+ myscale_user: Optional[str] = ""
+ myscale_password: Optional[str] = ""
+ myscale_host: Optional[str] = ""
+ myscale_port: Optional[int] = 443
+
+ query_model: Optional[str] = ""
+ chat_model: Optional[str] = ""
+
+ untrusted_api: Optional[str] = ""
+ myscale_enable_https: Optional[bool] = True
diff --git a/app/backend/types/table_config.py b/app/backend/types/table_config.py
new file mode 100644
index 0000000..61087ba
--- /dev/null
+++ b/app/backend/types/table_config.py
@@ -0,0 +1,25 @@
+from typing import Callable
+from langchain.chains.query_constructor.schema import AttributeInfo
+from langchain.prompts import PromptTemplate
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class TableConfig:
+ database: str
+ table: str
+ table_contents: str
+ # column names
+ must_have_col_names: List[str]
+ vector_col_name: str
+ text_col_name: str
+ metadata_col_name: str
+ # hint for UI
+ hint: Callable
+ hint_sql: Callable
+ # for langchain
+ doc_prompt: PromptTemplate
+ metadata_col_attributes: List[AttributeInfo]
+ emb_model: Callable
+ tool_desc: tuple
diff --git a/app/backend/vector_store/__init__.py b/app/backend/vector_store/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/backend/vector_store/myscale_without_metadata.py b/app/backend/vector_store/myscale_without_metadata.py
new file mode 100644
index 0000000..f446f64
--- /dev/null
+++ b/app/backend/vector_store/myscale_without_metadata.py
@@ -0,0 +1,52 @@
+from typing import Any, Optional, List
+
+from langchain.docstore.document import Document
+from langchain.embeddings.base import Embeddings
+from langchain.vectorstores.myscale import MyScale, MyScaleSettings
+
+from logger import logger
+
+
+class MyScaleWithoutMetadataJson(MyScale):
+ def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [],
+ **kwargs: Any) -> None:
+ try:
+ super().__init__(embedding, config, **kwargs)
+ except Exception as e:
+ logger.error(e)
+ self.must_have_cols: List[str] = must_have_cols
+
+ def _build_qstr(
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
+ ) -> str:
+ q_emb_str = ",".join(map(str, q_emb))
+ if where_str:
+ where_str = f"PREWHERE {where_str}"
+ else:
+ where_str = ""
+
+ q_str = f"""
+ SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
+ FROM {self.config.database}.{self.config.table}
+ {where_str}
+ ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
+ AS dist {self.dist_order}
+ LIMIT {topk}
+ """
+ return q_str
+
+ def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None,
+ **kwargs: Any) -> List[Document]:
+ q_str = self._build_qstr(embedding, k, where_str)
+ try:
+ return [
+ Document(
+ page_content=r[self.config.column_map["text"]],
+ metadata={k: r[k] for k in self.must_have_cols},
+ )
+ for r in self.client.query(q_str).named_results()
+ ]
+ except Exception as e:
+ logger.error(
+ f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
+ return []
diff --git a/app/callbacks/arxiv_callbacks.py b/app/callbacks/arxiv_callbacks.py
deleted file mode 100644
index efa3caa..0000000
--- a/app/callbacks/arxiv_callbacks.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import streamlit as st
-import json
-import textwrap
-from typing import Dict, Any, List
-from sql_formatter.core import format_sql
-from langchain.callbacks.streamlit.streamlit_callback_handler import (
- LLMThought,
- StreamlitCallbackHandler,
-)
-from langchain.schema.output import LLMResult
-
-
-class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
- def __init__(self) -> None:
- self.progress_bar = st.progress(value=0.0, text="Working...")
- self.tokens_stream = ""
-
- def on_llm_start(self, serialized, prompts, **kwargs) -> None:
- pass
-
- def on_text(self, text: str, **kwargs) -> None:
- self.progress_bar.progress(value=0.2, text="Asking LLM...")
-
- def on_chain_end(self, outputs, **kwargs) -> None:
- self.progress_bar.progress(value=0.6, text="Searching in DB...")
- if "repr" in outputs:
- st.markdown("### Generated Filter")
- st.markdown(
- f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
-
- def on_chain_start(self, serialized, inputs, **kwargs) -> None:
- pass
-
-
-class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
- def __init__(self) -> None:
- self.progress_bar = st.progress(value=0.0, text="Searching DB...")
- self.status_bar = st.empty()
- self.prog_value = 0.0
- self.prog_map = {
- "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
- "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
- "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
- }
-
- def on_llm_start(self, serialized, prompts, **kwargs) -> None:
- pass
-
- def on_text(self, text: str, **kwargs) -> None:
- pass
-
- def on_chain_start(self, serialized, inputs, **kwargs) -> None:
- cid = ".".join(serialized["id"])
- if cid != "langchain.chains.llm.LLMChain":
- self.progress_bar.progress(
- value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
- )
- self.prog_value = self.prog_map[cid]
- else:
- self.prog_value += 0.1
- self.progress_bar.progress(
- value=self.prog_value, text=f"Running Chain `{cid}`..."
- )
-
- def on_chain_end(self, outputs, **kwargs) -> None:
- pass
-
-
-class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
- def __init__(self) -> None:
- self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
- self.status_bar = st.empty()
- self.prog_value = 0
- self.prog_interval = 0.2
-
- def on_llm_start(self, serialized, prompts, **kwargs) -> None:
- pass
-
- def on_llm_end(
- self,
- response: LLMResult,
- *args,
- **kwargs,
- ):
- text = response.generations[0][0].text
- if text.replace(" ", "").upper().startswith("SELECT"):
- st.write("We generated Vector SQL for you:")
- st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
- print(f"Vector SQL: {text}")
- self.prog_value += self.prog_interval
- self.progress_bar.progress(
- value=self.prog_value, text="Searching in DB...")
-
- def on_chain_start(self, serialized, inputs, **kwargs) -> None:
- cid = ".".join(serialized["id"])
- self.prog_value += self.prog_interval
- self.progress_bar.progress(
- value=self.prog_value, text=f"Running Chain `{cid}`..."
- )
-
- def on_chain_end(self, outputs, **kwargs) -> None:
- pass
-
-
-class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
- def __init__(self) -> None:
- self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
- self.status_bar = st.empty()
- self.prog_value = 0
- self.prog_interval = 0.1
-
-
-class LLMThoughtWithKB(LLMThought):
- def on_tool_end(
- self,
- output: str,
- color=None,
- observation_prefix=None,
- llm_prefix=None,
- **kwargs: Any,
- ) -> None:
- try:
- self._container.markdown(
- "\n\n".join(
- ["### Retrieved Documents:"]
- + [
- f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
- for i, r in enumerate(json.loads(output))
- ]
- )
- )
- except Exception as e:
- super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
-
-
-class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
- def on_llm_start(
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
- ) -> None:
- if self._current_thought is None:
- self._current_thought = LLMThoughtWithKB(
- parent_container=self._parent_container,
- expanded=self._expand_new_thoughts,
- collapse_on_complete=self._collapse_completed_thoughts,
- labeler=self._thought_labeler,
- )
-
- self._current_thought.on_llm_start(serialized, prompts)
diff --git a/app/chains/arxiv_chains.py b/app/chains/arxiv_chains.py
deleted file mode 100644
index 359c5f5..0000000
--- a/app/chains/arxiv_chains.py
+++ /dev/null
@@ -1,197 +0,0 @@
-import logging
-import inspect
-from typing import Dict, Any, Optional, List, Tuple
-
-
-from langchain.callbacks.manager import (
- AsyncCallbackManagerForChainRun,
- CallbackManagerForChainRun,
-)
-from langchain.embeddings.base import Embeddings
-from langchain.callbacks.manager import Callbacks
-from langchain.schema.prompt_template import format_document
-from langchain.docstore.document import Document
-from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
-from langchain.vectorstores.myscale import MyScale, MyScaleSettings
-from langchain.chains.combine_documents.stuff import StuffDocumentsChain
-
-from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
-
-logger = logging.getLogger()
-
-
-class MyScaleWithoutMetadataJson(MyScale):
- def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
- super().__init__(embedding, config, **kwargs)
- self.must_have_cols: List[str] = must_have_cols
-
- def _build_qstr(
- self, q_emb: List[float], topk: int, where_str: Optional[str] = None
- ) -> str:
- q_emb_str = ",".join(map(str, q_emb))
- if where_str:
- where_str = f"PREWHERE {where_str}"
- else:
- where_str = ""
-
- q_str = f"""
- SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
- FROM {self.config.database}.{self.config.table}
- {where_str}
- ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
- AS dist {self.dist_order}
- LIMIT {topk}
- """
- return q_str
-
- def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
- q_str = self._build_qstr(embedding, k, where_str)
- try:
- return [
- Document(
- page_content=r[self.config.column_map["text"]],
- metadata={k: r[k] for k in self.must_have_cols},
- )
- for r in self.client.query(q_str).named_results()
- ]
- except Exception as e:
- logger.error(
- f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
- return []
-
-
-class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
- """Based on VectorSQLOutputParser
- It also modify the SQL to get all columns
- """
- must_have_columns: List[str]
-
- @property
- def _type(self) -> str:
- return "vector_sql_retrieve_custom"
-
- def parse(self, text: str) -> Dict[str, Any]:
- text = text.strip()
- start = text.upper().find("SELECT")
- if start >= 0:
- end = text.upper().find("FROM")
- text = text.replace(
- text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
- return super().parse(text)
-
-
-class ArXivStuffDocumentChain(StuffDocumentsChain):
- """Combine arxiv documents with PDF reference number"""
-
- def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
- """Construct inputs from kwargs and docs.
-
- Format and the join all the documents together into one input with name
- `self.document_variable_name`. The pluck any additional variables
- from **kwargs.
-
- Args:
- docs: List of documents to format and then join into single input
- **kwargs: additional inputs to chain, will pluck any other required
- arguments from here.
-
- Returns:
- dictionary of inputs to LLMChain
- """
- # Format each document according to the prompt
- doc_strings = []
- for doc_id, doc in enumerate(docs):
- # add temp reference number in metadata
- doc.metadata.update({'ref_id': doc_id})
- doc.page_content = doc.page_content.replace('\n', ' ')
- doc_strings.append(format_document(doc, self.document_prompt))
- # Join the documents together to put them in the prompt.
- inputs = {
- k: v
- for k, v in kwargs.items()
- if k in self.llm_chain.prompt.input_variables
- }
- inputs[self.document_variable_name] = self.document_separator.join(
- doc_strings)
- return inputs
-
- def combine_docs(
- self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
- ) -> Tuple[str, dict]:
- """Stuff all documents into one prompt and pass to LLM.
-
- Args:
- docs: List of documents to join together into one variable
- callbacks: Optional callbacks to pass along
- **kwargs: additional parameters to use to get inputs to LLMChain.
-
- Returns:
- The first element returned is the single string output. The second
- element returned is a dictionary of other keys to return.
- """
- inputs = self._get_inputs(docs, **kwargs)
- # Call predict on the LLM.
- output = self.llm_chain.predict(callbacks=callbacks, **inputs)
- return output, {}
-
- @property
- def _chain_type(self) -> str:
- return "referenced_stuff_documents_chain"
-
-
-class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
- """QA with source chain for Chat ArXiv app with references
-
- This chain will automatically assign reference number to the article,
- Then parse it back to titles or anything else.
- """
-
- def _call(
- self,
- inputs: Dict[str, Any],
- run_manager: Optional[CallbackManagerForChainRun] = None,
- ) -> Dict[str, str]:
- _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
- accepts_run_manager = (
- "run_manager" in inspect.signature(self._get_docs).parameters
- )
- if accepts_run_manager:
- docs = self._get_docs(inputs, run_manager=_run_manager)
- else:
- docs = self._get_docs(inputs) # type: ignore[call-arg]
-
- answer = self.combine_documents_chain.run(
- input_documents=docs, callbacks=_run_manager.get_child(), **inputs
- )
- # parse source with ref_id
- sources = []
- ref_cnt = 1
- for d in docs:
- ref_id = d.metadata['ref_id']
- if f"Doc #{ref_id}" in answer:
- answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
- if f"#{ref_id}" in answer:
- title = d.metadata['title'].replace('\n', '')
- d.metadata['ref_id'] = ref_cnt
- answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
- sources.append(d)
- ref_cnt += 1
-
- result: Dict[str, Any] = {
- self.answer_key: answer,
- self.sources_answer_key: sources,
- }
- if self.return_source_documents:
- result["source_documents"] = docs
- return result
-
- async def _acall(
- self,
- inputs: Dict[str, Any],
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
- ) -> Dict[str, Any]:
- raise NotImplementedError
-
- @property
- def _chain_type(self) -> str:
- return "arxiv_qa_with_sources_chain"
diff --git a/app/chat.py b/app/chat.py
deleted file mode 100644
index e1f6aee..0000000
--- a/app/chat.py
+++ /dev/null
@@ -1,402 +0,0 @@
-import json
-import pandas as pd
-from os import environ
-from time import sleep
-import datetime
-import streamlit as st
-from lib.sessions import SessionManager
-from lib.private_kb import PrivateKnowledgeBase
-from langchain.schema import HumanMessage, FunctionMessage
-from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
-from lib.json_conv import CustomJSONDecoder
-
-from lib.helper import (
- build_agents,
- MYSCALE_HOST,
- MYSCALE_PASSWORD,
- MYSCALE_PORT,
- MYSCALE_USER,
- DEFAULT_SYSTEM_PROMPT,
- UNSTRUCTURED_API,
-)
-from login import back_to_main
-
-environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
-
-TOOL_NAMES = {
- "langchain_retriever_tool": "Self-querying retriever",
- "vecsql_retriever_tool": "Vector SQL",
-}
-
-
-def on_chat_submit():
- with st.session_state.next_round.container():
- with st.chat_message("user"):
- st.write(st.session_state.chat_input)
- with st.chat_message("assistant"):
- container = st.container()
- st_callback = ChatDataAgentCallBackHandler(
- container, collapse_completed_thoughts=False
- )
- ret = st.session_state.agent(
- {"input": st.session_state.chat_input}, callbacks=[st_callback]
- )
- print(ret)
-
-
-def clear_history():
- if "agent" in st.session_state:
- st.session_state.agent.memory.clear()
-
-
-def back_to_main():
- if "user_info" in st.session_state:
- del st.session_state.user_info
- if "user_name" in st.session_state:
- del st.session_state.user_name
- if "jump_query_ask" in st.session_state:
- del st.session_state.jump_query_ask
- if "sel_sess" in st.session_state:
- del st.session_state.sel_sess
- if "current_sessions" in st.session_state:
- del st.session_state.current_sessions
-
-
-def on_session_change_submit():
- if "session_manager" in st.session_state and "session_editor" in st.session_state:
- print(st.session_state.session_editor)
- try:
- for elem in st.session_state.session_editor["added_rows"]:
- if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
- if elem["session_id"] != "" and "?" not in elem["session_id"]:
- st.session_state.session_manager.add_session(
- user_id=st.session_state.user_name,
- session_id=f"{st.session_state.user_name}?{elem['session_id']}",
- system_prompt=elem["system_prompt"],
- )
- else:
- raise KeyError(
- "`session_id` should NOT be neither empty nor contain question marks."
- )
- else:
- raise KeyError(
- "You should fill both `session_id` and `system_prompt` to add a column!"
- )
- for elem in st.session_state.session_editor["deleted_rows"]:
- st.session_state.session_manager.remove_session(
- session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
- )
- refresh_sessions()
- except Exception as e:
- sleep(2)
- st.error(f"{type(e)}: {str(e)}")
- finally:
- st.session_state.session_editor["added_rows"] = []
- st.session_state.session_editor["deleted_rows"] = []
- refresh_agent()
-
-
-def build_session_manager():
- return SessionManager(
- st.session_state,
- host=MYSCALE_HOST,
- port=MYSCALE_PORT,
- username=MYSCALE_USER,
- password=MYSCALE_PASSWORD,
- )
-
-
-def refresh_sessions():
- st.session_state[
- "current_sessions"
- ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
- if (
- type(st.session_state.current_sessions) is not dict
- and len(st.session_state.current_sessions) <= 0
- ):
- st.session_state.session_manager.add_session(
- st.session_state.user_name,
- f"{st.session_state.user_name}?default",
- DEFAULT_SYSTEM_PROMPT,
- )
- st.session_state[
- "current_sessions"
- ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
- st.session_state["user_files"] = st.session_state.private_kb.list_files(
- st.session_state.user_name
- )
- st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
- st.session_state.user_name
- )
- st.session_state["tools_with_users"] = {
- **st.session_state.tools,
- **st.session_state.private_kb.as_tools(st.session_state.user_name),
- }
- try:
- dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
- "default"
- if "" not in st.session_state
- else st.session_state.sel_session["session_id"]
- )
- except ValueError:
- dfl_indx = 0
- st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
-
-
-def build_kb_as_tool():
- if (
- "b_tool_name" in st.session_state
- and "b_tool_desc" in st.session_state
- and "b_tool_files" in st.session_state
- and len(st.session_state.b_tool_name) > 0
- and len(st.session_state.b_tool_desc) > 0
- and len(st.session_state.b_tool_files) > 0
- ):
- st.session_state.private_kb.create_tool(
- st.session_state.user_name,
- st.session_state.b_tool_name,
- st.session_state.b_tool_desc,
- [f["file_name"] for f in st.session_state.b_tool_files],
- )
- refresh_sessions()
- else:
- st.session_state.tool_status.error(
- "You should fill all fields to build up a tool!"
- )
- sleep(2)
-
-
-def remove_kb():
- if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
- st.session_state.private_kb.remove_tools(
- st.session_state.user_name,
- [f["tool_name"] for f in st.session_state.r_tool_names],
- )
- refresh_sessions()
- else:
- st.session_state.tool_status.error(
- "You should specify at least one tool to delete!"
- )
- sleep(2)
-
-
-def refresh_agent():
- with st.spinner("Initializing session..."):
- print(
- f"??? Changed to ",
- f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
- )
- st.session_state["agent"] = build_agents(
- f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
- ["LangChain Self Query Retriever For Wikipedia"]
- if "selected_tools" not in st.session_state
- else st.session_state.selected_tools,
- system_prompt=DEFAULT_SYSTEM_PROMPT
- if "sel_sess" not in st.session_state
- else st.session_state.sel_sess["system_prompt"],
- )
-
-
-def add_file():
- if (
- "uploaded_files" not in st.session_state
- or len(st.session_state.uploaded_files) == 0
- ):
- st.session_state.tool_status.error("Please upload files!", icon="⚠️")
- sleep(2)
- return
- try:
- st.session_state.tool_status.info("Uploading...")
- st.session_state.private_kb.add_by_file(
- st.session_state.user_name, st.session_state.uploaded_files
- )
- refresh_sessions()
- except ValueError as e:
- st.session_state.tool_status.error("Failed to upload! " + str(e))
- sleep(2)
-
-
-def clear_files():
- st.session_state.private_kb.clear(st.session_state.user_name)
- refresh_sessions()
-
-
-def chat_page():
- if "sel_sess" not in st.session_state:
- st.session_state["sel_sess"] = {
- "session_id": "default",
- "system_prompt": DEFAULT_SYSTEM_PROMPT,
- }
- if "private_kb" not in st.session_state:
- st.session_state["private_kb"] = PrivateKnowledgeBase(
- host=MYSCALE_HOST,
- port=MYSCALE_PORT,
- username=MYSCALE_USER,
- password=MYSCALE_PASSWORD,
- embedding=st.session_state.embeddings["Wikipedia"],
- parser_api_key=UNSTRUCTURED_API,
- )
- if "session_manager" not in st.session_state:
- st.session_state["session_manager"] = build_session_manager()
- with st.sidebar:
- with st.expander("Session Management"):
- if "current_sessions" not in st.session_state:
- refresh_sessions()
- st.info(
- "Here you can set up your session! \n\nYou can **change your prompt** here!",
- icon="🤖",
- )
- st.info(
- (
- "**Add columns by clicking the empty row**.\n"
- "And **delete columns by selecting rows with a press on `DEL` Key**"
- ),
- icon="💡",
- )
- st.info(
- "Don't forget to **click `Submit Change` to save your change**!",
- icon="📒",
- )
- st.data_editor(
- st.session_state.current_sessions,
- num_rows="dynamic",
- key="session_editor",
- use_container_width=True,
- )
- st.button("Submit Change!", on_click=on_session_change_submit)
- with st.expander("Session Selection", expanded=True):
- st.info(
- "If no session is attach to your account, then we will add a default session to you!",
- icon="❤️",
- )
- try:
- dfl_indx = [
- x["session_id"] for x in st.session_state.current_sessions
- ].index(
- "default"
- if "" not in st.session_state
- else st.session_state.sel_session["session_id"]
- )
- except Exception as e:
- print("*** ", str(e))
- dfl_indx = 0
- st.selectbox(
- "Choose a session to chat:",
- options=st.session_state.current_sessions,
- index=dfl_indx,
- key="sel_sess",
- format_func=lambda x: x["session_id"],
- on_change=refresh_agent,
- )
- print(st.session_state.sel_sess)
- with st.expander("Tool Settings", expanded=True):
- st.info(
- "We provides you several knowledge base tools for you. We are building more tools!",
- icon="🔧",
- )
- st.session_state["tool_status"] = st.empty()
- tab_kb, tab_file = st.tabs(
- [
- "Knowledge Bases",
- "File Upload",
- ]
- )
- with tab_kb:
- st.markdown("#### Build You Own Knowledge")
- st.multiselect(
- "Select Files to Build up",
- st.session_state.user_files,
- placeholder="You should upload files first",
- key="b_tool_files",
- format_func=lambda x: x["file_name"],
- )
- st.text_input(
- "Tool Name", "get_relevant_documents", key="b_tool_name")
- st.text_input(
- "Tool Description",
- "Searches among user's private files and returns related documents",
- key="b_tool_desc",
- )
- st.button("Build!", on_click=build_kb_as_tool)
- st.markdown("### Knowledge Base Selection")
- if (
- "user_tools" in st.session_state
- and len(st.session_state.user_tools) > 0
- ):
- st.markdown("***User Created Knowledge Bases***")
- st.dataframe(st.session_state.user_tools)
- st.multiselect(
- "Select a Knowledge Base Tool",
- st.session_state.tools.keys()
- if "tools_with_users" not in st.session_state
- else st.session_state.tools_with_users,
- default=["Wikipedia + Self Querying"],
- key="selected_tools",
- on_change=refresh_agent,
- )
- st.markdown("### Delete Knowledge Base")
- st.multiselect(
- "Choose Knowledge Base to Remove",
- st.session_state.user_tools,
- format_func=lambda x: x["tool_name"],
- key="r_tool_names",
- )
- st.button("Delete", on_click=remove_kb)
- with tab_file:
- st.info(
- (
- "We adopted [Unstructured API](https://unstructured.io/api-key) "
- "here and we only store the processed texts from your documents. "
- "For privacy concerns, please refer to "
- "[our policy issue](https://myscale.com/privacy/)."
- ),
- icon="📃",
- )
- st.file_uploader(
- "Upload files", key="uploaded_files", accept_multiple_files=True
- )
- st.markdown("### Uploaded Files")
- st.dataframe(
- st.session_state.private_kb.list_files(
- st.session_state.user_name),
- use_container_width=True,
- )
- col_1, col_2 = st.columns(2)
- with col_1:
- st.button("Add Files", on_click=add_file)
- with col_2:
- st.button("Clear Files and All Tools",
- on_click=clear_files)
-
- st.button("Clear Chat History", on_click=clear_history)
- st.button("Logout", on_click=back_to_main)
- if "agent" not in st.session_state:
- refresh_agent()
- print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
- for msg in st.session_state.agent.memory.chat_memory.messages:
- speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
- if isinstance(msg, FunctionMessage):
- with st.chat_message("Knowledge Base", avatar="📖"):
- st.write(
- f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
- )
- st.write("Retrieved from knowledge base:")
- try:
- st.dataframe(
- pd.DataFrame.from_records(
- json.loads(msg.content, cls=CustomJSONDecoder)
- ),
- use_container_width=True,
- )
- except:
- st.write(msg.content)
- else:
- if len(msg.content) > 0:
- with st.chat_message(speaker):
- print(type(msg), msg.dict())
- st.write(
- f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
- )
- st.write(f"{msg.content}")
- st.session_state["next_round"] = st.empty()
- st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
diff --git a/app/lib/helper.py b/app/lib/helper.py
deleted file mode 100644
index 9d04210..0000000
--- a/app/lib/helper.py
+++ /dev/null
@@ -1,573 +0,0 @@
-
-import json
-import time
-import hashlib
-from typing import Dict, Any, List, Tuple
-import re
-from os import environ
-import streamlit as st
-from langchain.schema import BaseRetriever
-from langchain.tools import Tool
-from langchain.pydantic_v1 import BaseModel, Field
-
-from sqlalchemy import Column, Text, create_engine, MetaData
-from langchain.agents import AgentExecutor
-try:
- from sqlalchemy.orm import declarative_base
-except ImportError:
- from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import sessionmaker
-from clickhouse_sqlalchemy import (
- types, engines
-)
-from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
-from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
-from langchain.utilities.sql_database import SQLDatabase
-from langchain.chains import LLMChain
-from sqlalchemy import create_engine, MetaData
-from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
- SystemMessagePromptTemplate, HumanMessagePromptTemplate
-from langchain.prompts.prompt import PromptTemplate
-from langchain.chat_models import ChatOpenAI
-from langchain.schema import BaseRetriever, Document
-from langchain import OpenAI
-from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
-from langchain.retrievers.self_query.base import SelfQueryRetriever
-from langchain.retrievers.self_query.myscale import MyScaleTranslator
-from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
-from langchain.vectorstores import MyScaleSettings
-from chains.arxiv_chains import MyScaleWithoutMetadataJson
-from langchain.prompts.prompt import PromptTemplate
-from langchain.prompts.chat import MessagesPlaceholder
-from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
-from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
-from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage, \
- SystemMessage, ChatMessage, ToolMessage
-from langchain.memory import SQLChatMessageHistory
-from langchain.memory.chat_message_histories.sql import \
- DefaultMessageConverter
-from langchain.schema.messages import BaseMessage
-# from langchain.agents.agent_toolkits import create_retriever_tool
-from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
-from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
-from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
-from .json_conv import CustomJSONEncoder
-
-environ['TOKENIZERS_PARALLELISM'] = 'true'
-environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
-
-query_model_name = "gpt-3.5-turbo-0125"
-chat_model_name = "gpt-3.5-turbo-0125"
-
-
-OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
-OPENAI_API_BASE = st.secrets['OPENAI_API_BASE']
-MYSCALE_USER = st.secrets['MYSCALE_USER']
-MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
-MYSCALE_HOST = st.secrets['MYSCALE_HOST']
-MYSCALE_PORT = st.secrets['MYSCALE_PORT']
-UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API']
-
-COMBINE_PROMPT = ChatPromptTemplate.from_strings(
- string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
- (HumanMessagePromptTemplate, '{question}')])
-DEFAULT_SYSTEM_PROMPT = (
- "Do your best to answer the questions. "
- "Feel free to use any tools available to look up "
- "relevant information. Please keep all details in query "
- "when calling search functions."
-)
-
-
-def hint_arxiv():
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
- "For example: \n\n"
- "*If you want to search papers with complex filters*:\n\n"
- "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
- "*If you want to ask questions based on papers in database*:\n\n"
- "- What is PageRank?\n"
- "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
- "- Introduce some applications of GANs published around 2019.\n"
- "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
- "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
- "- Is it possible to synthesize room temperature super conductive material?")
-
-
-def hint_sql_arxiv():
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
- st.markdown('''```sql
-CREATE TABLE default.ChatArXiv (
- `abstract` String,
- `id` String,
- `vector` Array(Float32),
- `metadata` Object('JSON'),
- `pubdate` DateTime,
- `title` String,
- `categories` Array(String),
- `authors` Array(String),
- `comment` String,
- `primary_category` String,
- VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
- CONSTRAINT vec_len CHECK length(vector) = 768)
-ENGINE = ReplacingMergeTree ORDER BY id
-```''')
-
-
-def hint_wiki():
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
- "For example: \n\n"
- "- Which company did Elon Musk found?\n"
- "- What is Iron Gwazi?\n"
- "- What is a Ring in mathematics?\n"
- "- 苹果的发源地是那里?\n")
-
-
-def hint_sql_wiki():
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
- st.markdown('''```sql
-CREATE TABLE wiki.Wikipedia (
- `id` String,
- `title` String,
- `text` String,
- `url` String,
- `wiki_id` UInt64,
- `views` Float32,
- `paragraph_id` UInt64,
- `langs` UInt32,
- `emb` Array(Float32),
- VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
- CONSTRAINT emb_len CHECK length(emb) = 768)
-ENGINE = ReplacingMergeTree ORDER BY id
-```''')
-
-
-sel_map = {
- 'Wikipedia': {
- "database": "wiki",
- "table": "Wikipedia",
- "hint": hint_wiki,
- "hint_sql": hint_sql_wiki,
- "doc_prompt": PromptTemplate(
- input_variables=["page_content",
- "url", "title", "ref_id", "views"],
- template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
- "metadata_cols": [
- AttributeInfo(
- name="title",
- description="title of the wikipedia page",
- type="string",
- ),
- AttributeInfo(
- name="text",
- description="paragraph from this wiki page",
- type="string",
- ),
- AttributeInfo(
- name="views",
- description="number of views",
- type="float"
- ),
- ],
- "must_have_cols": ['id', 'title', 'url', 'text', 'views'],
- "vector_col": "emb",
- "text_col": "text",
- "metadata_col": "metadata",
- "emb_model": lambda: SentenceTransformerEmbeddings(
- model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',),
- "tool_desc": ("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages"),
- },
- 'ArXiv Papers': {
- "database": "default",
- "table": "ChatArXiv",
- "hint": hint_arxiv,
- "hint_sql": hint_sql_arxiv,
- "doc_prompt": PromptTemplate(
- input_variables=["page_content", "id", "title", "ref_id",
- "authors", "pubdate", "categories"],
- template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
- "metadata_cols": [
- AttributeInfo(
- name=VirtualColumnName(name="pubdate"),
- description="The year the paper is published",
- type="timestamp",
- ),
- AttributeInfo(
- name="authors",
- description="List of author names",
- type="list[string]",
- ),
- AttributeInfo(
- name="title",
- description="Title of the paper",
- type="string",
- ),
- AttributeInfo(
- name="categories",
- description="arxiv categories to this paper",
- type="list[string]"
- ),
- AttributeInfo(
- name="length(categories)",
- description="length of arxiv categories to this paper",
- type="int"
- ),
- ],
- "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
- "vector_col": "vector",
- "text_col": "abstract",
- "metadata_col": "metadata",
- "emb_model": lambda: HuggingFaceInstructEmbeddings(
- model_name='hkunlp/instructor-xl',
- embed_instruction="Represent the question for retrieving supporting scientific papers: "),
- "tool_desc": ("search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers"),
- }
-}
-
-
-def build_embedding_model(_sel):
- """Build embedding model
- """
- with st.spinner("Loading Model..."):
- embeddings = sel_map[_sel]["emb_model"]()
- return embeddings
-
-
-def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
- """build chains and retrievers
-
- :param _sel: selected knowledge base
- :type _sel: str
- :return: _description_
- :rtype: Dict[str, Any]
- """
- metadata_field_info = sel_map[_sel]["metadata_cols"]
- retriever = build_self_query(_sel)
- chain = build_qa_chain(_sel, retriever, name="Self Query Retriever")
- sql_retriever = build_vector_sql(_sel)
- sql_chain = build_qa_chain(_sel, sql_retriever, name="Vector SQL")
-
- return {
- "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
- "retriever": retriever,
- "chain": chain,
- "sql_retriever": sql_retriever,
- "sql_chain": sql_chain
- }
-
-
-def build_self_query(_sel: str) -> SelfQueryRetriever:
- """Build self querying retriever
-
- :param _sel: selected knowledge base
- :type _sel: str
- :return: retriever used by chains
- :rtype: SelfQueryRetriever
- """
- with st.spinner(f"Connecting DB for {_sel}..."):
- myscale_connection = {
- "host": MYSCALE_HOST,
- "port": MYSCALE_PORT,
- "username": MYSCALE_USER,
- "password": MYSCALE_PASSWORD,
- }
- config = MyScaleSettings(**myscale_connection,
- database=sel_map[_sel]["database"],
- table=sel_map[_sel]["table"],
- column_map={
- "id": "id",
- "text": sel_map[_sel]["text_col"],
- "vector": sel_map[_sel]["vector_col"],
- "metadata": sel_map[_sel]["metadata_col"]
- })
- doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
- must_have_cols=sel_map[_sel]['must_have_cols'])
-
- with st.spinner(f"Building Self Query Retriever for {_sel}..."):
- metadata_field_info = sel_map[_sel]["metadata_cols"]
- retriever = SelfQueryRetriever.from_llm(
- OpenAI(model_name=query_model_name,
- openai_api_key=OPENAI_API_KEY, temperature=0),
- doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
- use_original_query=False, structured_query_translator=MyScaleTranslator())
- return retriever
-
-
-def build_vector_sql(_sel: str) -> VectorSQLDatabaseChainRetriever:
- """Build Vector SQL Database Retriever
-
- :param _sel: selected knowledge base
- :type _sel: str
- :return: retriever used by chains
- :rtype: VectorSQLDatabaseChainRetriever
- """
- with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
- engine = create_engine(
- f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
- metadata = MetaData(bind=engine)
- PROMPT = PromptTemplate(
- input_variables=["input", "table_info", "top_k"],
- template=_myscale_prompt,
- )
- output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
- model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
- sql_query_chain = VectorSQLDatabaseChain.from_llm(
- llm=OpenAI(model_name=query_model_name,
- openai_api_key=OPENAI_API_KEY, temperature=0),
- prompt=PROMPT,
- top_k=10,
- return_direct=True,
- db=SQLDatabase(engine, None, metadata, max_string_length=1024),
- sql_cmd_parser=output_parser,
- native_format=True
- )
- sql_retriever = VectorSQLDatabaseChainRetriever(
- sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
- return sql_retriever
-
-
-def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str = "Self-query") -> ArXivQAwithSourcesChain:
- """_summary_
-
- :param _sel: selected knowledge base
- :type _sel: str
- :param retriever: retriever used by chains
- :type retriever: BaseRetriever
- :param name: display name, defaults to "Self-query"
- :type name: str, optional
- :return: QA chain interacts with user
- :rtype: ArXivQAwithSourcesChain
- """
- with st.spinner(f'Building QA Chain with {name} for {_sel}...'):
- chain = ArXivQAwithSourcesChain(
- retriever=retriever,
- combine_documents_chain=ArXivStuffDocumentChain(
- llm_chain=LLMChain(
- prompt=COMBINE_PROMPT,
- llm=ChatOpenAI(model_name=chat_model_name,
- openai_api_key=OPENAI_API_KEY, temperature=0.6),
- ),
- document_prompt=sel_map[_sel]["doc_prompt"],
- document_variable_name="summaries",
-
- ),
- return_source_documents=True,
- max_tokens_limit=12000,
- )
- return chain
-
-
-@st.cache_resource
-def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """build all resources
-
- :return: sel_map_obj
- :rtype: Dict[str, Any]
- """
- sel_map_obj = {}
- embeddings = {}
- for k in sel_map:
- embeddings[k] = build_embedding_model(k)
- st.session_state[f'emb_model_{k}'] = embeddings[k]
- sel_map_obj[k] = build_chains_retrievers(k)
- return sel_map_obj, embeddings
-
-
-def create_message_model(table_name, DynamicBase): # type: ignore
- """
- Create a message model for a given table name.
-
- Args:
- table_name: The name of the table to use.
- DynamicBase: The base class to use for the model.
-
- Returns:
- The model class.
-
- """
-
- # Model decleared inside a function to have a dynamic table name
- class Message(DynamicBase):
- __tablename__ = table_name
- id = Column(types.Float64)
- session_id = Column(Text)
- user_id = Column(Text)
- msg_id = Column(Text, primary_key=True)
- type = Column(Text)
- addtionals = Column(Text)
- message = Column(Text)
- __table_args__ = (
- engines.ReplacingMergeTree(
- partition_by='session_id',
- order_by=('id', 'msg_id')),
- {'comment': 'Store Chat History'}
- )
-
- return Message
-
-
-def _message_from_dict(message: dict) -> BaseMessage:
- _type = message["type"]
- if _type == "human":
- return HumanMessage(**message["data"])
- elif _type == "ai":
- return AIMessage(**message["data"])
- elif _type == "system":
- return SystemMessage(**message["data"])
- elif _type == "chat":
- return ChatMessage(**message["data"])
- elif _type == "function":
- return FunctionMessage(**message["data"])
- elif _type == "tool":
- return ToolMessage(**message["data"])
- elif _type == "AIMessageChunk":
- message["data"]["type"] = "ai"
- return AIMessage(**message["data"])
- else:
- raise ValueError(f"Got unexpected message type: {_type}")
-
-
-class DefaultClickhouseMessageConverter(DefaultMessageConverter):
- """The default message converter for SQLChatMessageHistory."""
-
- def __init__(self, table_name: str):
- self.model_class = create_message_model(table_name, declarative_base())
-
- def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
- tstamp = time.time()
- msg_id = hashlib.sha256(
- f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
- user_id, _ = session_id.split("?")
- return self.model_class(
- id=tstamp,
- msg_id=msg_id,
- user_id=user_id,
- session_id=session_id,
- type=message.type,
- addtionals=json.dumps(message.additional_kwargs),
- message=json.dumps({
- "type": message.type,
- "additional_kwargs": {"timestamp": tstamp},
- "data": message.dict()})
- )
-
- def from_sql_model(self, sql_message: Any) -> BaseMessage:
- msg_dump = json.loads(sql_message.message)
- msg = _message_from_dict(msg_dump)
- msg.additional_kwargs = msg_dump["additional_kwargs"]
- return msg
-
- def get_sql_model_class(self) -> Any:
- return self.model_class
-
-
-def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs):
- name = name.replace(" ", "_")
- conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
- chat_memory = SQLChatMessageHistory(
- session_id,
- connection_string=f'{conn_str}/chat?protocol=https',
- custom_message_converter=DefaultClickhouseMessageConverter(name))
- memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
-
- _system_message = SystemMessage(
- content=system_prompt
- )
- prompt = OpenAIFunctionsAgent.create_prompt(
- system_message=_system_message,
- extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
- )
- agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
- return AgentExecutor(
- agent=agent,
- tools=tools,
- memory=memory,
- verbose=True,
- return_intermediate_steps=True,
- **kwargs
- )
-
-
-class RetrieverInput(BaseModel):
- query: str = Field(description="query to look up in retriever")
-
-
-def create_retriever_tool(
- retriever: BaseRetriever, name: str, description: str
-) -> Tool:
- """Create a tool to do retrieval of documents.
-
- Args:
- retriever: The retriever to use for the retrieval
- name: The name for the tool. This will be passed to the language model,
- so should be unique and somewhat descriptive.
- description: The description for the tool. This will be passed to the language
- model, so should be descriptive.
-
- Returns:
- Tool class to pass to an agent
- """
- def wrap(func):
- def wrapped_retrieve(*args, **kwargs):
- docs: List[Document] = func(*args, **kwargs)
- return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
- return wrapped_retrieve
-
- return Tool(
- name=name,
- description=description,
- func=wrap(retriever.get_relevant_documents),
- coroutine=retriever.aget_relevant_documents,
- args_schema=RetrieverInput,
- )
-
-
-@st.cache_resource
-def build_tools():
- """build all resources
-
- :return: sel_map_obj
- :rtype: Dict[str, Any]
- """
- sel_map_obj = {}
- for k in sel_map:
- if f'emb_model_{k}' not in st.session_state:
- st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
- if "sel_map_obj" not in st.session_state:
- st.session_state["sel_map_obj"] = {}
- if k not in st.session_state.sel_map_obj:
- st.session_state["sel_map_obj"][k] = {}
- if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
- st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
- sel_map_obj.update({
- f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
- f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
- })
- return sel_map_obj
-
-
-def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
- chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
- openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
- )
- tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
- sel_tools = [tools[k] for k in tool_names]
- agent = create_agent_executor(
- "chat_memory",
- session_id,
- chat_llm,
- tools=sel_tools,
- system_prompt=system_prompt
- )
- return agent
-
-
-def display(dataframe, columns_=None, index=None):
- if len(dataframe) > 0:
- if index:
- dataframe.set_index(index)
- if columns_:
- st.dataframe(dataframe[columns_])
- else:
- st.dataframe(dataframe)
- else:
- st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
diff --git a/app/lib/private_kb.py b/app/lib/private_kb.py
deleted file mode 100644
index b5c98e0..0000000
--- a/app/lib/private_kb.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import pandas as pd
-import hashlib
-import requests
-from typing import List, Optional
-from datetime import datetime
-from langchain.schema.embeddings import Embeddings
-from streamlit.runtime.uploaded_file_manager import UploadedFile
-from clickhouse_connect import get_client
-from multiprocessing.pool import ThreadPool
-from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
-from .helper import create_retriever_tool
-
-parser_url = "https://api.unstructured.io/general/v0/general"
-
-
-def parse_files(api_key, user_id, files: List[UploadedFile]):
- def parse_file(file: UploadedFile):
- headers = {
- "accept": "application/json",
- "unstructured-api-key": api_key,
- }
- data = {"strategy": "auto", "ocr_languages": ["eng"]}
- file_hash = hashlib.sha256(file.read()).hexdigest()
- file_data = {"files": (file.name, file.getvalue(), file.type)}
- response = requests.post(
- parser_url, headers=headers, data=data, files=file_data
- )
- json_response = response.json()
- if response.status_code != 200:
- raise ValueError(str(json_response))
- texts = [
- {
- "text": t["text"],
- "file_name": t["metadata"]["filename"],
- "entity_id": hashlib.sha256(
- (file_hash + t["text"]).encode()
- ).hexdigest(),
- "user_id": user_id,
- "created_by": datetime.now(),
- }
- for t in json_response
- if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
- ]
- return texts
-
- with ThreadPool(8) as p:
- rows = []
- for r in p.imap_unordered(parse_file, files):
- rows.extend(r)
- return rows
-
-
-def extract_embedding(embeddings: Embeddings, texts):
- if len(texts) > 0:
- embs = embeddings.embed_documents(
- [t["text"] for _, t in enumerate(texts)])
- for i, _ in enumerate(texts):
- texts[i]["vector"] = embs[i]
- return texts
- raise ValueError("No texts extracted!")
-
-
-class PrivateKnowledgeBase:
- def __init__(
- self,
- host,
- port,
- username,
- password,
- embedding: Embeddings,
- parser_api_key,
- db="chat",
- kb_table="private_kb",
- tool_table="private_tool",
- ) -> None:
- super().__init__()
- kb_schema_ = f"""
- CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
- entity_id String,
- file_name String,
- text String,
- user_id String,
- created_by DateTime,
- vector Array(Float32),
- CONSTRAINT cons_vec_len CHECK length(vector) = 768,
- VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
- ) ENGINE = ReplacingMergeTree ORDER BY entity_id
- """
- tool_schema_ = f"""
- CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
- tool_id String,
- tool_name String,
- file_names Array(String),
- user_id String,
- created_by DateTime,
- tool_description String
- ) ENGINE = ReplacingMergeTree ORDER BY tool_id
- """
- self.kb_table = kb_table
- self.tool_table = tool_table
- config = MyScaleSettings(
- host=host,
- port=port,
- username=username,
- password=password,
- database=db,
- table=kb_table,
- )
- client = get_client(
- host=config.host,
- port=config.port,
- username=config.username,
- password=config.password,
- )
- client.command("SET allow_experimental_object_type=1")
- client.command(kb_schema_)
- client.command(tool_schema_)
- self.parser_api_key = parser_api_key
- self.vstore = MyScaleWithoutJSON(
- embedding=embedding,
- config=config,
- must_have_cols=["file_name", "text", "created_by"],
- )
-
- def list_files(self, user_id, tool_name=None):
- query = f"""
- SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
- arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
- FROM {self.vstore.config.database}.{self.kb_table}
- WHERE user_id = '{user_id}' GROUP BY file_name
- """
- return [r for r in self.vstore.client.query(query).named_results()]
-
- def add_by_file(
- self, user_id, files: List[UploadedFile], **kwargs
- ):
- data = parse_files(self.parser_api_key, user_id, files)
- data = extract_embedding(self.vstore.embeddings, data)
- self.vstore.client.insert_df(
- self.kb_table,
- pd.DataFrame(data),
- database=self.vstore.config.database,
- )
-
- def clear(self, user_id):
- self.vstore.client.command(
- f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
- f"WHERE user_id='{user_id}'"
- )
- query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
- WHERE user_id = '{user_id}'"""
- self.vstore.client.command(query)
-
- def create_tool(
- self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
- ):
- self.vstore.client.insert_df(
- self.tool_table,
- pd.DataFrame(
- [
- {
- "tool_id": hashlib.sha256(
- (user_id + tool_name).encode("utf-8")
- ).hexdigest(),
- "tool_name": tool_name,
- "file_names": files,
- "user_id": user_id,
- "created_by": datetime.now(),
- "tool_description": tool_description,
- }
- ]
- ),
- database=self.vstore.config.database,
- )
-
- def list_tools(self, user_id, tool_name=None):
- extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
- query = f"""
- SELECT tool_name, tool_description, length(file_names)
- FROM {self.vstore.config.database}.{self.tool_table}
- WHERE user_id = '{user_id}' {extended_where}
- """
- return [r for r in self.vstore.client.query(query).named_results()]
-
- def remove_tools(self, user_id, tool_names):
- tool_names = ",".join([f"'{t}'" for t in tool_names])
- query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
- WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]"""
- self.vstore.client.command(query)
-
- def as_tools(self, user_id, tool_name=None):
- tools = self.list_tools(user_id=user_id, tool_name=tool_name)
- retrievers = {
- t["tool_name"]: create_retriever_tool(
- self.vstore.as_retriever(
- search_kwargs={
- "where_str": (
- f"user_id='{user_id}' "
- f"""AND file_name IN (
- SELECT arrayJoin(file_names) FROM (
- SELECT file_names
- FROM {self.vstore.config.database}.{self.tool_table}
- WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
- )"""
- )
- },
- ),
- name=t["tool_name"],
- description=t["tool_description"],
- )
- for t in tools
- }
- return retrievers
diff --git a/app/lib/schemas.py b/app/lib/schemas.py
deleted file mode 100644
index ace331f..0000000
--- a/app/lib/schemas.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from sqlalchemy import Column, Text
-from clickhouse_sqlalchemy import types, engines
-
-
-def create_message_model(table_name, DynamicBase): # type: ignore
- """
- Create a message model for a given table name.
-
- Args:
- table_name: The name of the table to use.
- DynamicBase: The base class to use for the model.
-
- Returns:
- The model class.
-
- """
-
- # Model decleared inside a function to have a dynamic table name
- class Message(DynamicBase):
- __tablename__ = table_name
- id = Column(types.Float64)
- session_id = Column(Text)
- user_id = Column(Text)
- msg_id = Column(Text, primary_key=True)
- type = Column(Text)
- addtionals = Column(Text)
- message = Column(Text)
- __table_args__ = (
- engines.ReplacingMergeTree(
- partition_by='session_id',
- order_by=('id', 'msg_id')),
- {'comment': 'Store Chat History'}
- )
-
- return Message
-
-
-def create_session_table(table_name, DynamicBase): # type: ignore
- # Model decleared inside a function to have a dynamic table name
- class Session(DynamicBase):
- __tablename__ = table_name
- user_id = Column(Text)
- session_id = Column(Text, primary_key=True)
- system_prompt = Column(Text)
- create_by = Column(types.DateTime)
- additionals = Column(Text)
- __table_args__ = (
- engines.ReplacingMergeTree(
- order_by=('session_id')),
- {'comment': 'Store Session and Prompts'}
- )
- return Session
diff --git a/app/lib/sessions.py b/app/lib/sessions.py
deleted file mode 100644
index 5cefd63..0000000
--- a/app/lib/sessions.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import json
-try:
- from sqlalchemy.orm import declarative_base
-except ImportError:
- from sqlalchemy.ext.declarative import declarative_base
-from langchain.schema import BaseChatMessageHistory
-from datetime import datetime
-from sqlalchemy import Column, Text, orm, create_engine
-from .schemas import create_message_model, create_session_table
-
-
-def get_sessions(engine, model_class, user_id):
- with orm.sessionmaker(engine)() as session:
- result = (
- session.query(model_class)
- .where(
- model_class.session_id == user_id
- )
- .order_by(model_class.create_by.desc())
- )
- return json.loads(result)
-
-
-class SessionManager:
- def __init__(self, session_state, host, port, username, password,
- db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
- conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
- self.engine = create_engine(conn_str, echo=False)
- self.sess_model_class = create_session_table(
- sess_table, declarative_base())
- self.sess_model_class.metadata.create_all(self.engine)
- self.msg_model_class = create_message_model(
- msg_table, declarative_base())
- self.msg_model_class.metadata.create_all(self.engine)
- self.Session = orm.sessionmaker(self.engine)
- self.session_state = session_state
-
- def list_sessions(self, user_id):
- with self.Session() as session:
- result = (
- session.query(self.sess_model_class)
- .where(
- self.sess_model_class.user_id == user_id
- )
- .order_by(self.sess_model_class.create_by.desc())
- )
- sessions = []
- for r in result:
- sessions.append({
- "session_id": r.session_id.split("?")[-1],
- "system_prompt": r.system_prompt,
- })
- return sessions
-
- def modify_system_prompt(self, session_id, sys_prompt):
- with self.Session() as session:
- session.update(self.sess_model_class).where(
- self.sess_model_class == session_id).value(system_prompt=sys_prompt)
- session.commit()
-
- def add_session(self, user_id, session_id, system_prompt, **kwargs):
- with self.Session() as session:
- elem = self.sess_model_class(
- user_id=user_id, session_id=session_id, system_prompt=system_prompt,
- create_by=datetime.now(), additionals=json.dumps(kwargs)
- )
- session.add(elem)
- session.commit()
-
- def remove_session(self, session_id):
- with self.Session() as session:
- session.query(self.sess_model_class).where(
- self.sess_model_class.session_id == session_id).delete()
- # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
- if "agent" in self.session_state:
- self.session_state.agent.memory.chat_memory.clear()
- if "file_analyzer" in self.session_state:
- self.session_state.file_analyzer.clear_files()
diff --git a/app/logger.py b/app/logger.py
new file mode 100644
index 0000000..a19a1c7
--- /dev/null
+++ b/app/logger.py
@@ -0,0 +1,18 @@
+import logging
+
+
+def setup_logger():
+ logger_ = logging.getLogger('chat-data')
+ logger_.setLevel(logging.INFO)
+ if not logger_.handlers:
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(logging.INFO)
+ formatter = logging.Formatter(
+ '%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]'
+ )
+ console_handler.setFormatter(formatter)
+ logger_.addHandler(console_handler)
+ return logger_
+
+
+logger = setup_logger()
diff --git a/app/login.py b/app/login.py
deleted file mode 100644
index 63231e2..0000000
--- a/app/login.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import streamlit as st
-from auth0_component import login_button
-
-AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
-AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
-
-
-def login():
- if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
- return True
- st.subheader(
- "🤗 Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! 🤗 ")
- st.write("You can now chat with ArXiv and Wikipedia! 🌟\n")
- st.write("Built purely with streamlit 👑 , LangChain 🦜🔗 and love ❤️ for AI!")
- st.write(
- "Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
- st.write(
- "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
- st.divider()
- col1, col2 = st.columns(2, gap='large')
- with col1.container():
- st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
- st.write("In this demo, you will be able to see how those retrievers "
- "**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
- st.session_state["jump_query_ask"] = st.button("Query / Ask")
- with col2.container():
- # st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
- st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
- "an RAG-enabled chatbot within one MyScale instance! ")
- st.write("Log in to Chat with RAG!")
- st.write("Recommended to use the standalone version of Chat-Data, available [here](https://myscale-chatdata.hf.space/).")
- login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
- st.divider()
- st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
- "- [Terms of Sevice](https://myscale.com/terms/)")
- if st.session_state.auth0 is not None:
- st.session_state.user_info = dict(st.session_state.auth0)
- if 'email' in st.session_state.user_info:
- email = st.session_state.user_info["email"]
- else:
- email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
- st.session_state["user_name"] = email
- del st.session_state.auth0
- st.experimental_rerun()
- if st.session_state.jump_query_ask:
- st.experimental_rerun()
-
-
-def back_to_main():
- if "user_info" in st.session_state:
- del st.session_state.user_info
- if "user_name" in st.session_state:
- del st.session_state.user_name
- if "jump_query_ask" in st.session_state:
- del st.session_state.jump_query_ask
diff --git a/app/requirements.txt b/app/requirements.txt
index 3acacf7..cc60e48 100644
--- a/app/requirements.txt
+++ b/app/requirements.txt
@@ -1,15 +1,17 @@
-langchain @ git+https://github.com/myscale/langchain.git@preview#egg=langchain&subdirectory=libs/langchain
-langchain-experimental @ git+https://github.com/myscale/langchain.git@preview#egg=langchain-experimental&subdirectory=libs/experimental
-# https://github.com/PromtEngineer/localGPT/issues/722
-sentence_transformers==2.2.2
+langchain==0.2.1
+langchain-community==0.2.1
+langchain-core==0.2.1
+langchain-experimental==0.0.59
+langchain-openai==0.1.7
+sentence-transformers==2.2.2
InstructorEmbedding
pandas
-sentence_transformers
-streamlit==1.25
+streamlit
+streamlit-extras
streamlit-auth0-component
altair==4.2.2
clickhouse-connect
-openai==0.28
+openai==1.35.3
lark
tiktoken
sql-formatter
diff --git a/app/ui/__init__.py b/app/ui/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/ui/chat_page.py b/app/ui/chat_page.py
new file mode 100644
index 0000000..44e659d
--- /dev/null
+++ b/app/ui/chat_page.py
@@ -0,0 +1,196 @@
+import datetime
+import json
+
+import pandas as pd
+import streamlit as st
+from langchain_core.messages import HumanMessage, FunctionMessage
+from streamlit.delta_generator import DeltaGenerator
+
+from backend.chat_bot.json_decoder import CustomJSONDecoder
+from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \
+ EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \
+ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
+ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
+ CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS
+from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS
+from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager
+from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \
+ create_private_knowledge_base_as_tool, \
+ remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit
+
+
+def render_session_manager():
+ with st.expander("🤖 Session Management"):
+ if CHAT_CURRENT_USER_SESSIONS not in st.session_state:
+ refresh_sessions()
+ st.markdown("Here you can update `session_id` and `system_prompt`")
+ st.markdown("- Click empty row to add a new item")
+ st.markdown("- If needs to delete an item, just click it and press `DEL` key")
+ st.markdown("- Don't forget to submit your change.")
+
+ st.data_editor(
+ data=st.session_state[CHAT_CURRENT_USER_SESSIONS],
+ num_rows="dynamic",
+ key="session_editor",
+ use_container_width=True,
+ )
+ st.button("⏫ Submit", on_click=on_session_change_submit, type="primary")
+
+
+def render_session_selection():
+ with st.expander("✅ Session Selection", expanded=True):
+ st.selectbox(
+ "Choose a `session` to chat",
+ options=st.session_state[CHAT_CURRENT_USER_SESSIONS],
+ index=None,
+ key=EL_SESSION_SELECTOR,
+ format_func=lambda x: x["session_id"],
+ on_change=refresh_agent,
+ )
+
+
+def render_files_manager():
+ with st.expander("📃 **Upload your personal files**", expanded=False):
+ st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).")
+ st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).")
+ st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True)
+ # st.markdown("### Uploaded Files")
+ st.dataframe(
+ data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]),
+ use_container_width=True,
+ )
+ st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty()
+ col_1, col_2 = st.columns(2)
+ with col_1:
+ st.button(label="Upload files", on_click=add_file)
+ with col_2:
+ st.button(label="Clear all files and tools", on_click=clear_files)
+
+
+def _render_create_personal_knowledge_bases(div: DeltaGenerator):
+ with div:
+ st.markdown("- If you haven't upload your personal files, please upload them first.")
+ st.markdown("- Select some **files** to build your `personal knowledge base`.")
+ st.markdown("- Once the your `personal knowledge base` is built, "
+ "it will answer your questions using information from your personal **files**.")
+ st.multiselect(
+ label="⚡️Select some files to build a **personal knowledge base**",
+ options=st.session_state[USER_PRIVATE_FILES],
+ placeholder="You should upload some files first",
+ key=EL_BUILD_KB_WITH_FILES,
+ format_func=lambda x: x["file_name"],
+ )
+ st.text_input(
+ label="⚡️Personal knowledge base name",
+ value="get_relevant_documents",
+ key=EL_PERSONAL_KB_NAME
+ )
+ st.text_input(
+ label="⚡️Personal knowledge base description",
+ value="Searches from some personal files.",
+ key=EL_PERSONAL_KB_DESCRIPTION,
+ )
+ st.button(
+ label="Build 🔧",
+ on_click=create_private_knowledge_base_as_tool
+ )
+
+
+def _render_remove_personal_knowledge_bases(div: DeltaGenerator):
+ with div:
+ st.markdown("> Here is all your personal knowledge bases.")
+ if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0:
+ st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES])
+ else:
+ st.warning("You don't have any personal knowledge bases, please create a new one.")
+ st.multiselect(
+ label="Choose a personal knowledge base to delete",
+ placeholder="Choose a personal knowledge base to delete",
+ options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES],
+ format_func=lambda x: x["tool_name"],
+ key=EL_PERSONAL_KB_NEEDS_REMOVE,
+ )
+ st.button("Delete", on_click=remove_private_knowledge_bases, type="primary")
+
+
+def render_personal_tools_build():
+ with st.expander("🔨 **Build your personal knowledge base**", expanded=True):
+ create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"])
+ _render_create_personal_knowledge_bases(create_new_kb)
+ _render_remove_personal_knowledge_bases(kb_manager)
+
+
+def render_knowledge_base_selector():
+ with st.expander("🙋 **Select some knowledge bases to query**", expanded=True):
+ st.markdown("- Knowledge bases come in two types: `public` and `private`.")
+ st.markdown("- All users can access our `public` knowledge bases.")
+ st.markdown("- Only you can access your `personal` knowledge bases.")
+ options = st.session_state[RETRIEVER_TOOLS].keys()
+ if AVAILABLE_RETRIEVAL_TOOLS in st.session_state:
+ options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS]
+ st.multiselect(
+ label="Select some knowledge base tool",
+ placeholder="Please select some knowledge bases to query",
+ options=options,
+ default=["Wikipedia + Self Querying"],
+ key=EL_SELECTED_KBS,
+ on_change=refresh_agent,
+ )
+
+
+def chat_page():
+ # initialize resources
+ build_chat_knowledge_table()
+ initialize_session_manager()
+
+ # render sidebar
+ with st.sidebar:
+ left, middle, right = st.columns([1, 1, 2])
+ with left:
+ st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main)
+ with right:
+ st.markdown(f"👤 `{st.session_state[USER_NAME]}`")
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
+ render_session_manager()
+ render_session_selection()
+ render_files_manager()
+ render_personal_tools_build()
+ render_knowledge_base_selector()
+
+ # render chat history
+ if "agent" not in st.session_state:
+ refresh_agent()
+ for msg in st.session_state.agent.memory.chat_memory.messages:
+ speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
+ if isinstance(msg, FunctionMessage):
+ with st.chat_message(name="from knowledge base", avatar="📚"):
+ st.write(
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
+ )
+ st.write("Retrieved from knowledge base:")
+ try:
+ st.dataframe(
+ pd.DataFrame.from_records(
+ json.loads(msg.content, cls=CustomJSONDecoder)
+ ),
+ use_container_width=True,
+ )
+ except Exception as e:
+ st.warning(e)
+ st.write(msg.content)
+ else:
+ if len(msg.content) > 0:
+ with st.chat_message(speaker):
+ # print(type(msg), msg.dict())
+ st.write(
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
+ )
+ st.write(f"{msg.content}")
+ st.session_state["next_round"] = st.empty()
+ from streamlit import _bottom
+ with _bottom:
+ col1, col2 = st.columns([1, 16])
+ with col1:
+ st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary")
+ with col2:
+ st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
diff --git a/app/ui/home.py b/app/ui/home.py
new file mode 100644
index 0000000..e2d71ca
--- /dev/null
+++ b/app/ui/home.py
@@ -0,0 +1,156 @@
+import base64
+
+from streamlit_extras.add_vertical_space import add_vertical_space
+from streamlit_extras.card import card
+from streamlit_extras.colored_header import colored_header
+from streamlit_extras.mention import mention
+from streamlit_extras.tags import tagger_component
+
+from logger import logger
+import os
+
+import streamlit as st
+from auth0_component import login_button
+
+from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML
+from streamlit_extras.let_it_rain import rain
+
+
+def render_home():
+ render_home_header()
+ # st.divider()
+ # st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
+ add_vertical_space(5)
+ render_home_content()
+ # st.divider()
+ st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
+ render_home_footer()
+
+
+def render_home_header():
+ logger.info("render home header")
+ st.header("ChatData - Your Intelligent Assistant")
+ st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
+ st.markdown("> [ChatData](https://github.com/myscale/ChatData) \
+ is developed by [MyScale](https://myscale.com/), \
+ it's an integration of [LangChain](https://www.langchain.com/) \
+ and [MyScaleDB](https://github.com/myscale/myscaledb)")
+
+ tagger_component(
+ "Keywords:",
+ ["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"],
+ color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"],
+ )
+ text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4])
+ with text:
+ st.markdown("Related:")
+ with col1.container():
+ mention(
+ label="streamlit",
+ icon="streamlit",
+ url="https://streamlit.io/",
+ write=True
+ )
+ with col2.container():
+ mention(
+ label="langchain",
+ icon="🦜🔗",
+ url="https://www.langchain.com/",
+ write=True
+ )
+ with col3.container():
+ mention(
+ label="streamlit-extras",
+ icon="🪢",
+ url="https://github.com/arnaudmiribel/streamlit-extras",
+ write=True
+ )
+
+
+def _render_self_query_chain_content():
+ col1, col2 = st.columns([1, 1], gap='large')
+ with col1.container():
+ st.image(image='../assets/home_page_background_1.png',
+ caption=None,
+ width=None,
+ use_column_width=True,
+ clamp=False,
+ channels="RGB",
+ output_format="PNG")
+ with col2.container():
+ st.header("VectorSearch & SelfQuery with Sources")
+ st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.")
+ st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options:
+
+1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content.
+
+2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""")
+ add_vertical_space(3)
+ _, middle, _ = st.columns([2, 1, 2], gap='small')
+ with middle.container():
+ st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary")
+
+
+def _render_chat_bot_content():
+ col1, col2 = st.columns(2, gap='large')
+ with col1.container():
+ st.image(image='../assets/home_page_background_2.png',
+ caption=None,
+ width=None,
+ use_column_width=True,
+ clamp=False,
+ channels="RGB",
+ output_format="PNG")
+ with col2.container():
+ st.header("Chat Bot")
+ st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.")
+ st.markdown("- You need to log in. We use `user_name` to identify each customer.")
+ st.markdown("- You can upload your own PDF files and build your own knowledge base. \
+ (This is just a sample application. Please do not upload important or confidential files.)")
+ st.markdown("- A default session will be assigned as your initial chat session. \
+ You can create and switch to other sessions to jump between different chat conversations.")
+ add_vertical_space(1)
+ _, middle, _ = st.columns([1, 2, 1], gap='small')
+ with middle.container():
+ if USER_NAME not in st.session_state:
+ login_button(clientId=os.environ["AUTH0_CLIENT_ID"],
+ domain=os.environ["AUTH0_DOMAIN"],
+ key="auth0")
+ # if user_info:
+ # user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null")
+ # st.session_state[USER_NAME] = user_name
+ # print(user_info)
+
+
+def render_home_content():
+ logger.info("render home content")
+ _render_self_query_chain_content()
+ add_vertical_space(3)
+ _render_chat_bot_content()
+
+
+def render_home_footer():
+ logger.info("render home footer")
+ st.write(
+ "Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!"
+ )
+ st.write(
+ "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
+ st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)")
+
+ # st.write(
+ # "Recommended to use the standalone version of Chat-Data, "
+ # "available [here](https://myscale-chatdata.hf.space/)."
+ # )
+
+ if st.session_state.auth0 is not None:
+ st.session_state[USER_INFO] = dict(st.session_state.auth0)
+ if 'email' in st.session_state[USER_INFO]:
+ email = st.session_state[USER_INFO]["email"]
+ else:
+ email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}"
+ st.session_state["user_name"] = email
+ del st.session_state.auth0
+ st.rerun()
+ if st.session_state.jump_query_ask:
+ st.rerun()
diff --git a/app/ui/retrievers.py b/app/ui/retrievers.py
new file mode 100644
index 0000000..4fb542c
--- /dev/null
+++ b/app/ui/retrievers.py
@@ -0,0 +1,97 @@
+import streamlit as st
+from streamlit_extras.add_vertical_space import add_vertical_space
+
+from backend.constants.myscale_tables import MYSCALE_TABLES
+from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
+from backend.retrievers.self_query import process_self_query
+from backend.retrievers.vector_sql_query import process_sql_query
+from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO
+
+
+def back_to_main():
+ if USER_INFO in st.session_state:
+ del st.session_state[USER_INFO]
+ if USER_NAME in st.session_state:
+ del st.session_state[USER_NAME]
+ if JUMP_QUERY_ASK in st.session_state:
+ del st.session_state[JUMP_QUERY_ASK]
+
+
+def _render_table_selector() -> str:
+ col1, col2 = st.columns(2)
+ with col1:
+ selected_table = st.selectbox(
+ label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
+ options=MYSCALE_TABLES.keys(),
+ )
+ MYSCALE_TABLES[selected_table].hint()
+ with col2:
+ add_vertical_space(1)
+ st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
+ icon='📚')
+ MYSCALE_TABLES[selected_table].hint_sql()
+
+ return selected_table
+
+
+def render_retrievers():
+ st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
+ st.subheader('Please choose a public knowledge base to search.')
+ selected_table = _render_table_selector()
+
+ tab_sql, tab_self_query = st.tabs(
+ tabs=['Vector SQL', 'Self-querying Retriever']
+ )
+
+ with tab_sql:
+ render_tab_sql(selected_table)
+
+ with tab_self_query:
+ render_tab_self_query(selected_table)
+
+
+def render_tab_sql(selected_table: str):
+ st.warning(
+ "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
+ "the metadata we provide. This table allows filters to be established on the following metadata fields:",
+ icon="⚠️")
+ st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
+
+ cols = st.columns([8, 3, 3, 2])
+ cols[0].text_input("Input your question:", key='query_sql')
+ with cols[1].container():
+ add_vertical_space(2)
+ st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
+ with cols[2].container():
+ add_vertical_space(2)
+ st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)
+
+ if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
+ process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)
+
+ if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
+ process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)
+
+
+def render_tab_self_query(selected_table):
+ st.warning(
+ "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
+ "the metadata we provide. This table allows filters to be established on the following metadata fields:",
+ icon="⚠️")
+ st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
+
+ cols = st.columns([8, 3, 3, 2])
+ cols[0].text_input("Input your question:", key='query_self')
+
+ with cols[1].container():
+ add_vertical_space(2)
+ st.button("Retrieve from MyScaleDB ➡️", key='search_self')
+ with cols[2].container():
+ add_vertical_space(2)
+ st.button("Retrieve and answer with LLM ➡️", key='ask_self')
+
+ if st.session_state.search_self:
+ process_self_query(selected_table, RetrieverButtons.self_query_from_db)
+
+ if st.session_state.ask_self:
+ process_self_query(selected_table, RetrieverButtons.self_query_with_llm)
diff --git a/app/ui/utils.py b/app/ui/utils.py
new file mode 100644
index 0000000..92b1177
--- /dev/null
+++ b/app/ui/utils.py
@@ -0,0 +1,18 @@
+import streamlit as st
+
+
+def display(dataframe, columns_=None, index=None):
+ if len(dataframe) > 0:
+ if index:
+ dataframe.set_index(index)
+ if columns_:
+ st.dataframe(dataframe[columns_])
+ else:
+ st.dataframe(dataframe)
+ else:
+ st.write(
+ "Sorry 😵 we didn't find any articles related to your query.\n\n"
+ "Maybe the LLM is too naughty that does not follow our instruction... \n\n"
+ "Please try again and use verbs that may match the datatype.",
+ unsafe_allow_html=True
+ )
diff --git a/assets/chatdata-homepage.png b/assets/chatdata-homepage.png
deleted file mode 100644
index 1580b7f..0000000
Binary files a/assets/chatdata-homepage.png and /dev/null differ
diff --git a/assets/home.png b/assets/home.png
new file mode 100644
index 0000000..22f1eaf
Binary files /dev/null and b/assets/home.png differ
diff --git a/assets/home_page_background_1.png b/assets/home_page_background_1.png
new file mode 100644
index 0000000..1be1602
Binary files /dev/null and b/assets/home_page_background_1.png differ
diff --git a/assets/home_page_background_2.png b/assets/home_page_background_2.png
new file mode 100644
index 0000000..bf18230
Binary files /dev/null and b/assets/home_page_background_2.png differ