diff --git a/.gitignore b/.gitignore index 57e26ae..0d1b862 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,8 @@ dist/ downloads/ eggs/ .eggs/ -lib/ +.idea/ + lib64/ parts/ sdist/ @@ -163,5 +164,5 @@ cython_debug/ # dataset files data/ .streamlit/ -*.ipynb +#*.ipynb .DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 8a9d1ed..977be5b 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ In conclusion, with ChatData, you can effortlessly navigate through vast amounts ➡️ Dive in and experience ChatData on [Hugging Face](https://huggingface.co/spaces/myscale/ChatData)🤗 -![ChatData Homepage](assets/chatdata-homepage.png) +![ChatData Homepage](assets/home.png) ### Data schema @@ -117,15 +117,6 @@ And for overall table schema, please refer to [table creation section in docs/se If you want to use this database with `langchain.chains.sql_database.base.SQLDatabaseChain` or `langchain.retrievers.SQLDatabaseRetriever`, please follow guides on [data preparation section](docs/vector-sql.md#prepare-the-database) and [chain creation section](docs/vector-sql.md#create-the-sqldatabasechain) in docs/vector-sql.md -### How to run ChatData - - - -```bash -python3 -m pip install requirements.txt -python3 -m streamlit run app.py -``` - ### Where can I get those arXiv data? - [From parquet files on S3](docs/self-query.md#insert-data) @@ -167,18 +158,12 @@ cd app/ 2. Create an virtual environment ```bash -python3 -m venv .venv -source .venv/bin/activate +python3 -m venv venv +source venv/bin/activate ``` 3. Install dependencies -> This app is currently using [MyScale's technical preview of LangChain](https://github.com/myscale/langchain/tree/preview). -> ->> It contains improved SQLDatabaseChain in [this PR](https://github.com/hwchase17/langchain/pull/7454) ->> ->> It contains [improved prompts](https://github.com/hwchase17/langchain/pull/6737#discussion_r1243527112) for comparators `LIKE` and `CONTAIN` in [MyScale self-query retriever](https://github.com/hwchase17/langchain/pull/6143). - ```bash python3 -m pip install -r requirements.txt ``` diff --git a/app/.streamlit/config.toml b/app/.streamlit/config.toml index 35fc398..37b9ad3 100644 --- a/app/.streamlit/config.toml +++ b/app/.streamlit/config.toml @@ -1,6 +1,2 @@ [theme] -primaryColor="#523EFD" -backgroundColor="#FFFFFF" -secondaryBackgroundColor="#D4CEFF" -textColor="#262730" -font="sans serif" \ No newline at end of file +base="dark" diff --git a/app/.streamlit/secrets.example.toml b/app/.streamlit/secrets.example.toml index 9c5d344..add2dbf 100644 --- a/app/.streamlit/secrets.example.toml +++ b/app/.streamlit/secrets.example.toml @@ -1,7 +1,8 @@ -MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale +MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" # read-only database provided by MyScale MYSCALE_PORT = 443 MYSCALE_USER = "chatdata" MYSCALE_PASSWORD = "myscale_rocks" +MYSCALE_ENABLE_HTTPS = true OPENAI_API_BASE = "https://api.openai.com/v1" OPENAI_API_KEY = "" UNSTRUCTURED_API = "" # optional if you don't upload documents diff --git a/app/app.py b/app/app.py index ecae84a..df1bbfd 100644 --- a/app/app.py +++ b/app/app.py @@ -1,133 +1,86 @@ -import pandas as pd -from os import environ +import os +import time + import streamlit as st -from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ - ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ - ChatDataSQLAskCallBackHandler +from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \ + DATA_INITIALIZE_STARTED +from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \ + TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config +from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools +from backend.types.global_config import GlobalConfig +from logger import logger +from ui.chat_page import chat_page +from ui.home import render_home +from ui.retrievers import render_retrievers -from chat import chat_page -from login import login, back_to_main -from lib.helper import build_tools, build_all, sel_map, display +# warnings.filterwarnings("ignore", category=UserWarning) -environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] +def prepare_environment(): + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + os.environ["LANGCHAIN_TRACING_V2"] = "false" + # os.environ["LANGCHAIN_API_KEY"] = "" + os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE'] + os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY'] + os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID'] + os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN'] + + update_global_config(GlobalConfig( + openai_api_base=st.secrets['OPENAI_API_BASE'], + openai_api_key=st.secrets['OPENAI_API_KEY'], + auth0_client_id=st.secrets['AUTH0_CLIENT_ID'], + auth0_domain=st.secrets['AUTH0_DOMAIN'], + myscale_user=st.secrets['MYSCALE_USER'], + myscale_password=st.secrets['MYSCALE_PASSWORD'], + myscale_host=st.secrets['MYSCALE_HOST'], + myscale_port=st.secrets['MYSCALE_PORT'], + query_model="gpt-3.5-turbo-0125", + chat_model="gpt-3.5-turbo-0125", + untrusted_api=st.secrets['UNSTRUCTURED_API'], + myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True), + )) -st.set_page_config(page_title="ChatData", - page_icon="https://myscale.com/favicon.ico") -st.markdown( - f""" - """, - unsafe_allow_html=True, -) -st.header("ChatData") -if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state: - st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all() - st.session_state["tools"] = build_tools() +# when refresh browser, all session keys will be cleaned. +def initialize_session_state(): + if DATA_INITIALIZE_STATUS not in st.session_state: + st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED + logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}") + if JUMP_QUERY_ASK not in st.session_state: + st.session_state[JUMP_QUERY_ASK] = False + logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}") -if login(): - if "user_name" in st.session_state: - chat_page() - elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask: - sel = st.selectbox('Choose the knowledge base you want to ask with:', - options=['ArXiv Papers', 'Wikipedia']) - sel_map[sel]['hint']() - tab_sql, tab_self_query = st.tabs( - ['Vector SQL', 'Self-Query Retrievers']) - with tab_sql: - sel_map[sel]['hint_sql']() - st.text_input("Ask a question:", key='query_sql') - cols = st.columns([1, 1, 1, 4]) - cols[0].button("Query", key='search_sql') - cols[1].button("Ask", key='ask_sql') - cols[2].button("Back", key='back_sql', on_click=back_to_main) - plc_hldr = st.empty() - if st.session_state.search_sql: - plc_hldr = st.empty() - print(st.session_state.query_sql) - with plc_hldr.expander('Query Log', expanded=True): - callback = ChatDataSQLSearchCallBackHandler() - try: - docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( - st.session_state.query_sql, callbacks=[callback]) - callback.progress_bar.progress(value=1.0, text="Done!") - docs = pd.DataFrame( - [{**d.metadata, 'abstract': d.page_content} for d in docs]) - display(docs) - except Exception as e: - st.write('Oops 😵 Something bad happened...') - raise e +def initialize_chat_data(): + if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED: + start_time = time.time() + st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED + st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models() + st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers() + st.session_state[RETRIEVER_TOOLS] = update_retriever_tools() + # mark data initialization finished. + st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED + end_time = time.time() + logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, " + f"session state keys: {list(st.session_state.keys())}") - if st.session_state.ask_sql: - plc_hldr = st.empty() - print(st.session_state.query_sql) - with plc_hldr.expander('Chat Log', expanded=True): - callback = ChatDataSQLAskCallBackHandler() - try: - ret = st.session_state.sel_map_obj[sel]["sql_chain"]( - st.session_state.query_sql, callbacks=[callback]) - callback.progress_bar.progress(value=1.0, text="Done!") - st.markdown( - f"### Answer from LLM\n{ret['answer']}\n### References") - docs = ret['sources'] - docs = pd.DataFrame( - [{**d.metadata, 'abstract': d.page_content} for d in docs]) - display( - docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') - except Exception as e: - st.write('Oops 😵 Something bad happened...') - raise e - with tab_self_query: - st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') - st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) - st.text_input("Ask a question:", key='query_self') - cols = st.columns([1, 1, 1, 4]) - cols[0].button("Query", key='search_self') - cols[1].button("Ask", key='ask_self') - cols[2].button("Back", key='back_self', on_click=back_to_main) - plc_hldr = st.empty() - if st.session_state.search_self: - plc_hldr = st.empty() - print(st.session_state.query_self) - with plc_hldr.expander('Query Log', expanded=True): - call_back = None - callback = ChatDataSelfSearchCallBackHandler() - try: - docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( - st.session_state.query_self, callbacks=[callback]) - print(docs) - callback.progress_bar.progress(value=1.0, text="Done!") - docs = pd.DataFrame( - [{**d.metadata, 'abstract': d.page_content} for d in docs]) - display(docs, sel_map[sel]["must_have_cols"]) - except Exception as e: - st.write('Oops 😵 Something bad happened...') - raise e +st.set_page_config( + page_title="ChatData", + page_icon="https://myscale.com/favicon.ico", + initial_sidebar_state="expanded", + layout="wide", +) + +prepare_environment() +initialize_session_state() +initialize_chat_data() - if st.session_state.ask_self: - plc_hldr = st.empty() - print(st.session_state.query_self) - with plc_hldr.expander('Chat Log', expanded=True): - call_back = None - callback = ChatDataSelfAskCallBackHandler() - try: - ret = st.session_state.sel_map_obj[sel]["chain"]( - st.session_state.query_self, callbacks=[callback]) - callback.progress_bar.progress(value=1.0, text="Done!") - st.markdown( - f"### Answer from LLM\n{ret['answer']}\n### References") - docs = ret['sources'] - docs = pd.DataFrame( - [{**d.metadata, 'abstract': d.page_content} for d in docs]) - display( - docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') - except Exception as e: - st.write('Oops 😵 Something bad happened...') - raise e +if USER_NAME in st.session_state: + chat_page() +else: + if st.session_state[JUMP_QUERY_ASK]: + render_retrievers() + else: + render_home() diff --git a/app/backend/__init__.py b/app/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/callbacks/__init__.py b/app/backend/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/callbacks/arxiv_callbacks.py b/app/backend/callbacks/arxiv_callbacks.py new file mode 100644 index 0000000..49f8fd1 --- /dev/null +++ b/app/backend/callbacks/arxiv_callbacks.py @@ -0,0 +1,46 @@ +import json +import textwrap +from typing import Dict, Any, List + +from langchain.callbacks.streamlit.streamlit_callback_handler import ( + LLMThought, + StreamlitCallbackHandler, +) + + +class LLMThoughtWithKnowledgeBase(LLMThought): + def on_tool_end( + self, + output: str, + color=None, + observation_prefix=None, + llm_prefix=None, + **kwargs: Any, + ) -> None: + try: + self._container.markdown( + "\n\n".join( + ["### Retrieved Documents:"] + + [ + f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" + for i, r in enumerate(json.loads(output)) + ] + ) + ) + except Exception as e: + super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) + + +class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + if self._current_thought is None: + self._current_thought = LLMThoughtWithKnowledgeBase( + parent_container=self._parent_container, + expanded=self._expand_new_thoughts, + collapse_on_complete=self._collapse_completed_thoughts, + labeler=self._thought_labeler, + ) + + self._current_thought.on_llm_start(serialized, prompts) diff --git a/app/backend/callbacks/llm_thought_with_table.py b/app/backend/callbacks/llm_thought_with_table.py new file mode 100644 index 0000000..56fb0a1 --- /dev/null +++ b/app/backend/callbacks/llm_thought_with_table.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, List + +import streamlit as st +from langchain_core.outputs import LLMResult +from streamlit.external.langchain import StreamlitCallbackHandler + + +class ChatDataSelfQueryCallBack(StreamlitCallbackHandler): + def __init__(self): + super().__init__(st.container()) + self._current_thought = None + self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...") + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + self.progress_bar.progress(value=0.35, text="Communicate with LLM...") + pass + + def on_chain_end(self, outputs, **kwargs) -> None: + if len(kwargs['tags']) == 0: + self.progress_bar.progress(value=0.75, text="Searching in DB...") + + def on_chain_start(self, serialized, inputs, **kwargs) -> None: + + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + st.markdown("### Generate filter by LLM \n" + "> Here we get `query_constructor` results \n\n") + + self.progress_bar.progress(value=0.5, text="Generate filter by LLM...") + for item in response.generations: + st.markdown(f"{item[0].text}") + + pass diff --git a/app/backend/callbacks/self_query_callbacks.py b/app/backend/callbacks/self_query_callbacks.py new file mode 100644 index 0000000..b8237b8 --- /dev/null +++ b/app/backend/callbacks/self_query_callbacks.py @@ -0,0 +1,57 @@ +from typing import Dict, Any, List + +import streamlit as st +from langchain.callbacks.streamlit.streamlit_callback_handler import ( + StreamlitCallbackHandler, +) +from langchain.schema.output import LLMResult + + +class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler): + def __init__(self): + super().__init__(st.container()) + self._current_thought = None + self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...") + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + self.progress_bar.progress(value=0.35, text="Communicate with LLM...") + pass + + def on_chain_end(self, outputs, **kwargs) -> None: + if len(kwargs['tags']) == 0: + self.progress_bar.progress(value=0.75, text="Searching in DB...") + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + st.markdown("### Generate filter by LLM \n" + "> Here we get `query_constructor` results \n\n") + self.progress_bar.progress(value=0.5, text="Generate filter by LLM...") + for item in response.generations: + st.markdown(f"{item[0].text}") + pass + + +class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): + def __init__(self) -> None: + super().__init__(st.container()) + self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...") + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + + if len(kwargs['tags']) != 0: + self.progress_bar.progress(value=0.5, text="We got filter info from LLM...") + st.markdown("### Generate filter by LLM \n" + "> Here we get `query_constructor` results \n\n") + for item in response.generations: + st.markdown(f"{item[0].text}") + pass + + def on_chain_start(self, serialized, inputs, **kwargs) -> None: + cid = ".".join(serialized["id"]) + if cid.endswith(".CustomStuffDocumentChain"): + self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...") diff --git a/app/backend/callbacks/vector_sql_callbacks.py b/app/backend/callbacks/vector_sql_callbacks.py new file mode 100644 index 0000000..b83cf76 --- /dev/null +++ b/app/backend/callbacks/vector_sql_callbacks.py @@ -0,0 +1,53 @@ +import streamlit as st +from langchain.callbacks.streamlit.streamlit_callback_handler import ( + StreamlitCallbackHandler, +) +from langchain.schema.output import LLMResult +from sql_formatter.core import format_sql + + +class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): + def __init__(self) -> None: + self.progress_bar = st.progress(value=0.0, text="Writing SQL...") + self.status_bar = st.empty() + self.prog_value = 0 + self.prog_interval = 0.2 + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + pass + + def on_llm_end( + self, + response: LLMResult, + *args, + **kwargs, + ): + text = response.generations[0][0].text + if text.replace(" ", "").upper().startswith("SELECT"): + st.markdown("### Generated Vector Search SQL Statement \n" + "> This sql statement is generated by LLM \n\n") + st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") + self.prog_value += self.prog_interval + self.progress_bar.progress( + value=self.prog_value, text="Searching in DB...") + + def on_chain_start(self, serialized, inputs, **kwargs) -> None: + cid = ".".join(serialized["id"]) + self.prog_value += self.prog_interval + self.progress_bar.progress( + value=self.prog_value, text=f"Running Chain `{cid}`..." + ) + + def on_chain_end(self, outputs, **kwargs) -> None: + pass + + +class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): + def __init__(self, table: str) -> None: + self.progress_bar = st.progress(value=0.0, text="Writing SQL...") + self.status_bar = st.empty() + self.prog_value = 0 + self.prog_interval = 0.1 + self.table = table + + diff --git a/app/backend/chains/__init__.py b/app/backend/chains/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/chains/retrieval_qa_with_sources.py b/app/backend/chains/retrieval_qa_with_sources.py new file mode 100644 index 0000000..06b1f39 --- /dev/null +++ b/app/backend/chains/retrieval_qa_with_sources.py @@ -0,0 +1,70 @@ +import inspect +from typing import Dict, Any, Optional, List + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain +from langchain.docstore.document import Document + +from logger import logger + + +class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): + """QA with source chain for Chat ArXiv app with references + + This chain will automatically assign reference number to the article, + Then parse it back to titles or anything else. + """ + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m") + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + accepts_run_manager = ( + "run_manager" in inspect.signature(self._get_docs).parameters + ) + if accepts_run_manager: + docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager) + else: + docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg] + + answer = self.combine_documents_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **inputs + ) + # parse source with ref_id + sources = [] + ref_cnt = 1 + for d in docs: + ref_id = d.metadata['ref_id'] + if f"Doc #{ref_id}" in answer: + answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}") + if f"#{ref_id}" in answer: + title = d.metadata['title'].replace('\n', '') + d.metadata['ref_id'] = ref_cnt + answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]") + sources.append(d) + ref_cnt += 1 + + result: Dict[str, Any] = { + self.answer_key: answer, + self.sources_answer_key: sources, + } + if self.return_source_documents: + result["source_documents"] = docs + return result + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + raise NotImplementedError + + @property + def _chain_type(self) -> str: + return "custom_retrieval_qa_with_sources_chain" diff --git a/app/backend/chains/stuff_documents.py b/app/backend/chains/stuff_documents.py new file mode 100644 index 0000000..0f6c762 --- /dev/null +++ b/app/backend/chains/stuff_documents.py @@ -0,0 +1,65 @@ +from typing import Any, List, Tuple + +from langchain.callbacks.manager import Callbacks +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.docstore.document import Document +from langchain.schema.prompt_template import format_document + + +class CustomStuffDocumentChain(StuffDocumentsChain): + """Combine arxiv documents with PDF reference number""" + + def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: + """Construct inputs from kwargs and docs. + + Format and the join all the documents together into one input with name + `self.document_variable_name`. The pluck any additional variables + from **kwargs. + + Args: + docs: List of documents to format and then join into single input + **kwargs: additional inputs to chain, will pluck any other required + arguments from here. + + Returns: + dictionary of inputs to LLMChain + """ + # Format each document according to the prompt + doc_strings = [] + for doc_id, doc in enumerate(docs): + # add temp reference number in metadata + doc.metadata.update({'ref_id': doc_id}) + doc.page_content = doc.page_content.replace('\n', ' ') + doc_strings.append(format_document(doc, self.document_prompt)) + # Join the documents together to put them in the prompt. + inputs = { + k: v + for k, v in kwargs.items() + if k in self.llm_chain.prompt.input_variables + } + inputs[self.document_variable_name] = self.document_separator.join( + doc_strings) + return inputs + + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: + """Stuff all documents into one prompt and pass to LLM. + + Args: + docs: List of documents to join together into one variable + callbacks: Optional callbacks to pass along + **kwargs: additional parameters to use to get inputs to LLMChain. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ + inputs = self._get_inputs(docs, **kwargs) + # Call predict on the LLM. + output = self.llm_chain.predict(callbacks=callbacks, **inputs) + return output, {} + + @property + def _chain_type(self) -> str: + return "custom_stuff_document_chain" diff --git a/app/backend/chat_bot/__init__.py b/app/backend/chat_bot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/chat_bot/chat.py b/app/backend/chat_bot/chat.py new file mode 100644 index 0000000..cfa1aef --- /dev/null +++ b/app/backend/chat_bot/chat.py @@ -0,0 +1,225 @@ +import time + +from os import environ +from time import sleep +import streamlit as st + +from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT +from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \ + CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \ + EL_BUILD_KB_WITH_FILES, \ + EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \ + USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \ + EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES +from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS +from backend.construct.build_agents import build_agents +from backend.chat_bot.session_manager import SessionManager +from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler + +from logger import logger + +environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] + +TOOL_NAMES = { + "langchain_retriever_tool": "Self-querying retriever", + "vecsql_retriever_tool": "Vector SQL", +} + + +def on_chat_submit(): + with st.session_state.next_round.container(): + with st.chat_message("user"): + st.write(st.session_state.chat_input) + with st.chat_message("assistant"): + container = st.container() + st_callback = ChatDataAgentCallBackHandler( + container, collapse_completed_thoughts=False + ) + ret = st.session_state.agent( + {"input": st.session_state.chat_input}, callbacks=[st_callback] + ) + logger.info(f"ret:{ret}") + + +def clear_history(): + if "agent" in st.session_state: + st.session_state.agent.memory.clear() + + +def back_to_main(): + if USER_INFO in st.session_state: + del st.session_state[USER_INFO] + if USER_NAME in st.session_state: + del st.session_state[USER_NAME] + if JUMP_QUERY_ASK in st.session_state: + del st.session_state[JUMP_QUERY_ASK] + if EL_SESSION_SELECTOR in st.session_state: + del st.session_state[EL_SESSION_SELECTOR] + if CHAT_CURRENT_USER_SESSIONS in st.session_state: + del st.session_state[CHAT_CURRENT_USER_SESSIONS] + + +def refresh_sessions(): + chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER] + current_user_name = st.session_state[USER_NAME] + current_user_sessions = chat_session_manager.list_sessions(current_user_name) + + if not isinstance(current_user_sessions, dict) or not current_user_sessions: + # generate a default session for current user. + chat_session_manager.add_session( + user_id=current_user_name, + session_id=f"{current_user_name}?default", + system_prompt=DEFAULT_SYSTEM_PROMPT, + ) + st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name) + current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS] + else: + st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions + + # load current user files. + st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files( + current_user_name + ) + # load current user private knowledge bases. + st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \ + st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name) + logger.info(f"current user name: {current_user_name}, " + f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, " + f"user private files: {st.session_state[USER_PRIVATE_FILES]}") + st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = { + # public retrieval tools + **st.session_state[RETRIEVER_TOOLS], + # private retrieval tools + **st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name), + } + # print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}") + print(f"current_user_sessions is {current_user_sessions}") + st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0] + + +# process for session add and delete. +def on_session_change_submit(): + if "session_manager" in st.session_state and "session_editor" in st.session_state: + try: + for elem in st.session_state.session_editor["added_rows"]: + if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: + if elem["session_id"] != "" and "?" not in elem["session_id"]: + st.session_state.session_manager.add_session( + user_id=st.session_state.user_name, + session_id=f"{st.session_state.user_name}?{elem['session_id']}", + system_prompt=elem["system_prompt"], + ) + else: + st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌") + raise KeyError( + "`session_id` shouldn't be neither empty nor contain char `?`." + ) + else: + st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌") + raise KeyError( + "You should fill both `session_id` and `system_prompt` to add a column!" + ) + for elem in st.session_state.session_editor["deleted_rows"]: + user_name = st.session_state[USER_NAME] + session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id'] + user_with_session_id = f"{user_name}?{session_id}" + st.session_state.session_manager.remove_session(session_id=user_with_session_id) + st.toast(f"session `{user_with_session_id}` removed.", icon="✅") + + refresh_sessions() + except Exception as e: + sleep(2) + st.error(f"{type(e)}: {str(e)}") + finally: + st.session_state.session_editor["added_rows"] = [] + st.session_state.session_editor["deleted_rows"] = [] + refresh_agent() + + +def create_private_knowledge_base_as_tool(): + current_user_name = st.session_state[USER_NAME] + + if ( + EL_PERSONAL_KB_NAME in st.session_state + and EL_PERSONAL_KB_DESCRIPTION in st.session_state + and EL_BUILD_KB_WITH_FILES in st.session_state + and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0 + and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0 + and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0 + ): + st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base( + user_id=current_user_name, + tool_name=st.session_state[EL_PERSONAL_KB_NAME], + tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION], + files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]], + ) + refresh_sessions() + else: + st.session_state[EL_UPLOAD_FILES_STATUS].error( + "You should fill all fields to build up a tool!" + ) + sleep(2) + + +def remove_private_knowledge_bases(): + if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]: + private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE] + private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove] + # remove these private knowledge bases. + st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases( + user_id=st.session_state[USER_NAME], + private_knowledge_bases=private_knowledge_base_names + ) + refresh_sessions() + else: + st.session_state[EL_UPLOAD_FILES_STATUS].error( + "You should specify at least one private knowledge base to delete!" + ) + time.sleep(2) + + +def refresh_agent(): + with st.spinner("Initializing session..."): + user_name = st.session_state[USER_NAME] + session_id = st.session_state[EL_SESSION_SELECTOR]['session_id'] + user_with_session_id = f"{user_name}?{session_id}" + + if EL_SELECTED_KBS in st.session_state: + selected_knowledge_bases = st.session_state[EL_SELECTED_KBS] + else: + selected_knowledge_bases = ["Wikipedia + Vector SQL"] + + logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}") + if EL_SESSION_SELECTOR in st.session_state: + system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"] + else: + system_prompt = DEFAULT_SYSTEM_PROMPT + + st.session_state["agent"] = build_agents( + session_id=user_with_session_id, + tool_names=selected_knowledge_bases, + system_prompt=system_prompt + ) + + +def add_file(): + user_name = st.session_state[USER_NAME] + if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0: + st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️") + sleep(2) + return + try: + st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...") + st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file( + user_id=user_name, + files=st.session_state[EL_UPLOAD_FILES] + ) + refresh_sessions() + except ValueError as e: + st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e)) + sleep(2) + + +def clear_files(): + st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME]) + refresh_sessions() diff --git a/app/lib/json_conv.py b/app/backend/chat_bot/json_decoder.py similarity index 100% rename from app/lib/json_conv.py rename to app/backend/chat_bot/json_decoder.py diff --git a/app/backend/chat_bot/message_converter.py b/app/backend/chat_bot/message_converter.py new file mode 100644 index 0000000..0a7ad0c --- /dev/null +++ b/app/backend/chat_bot/message_converter.py @@ -0,0 +1,67 @@ +import hashlib +import json +import time +from typing import Any + +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage +from langchain.schema.messages import ToolMessage +from sqlalchemy.orm import declarative_base + +from backend.chat_bot.tools import create_message_history_table + + +def _message_from_dict(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + elif _type == "function": + return FunctionMessage(**message["data"]) + elif _type == "tool": + return ToolMessage(**message["data"]) + elif _type == "AIMessageChunk": + message["data"]["type"] = "ai" + return AIMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected message type: {_type}") + + +class DefaultClickhouseMessageConverter(DefaultMessageConverter): + """The default message converter for SQLChatMessageHistory.""" + + def __init__(self, table_name: str): + super().__init__(table_name) + self.model_class = create_message_history_table(table_name, declarative_base()) + + def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: + time_stamp = time.time() + msg_id = hashlib.sha256( + f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest() + user_id, _ = session_id.split("?") + return self.model_class( + id=time_stamp, + msg_id=msg_id, + user_id=user_id, + session_id=session_id, + type=message.type, + addtionals=json.dumps(message.additional_kwargs), + message=json.dumps({ + "type": message.type, + "additional_kwargs": {"timestamp": time_stamp}, + "data": message.dict()}) + ) + + def from_sql_model(self, sql_message: Any) -> BaseMessage: + msg_dump = json.loads(sql_message.message) + msg = _message_from_dict(msg_dump) + msg.additional_kwargs = msg_dump["additional_kwargs"] + return msg + + def get_sql_model_class(self) -> Any: + return self.model_class diff --git a/app/backend/chat_bot/private_knowledge_base.py b/app/backend/chat_bot/private_knowledge_base.py new file mode 100644 index 0000000..89e8f2b --- /dev/null +++ b/app/backend/chat_bot/private_knowledge_base.py @@ -0,0 +1,167 @@ +import hashlib +from datetime import datetime +from typing import List, Optional + +import pandas as pd +from clickhouse_connect import get_client +from langchain.schema.embeddings import Embeddings +from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings +from streamlit.runtime.uploaded_file_manager import UploadedFile + +from backend.chat_bot.tools import parse_files, extract_embedding +from backend.construct.build_retriever_tool import create_retriever_tool +from logger import logger + + +class ChatBotKnowledgeTable: + def __init__(self, host, port, username, password, + embedding: Embeddings, parser_api_key: str, db="chat", + kb_table="private_kb", tool_table="private_tool") -> None: + super().__init__() + personal_files_schema_ = f""" + CREATE TABLE IF NOT EXISTS {db}.{kb_table}( + entity_id String, + file_name String, + text String, + user_id String, + created_by DateTime, + vector Array(Float32), + CONSTRAINT cons_vec_len CHECK length(vector) = 768, + VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') + ) ENGINE = ReplacingMergeTree ORDER BY entity_id + """ + + # `tool_name` represent private knowledge database name. + private_knowledge_base_schema_ = f""" + CREATE TABLE IF NOT EXISTS {db}.{tool_table}( + tool_id String, + tool_name String, + file_names Array(String), + user_id String, + created_by DateTime, + tool_description String + ) ENGINE = ReplacingMergeTree ORDER BY tool_id + """ + self.personal_files_table = kb_table + self.private_knowledge_base_table = tool_table + config = MyScaleSettings( + host=host, + port=port, + username=username, + password=password, + database=db, + table=kb_table, + ) + self.client = get_client( + host=config.host, + port=config.port, + username=config.username, + password=config.password, + ) + self.client.command("SET allow_experimental_object_type=1") + self.client.command(personal_files_schema_) + self.client.command(private_knowledge_base_schema_) + self.parser_api_key = parser_api_key + self.vector_store = MyScaleWithoutJSON( + embedding=embedding, + config=config, + must_have_cols=["file_name", "text", "created_by"], + ) + + # List all files with given `user_id` + def list_files(self, user_id: str): + query = f""" + SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, + arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars + FROM {self.vector_store.config.database}.{self.personal_files_table} + WHERE user_id = '{user_id}' GROUP BY file_name + """ + return [r for r in self.vector_store.client.query(query).named_results()] + + # Parse and embedding files + def add_by_file(self, user_id, files: List[UploadedFile]): + data = parse_files(self.parser_api_key, user_id, files) + data = extract_embedding(self.vector_store.embeddings, data) + self.vector_store.client.insert_df( + table=self.personal_files_table, + df=pd.DataFrame(data), + database=self.vector_store.config.database, + ) + + # Remove all files and private_knowledge_bases with given `user_id` + def clear(self, user_id: str): + self.vector_store.client.command( + f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} " + f"WHERE user_id='{user_id}'" + ) + query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} + WHERE user_id = '{user_id}'""" + self.vector_store.client.command(query) + + def create_private_knowledge_base( + self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None + ): + self.vector_store.client.insert_df( + self.private_knowledge_base_table, + pd.DataFrame( + [ + { + "tool_id": hashlib.sha256( + (user_id + tool_name).encode("utf-8") + ).hexdigest(), + "tool_name": tool_name, # tool_name represent user's private knowledge base. + "file_names": files, + "user_id": user_id, + "created_by": datetime.now(), + "tool_description": tool_description, + } + ] + ), + database=self.vector_store.config.database, + ) + + # Show all private knowledge bases with given `user_id` + def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None): + extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else "" + query = f""" + SELECT tool_name, tool_description, length(file_names) + FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} + WHERE user_id = '{user_id}' {extended_where} + """ + return [r for r in self.vector_store.client.query(query).named_results()] + + def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]): + unique_list = list(set(private_knowledge_bases)) + unique_list = ",".join([f"'{t}'" for t in unique_list]) + query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} + WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]""" + self.vector_store.client.command(query) + + def as_retrieval_tools(self, user_id, tool_name=None): + logger.info(f"") + private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name) + retrievers = {} + for private_kb in private_knowledge_bases: + file_names_sql = f""" + SELECT arrayJoin(file_names) FROM ( + SELECT file_names + FROM chat.private_tool + WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}' + ) + """ + logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}") + res = self.client.query(file_names_sql) + file_names = [] + for line in res.result_rows: + file_names.append(line[0]) + file_names = ', '.join(f"'{item}'" for item in file_names) + logger.info(f"user_id is {user_id}, file_names is {file_names}") + retrievers[private_kb["tool_name"]] = create_retriever_tool( + self.vector_store.as_retriever( + search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"}, + ), + tool_name=private_kb["tool_name"], + description=private_kb["tool_description"], + ) + return retrievers + diff --git a/app/backend/chat_bot/session_manager.py b/app/backend/chat_bot/session_manager.py new file mode 100644 index 0000000..bb061f9 --- /dev/null +++ b/app/backend/chat_bot/session_manager.py @@ -0,0 +1,96 @@ +import json + +from backend.chat_bot.tools import create_session_table, create_message_history_table +from backend.constants.variables import GLOBAL_CONFIG + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base +from datetime import datetime +from sqlalchemy import orm, create_engine +from logger import logger + + +def get_sessions(engine, model_class, user_id): + with orm.sessionmaker(engine)() as session: + result = ( + session.query(model_class) + .where( + model_class.session_id == user_id + ) + .order_by(model_class.create_by.desc()) + ) + return json.loads(result) + + +class SessionManager: + def __init__( + self, + session_state, + host, + port, + username, + password, + db='chat', + session_table='sessions', + msg_table='chat_memory' + ) -> None: + if GLOBAL_CONFIG.myscale_enable_https == False: + conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http' + else: + conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' + self.engine = create_engine(conn_str, echo=False) + self.session_model_class = create_session_table( + session_table, declarative_base()) + self.session_model_class.metadata.create_all(self.engine) + self.msg_model_class = create_message_history_table(msg_table, declarative_base()) + self.msg_model_class.metadata.create_all(self.engine) + self.session_orm = orm.sessionmaker(self.engine) + self.session_state = session_state + + def list_sessions(self, user_id: str): + with self.session_orm() as session: + result = ( + session.query(self.session_model_class) + .where( + self.session_model_class.user_id == user_id + ) + .order_by(self.session_model_class.create_by.desc()) + ) + sessions = [] + for r in result: + sessions.append({ + "session_id": r.session_id.split("?")[-1], + "system_prompt": r.system_prompt, + }) + return sessions + + # Update sys_prompt with given session_id + def modify_system_prompt(self, session_id, sys_prompt): + with self.session_orm() as session: + obj = session.query(self.session_model_class).where( + self.session_model_class.session_id == session_id).first() + if obj: + obj.system_prompt = sys_prompt + session.commit() + else: + logger.warning(f"Session {session_id} not found") + + # Add a session(session_id, sys_prompt) + def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs): + with self.session_orm() as session: + elem = self.session_model_class( + user_id=user_id, session_id=session_id, system_prompt=system_prompt, + create_by=datetime.now(), additionals=json.dumps(kwargs) + ) + session.add(elem) + session.commit() + + # Remove a session and related chat history. + def remove_session(self, session_id: str): + with self.session_orm() as session: + # remove session + session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete() + # remove related chat history. + session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete() diff --git a/app/backend/chat_bot/tools.py b/app/backend/chat_bot/tools.py new file mode 100644 index 0000000..7f2c82f --- /dev/null +++ b/app/backend/chat_bot/tools.py @@ -0,0 +1,100 @@ +import hashlib +from datetime import datetime +from multiprocessing.pool import ThreadPool +from typing import List + +import requests +from clickhouse_sqlalchemy import types, engines +from langchain.schema.embeddings import Embeddings +from sqlalchemy import Column, Text +from streamlit.runtime.uploaded_file_manager import UploadedFile + + +def parse_files(api_key, user_id, files: List[UploadedFile]): + def parse_file(file: UploadedFile): + headers = { + "accept": "application/json", + "unstructured-api-key": api_key, + } + data = {"strategy": "auto", "ocr_languages": ["eng"]} + file_hash = hashlib.sha256(file.read()).hexdigest() + file_data = {"files": (file.name, file.getvalue(), file.type)} + response = requests.post( + url="https://api.unstructured.io/general/v0/general", + headers=headers, + data=data, + files=file_data + ) + json_response = response.json() + if response.status_code != 200: + raise ValueError(str(json_response)) + texts = [ + { + "text": t["text"], + "file_name": t["metadata"]["filename"], + "entity_id": hashlib.sha256( + (file_hash + t["text"]).encode() + ).hexdigest(), + "user_id": user_id, + "created_by": datetime.now(), + } + for t in json_response + if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 + ] + return texts + + with ThreadPool(8) as p: + rows = [] + for r in p.imap_unordered(parse_file, files): + rows.extend(r) + return rows + + +def extract_embedding(embeddings: Embeddings, texts): + if len(texts) > 0: + embeddings = embeddings.embed_documents( + [t["text"] for _, t in enumerate(texts)]) + for i, _ in enumerate(texts): + texts[i]["vector"] = embeddings[i] + return texts + raise ValueError("No texts extracted!") + + +def create_message_history_table(table_name: str, base_class): + class Message(base_class): + __tablename__ = table_name + id = Column(types.Float64) + session_id = Column(Text) + user_id = Column(Text) + msg_id = Column(Text, primary_key=True) + type = Column(Text) + # should be additions, formal developer mistake spell it. + addtionals = Column(Text) + message = Column(Text) + __table_args__ = ( + engines.MergeTree( + partition_by='session_id', + order_by=('id', 'msg_id') + ), + {'comment': 'Store Chat History'} + ) + + return Message + + +def create_session_table(table_name: str, DynamicBase): + class Session(DynamicBase): + __tablename__ = table_name + user_id = Column(Text) + session_id = Column(Text, primary_key=True) + system_prompt = Column(Text) + # represent create time. + create_by = Column(types.DateTime) + # should be additions, formal developer mistake spell it. + additionals = Column(Text) + __table_args__ = ( + engines.MergeTree(order_by=session_id), + {'comment': 'Store Session and Prompts'} + ) + + return Session diff --git a/app/backend/constants/__init__.py b/app/backend/constants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/constants/myscale_tables.py b/app/backend/constants/myscale_tables.py new file mode 100644 index 0000000..87c31d9 --- /dev/null +++ b/app/backend/constants/myscale_tables.py @@ -0,0 +1,128 @@ +from typing import Dict, List +import streamlit as st +from langchain.chains.query_constructor.schema import AttributeInfo +from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings +from langchain.prompts import PromptTemplate + +from backend.types.table_config import TableConfig + + +def hint_arxiv(): + st.markdown("Here we provide some query samples.") + st.markdown("- If you want to search papers with filters") + st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more " + "than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```") + st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```") + st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB") + st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```") + st.markdown("2. ```Introduce some applications of GANs published around 2019.```") + st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```") + + +def hint_sql_arxiv(): + st.markdown('''```sql +CREATE TABLE default.ChatArXiv ( + `abstract` String, + `id` String, + `vector` Array(Float32), + `metadata` Object('JSON'), + `pubdate` DateTime, + `title` String, + `categories` Array(String), + `authors` Array(String), + `comment` String, + `primary_category` String, + VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), + CONSTRAINT vec_len CHECK length(vector) = 768) +ENGINE = ReplacingMergeTree ORDER BY id +```''') + + +def hint_wiki(): + st.markdown("Here we provide some query samples.") + st.markdown("1. ```Which company did Elon Musk found?```") + st.markdown("2. ```What is Iron Gwazi?```") + st.markdown("3. ```苹果的发源地是哪里?```") + st.markdown("4. ```What is a Ring in mathematics?```") + st.markdown("5. ```The producer of Rick and Morty.```") + st.markdown("6. ```How low is the temperature on Pluto?```") + + +def hint_sql_wiki(): + st.markdown('''```sql +CREATE TABLE wiki.Wikipedia ( + `id` String, + `title` String, + `text` String, + `url` String, + `wiki_id` UInt64, + `views` Float32, + `paragraph_id` UInt64, + `langs` UInt32, + `emb` Array(Float32), + VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), + CONSTRAINT emb_len CHECK length(emb) = 768) +ENGINE = ReplacingMergeTree ORDER BY id +```''') + + +MYSCALE_TABLES: Dict[str, TableConfig] = { + 'Wikipedia': TableConfig( + database="wiki", + table="Wikipedia", + table_contents="Snapshort from Wikipedia for 2022. All in English.", + hint=hint_wiki, + hint_sql=hint_sql_wiki, + # doc_prompt 对 qa source chain 有用 + doc_prompt=PromptTemplate( + input_variables=["page_content", "url", "title", "ref_id", "views"], + template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}" + ), + metadata_col_attributes=[ + AttributeInfo(name="title", description="title of the wikipedia page", type="string"), + AttributeInfo(name="text", description="paragraph from this wiki page", type="string"), + AttributeInfo(name="views", description="number of views", type="float") + ], + must_have_col_names=['id', 'title', 'url', 'text', 'views'], + vector_col_name="emb", + text_col_name="text", + metadata_col_name="metadata", + emb_model=lambda: SentenceTransformerEmbeddings( + model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2' + ), + tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages") + ), + 'ArXiv Papers': TableConfig( + database="default", + table="ChatArXiv", + table_contents="Snapshort from Wikipedia for 2022. All in English.", + hint=hint_arxiv, + hint_sql=hint_sql_arxiv, + doc_prompt=PromptTemplate( + input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"], + template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t" + "Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}" + ), + metadata_col_attributes=[ + AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"), + AttributeInfo(name="authors", description="List of author names", type="list[string]"), + AttributeInfo(name="title", description="Title of the paper", type="string"), + AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"), + AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int") + ], + must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], + vector_col_name="vector", + text_col_name="abstract", + metadata_col_name="metadata", + emb_model=lambda: HuggingFaceInstructEmbeddings( + model_name='hkunlp/instructor-xl', + embed_instruction="Represent the question for retrieving supporting scientific papers: " + ), + tool_desc=( + "search_among_scientific_papers", + "Searches among scientific papers from ArXiv and returns research papers" + ) + ) +} + +ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()] diff --git a/app/prompts/arxiv_prompt.py b/app/backend/constants/prompts.py similarity index 67% rename from app/prompts/arxiv_prompt.py rename to app/backend/constants/prompts.py index 5b8b661..3f4152f 100644 --- a/app/prompts/arxiv_prompt.py +++ b/app/backend/constants/prompts.py @@ -1,15 +1,33 @@ -combine_prompt_template = ( - "You are a helpful document assistant. Your task is to provide information and answer any questions " - + "related to documents given below. You should use the sections, title and abstract of the selected documents as your source of information " - + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find " - + "relevant information in the given sections, you will need to let the user know that the source does not contain " - + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the " - + "corresponding section name and page that you refer to when answering. The following is the related information " - + "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}" - + "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n" - ) - -_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question. +from langchain.prompts import ChatPromptTemplate, \ + SystemMessagePromptTemplate, HumanMessagePromptTemplate + +DEFAULT_SYSTEM_PROMPT = ( + "Do your best to answer the questions. " + "Feel free to use any tools available to look up " + "relevant information. Please keep all details in query " + "when calling search functions." +) + +COMBINE_PROMPT_TEMPLATE = ( + "You are a helpful document assistant. " + "Your task is to provide information and answer any questions related to documents given below. " + "You should use the sections, title and abstract of the selected documents as your source of information " + "and try to provide concise and accurate answers to any questions asked by the user. " + "If you are unable to find relevant information in the given sections, " + "you will need to let the user know that the source does not contain relevant information but still try to " + "provide an answer based on your general knowledge. You must refer to the corresponding section name and page " + "that you refer to when answering. " + "The following is the related information about the document that will help you answer users' questions, " + "you MUST answer it using question's language:\n\n {summaries} " + "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n" +) + +COMBINE_PROMPT = ChatPromptTemplate.from_strings( + string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE), + (HumanMessagePromptTemplate, '{question}')]) + +MYSCALE_PROMPT = """ +You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question. MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance. When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows. @@ -43,7 +61,7 @@ PRIMARY KEY id Question: What is Feartue Pyramid Network? -SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} +SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} ======== table info ======== diff --git a/app/backend/constants/streamlit_keys.py b/app/backend/constants/streamlit_keys.py new file mode 100644 index 0000000..234f904 --- /dev/null +++ b/app/backend/constants/streamlit_keys.py @@ -0,0 +1,35 @@ +DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started" +DATA_INITIALIZE_STARTED = "data_initialize_started" +DATA_INITIALIZE_COMPLETED = "data_initialize_completed" + + +CHAT_SESSION = "sel_sess" +CHAT_KNOWLEDGE_TABLE = "private_kb" + +CHAT_SESSION_MANAGER = "session_manager" +CHAT_CURRENT_USER_SESSIONS = "current_sessions" + +EL_SESSION_SELECTOR = "el_session_selector" + +# all personal knowledge bases under a specific user. +USER_PERSONAL_KNOWLEDGE_BASES = "user_tools" +# all personal files under a specific user. +USER_PRIVATE_FILES = "user_files" +# public and personal knowledge bases. +AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users" + +EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove" + +# files needs upload +EL_UPLOAD_FILES = "el_upload_files" +EL_UPLOAD_FILES_STATUS = "el_upload_files_status" + +# use these files to build private knowledge base +EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files" +# build a personal kb, given name. +EL_PERSONAL_KB_NAME = "el_personal_kb_name" +# build a personal kb, given description. +EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description" + +# knowledge bases selected by user. +EL_SELECTED_KBS = "el_selected_kbs" diff --git a/app/backend/constants/variables.py b/app/backend/constants/variables.py new file mode 100644 index 0000000..5f29b2f --- /dev/null +++ b/app/backend/constants/variables.py @@ -0,0 +1,58 @@ +from backend.types.global_config import GlobalConfig + +# ***** str variables ***** # +EMBEDDING_MODEL_PREFIX = "embedding_model" +CHAINS_RETRIEVERS_MAPPING = "sel_map_obj" +LANGCHAIN_RETRIEVER = "langchain_retriever" +VECTOR_SQL_RETRIEVER = "vecsql_retriever" +TABLE_EMBEDDINGS_MAPPING = "embeddings" +RETRIEVER_TOOLS = "tools" +DATA_INITIALIZE_STATUS = "data_initialized" +UI_INITIALIZED = "ui_initialized" +JUMP_QUERY_ASK = "jump_query_ask" +USER_NAME = "user_name" +USER_INFO = "user_info" + +DIVIDER_HTML = """ +
+""" + +DIVIDER_THIN_HTML = """ +
+""" + + +class RetrieverButtons: + vector_sql_query_from_db = "vector_sql_query_from_db" + vector_sql_query_with_llm = "vector_sql_query_with_llm" + self_query_from_db = "self_query_from_db" + self_query_with_llm = "self_query_with_llm" + + +GLOBAL_CONFIG = GlobalConfig() + + +def update_global_config(new_config: GlobalConfig): + global GLOBAL_CONFIG + GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base + GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key + GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id + GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain + GLOBAL_CONFIG.myscale_user = new_config.myscale_user + GLOBAL_CONFIG.myscale_password = new_config.myscale_password + GLOBAL_CONFIG.myscale_host = new_config.myscale_host + GLOBAL_CONFIG.myscale_port = new_config.myscale_port + GLOBAL_CONFIG.query_model = new_config.query_model + GLOBAL_CONFIG.chat_model = new_config.chat_model + GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api + GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https diff --git a/app/backend/construct/__init__.py b/app/backend/construct/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/construct/build_agents.py b/app/backend/construct/build_agents.py new file mode 100644 index 0000000..b521e69 --- /dev/null +++ b/app/backend/construct/build_agents.py @@ -0,0 +1,82 @@ +import os +from typing import Sequence, List + +import streamlit as st +from langchain.agents import AgentExecutor +from langchain.schema.language_model import BaseLanguageModel +from langchain.tools import BaseTool + +from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter +from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT +from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS +from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS +from logger import logger + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import MessagesPlaceholder +from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.schema.messages import SystemMessage +from langchain.memory import SQLChatMessageHistory + + +def create_agent_executor( + agent_name: str, + session_id: str, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + system_prompt: str, + **kwargs +) -> AgentExecutor: + agent_name = agent_name.replace(" ", "_") + conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}' + chat_memory = SQLChatMessageHistory( + session_id, + connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https', + custom_message_converter=DefaultClickhouseMessageConverter(agent_name)) + memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) + + prompt = OpenAIFunctionsAgent.create_prompt( + system_message=SystemMessage(content=system_prompt), + extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], + ) + agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) + return AgentExecutor( + agent=agent, + tools=tools, + memory=memory, + verbose=True, + return_intermediate_steps=True, + **kwargs + ) + + +def build_agents( + session_id: str, + tool_names: List[str], + model: str = "gpt-3.5-turbo-0125", + temperature: float = 0.6, + system_prompt: str = DEFAULT_SYSTEM_PROMPT +): + chat_llm = ChatOpenAI( + model_name=model, + temperature=temperature, + base_url=GLOBAL_CONFIG.openai_api_base, + api_key=GLOBAL_CONFIG.openai_api_key, + streaming=True + ) + tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS)) + selected_tools = [tools[k] for k in tool_names] + logger.info(f"create agent, use tools: {selected_tools}") + agent = create_agent_executor( + agent_name="chat_memory", + session_id=session_id, + llm=chat_llm, + tools=selected_tools, + system_prompt=system_prompt + ) + return agent diff --git a/app/backend/construct/build_all.py b/app/backend/construct/build_all.py new file mode 100644 index 0000000..ebfba90 --- /dev/null +++ b/app/backend/construct/build_all.py @@ -0,0 +1,95 @@ +from logger import logger +from typing import Dict, Any, Union + +import streamlit as st + +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING +from backend.construct.build_chains import build_retrieval_qa_with_sources_chain +from backend.construct.build_retriever_tool import create_retriever_tool +from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever +from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn + +from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \ + SentenceTransformerEmbeddings + + +@st.cache_resource +def load_embedding_model_for_table(table_name: str) -> \ + Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]: + with st.spinner(f"Loading embedding models for [{table_name}] ..."): + embeddings = MYSCALE_TABLES[table_name].emb_model() + return embeddings + + +@st.cache_resource +def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]: + embedding_models = {} + for table in MYSCALE_TABLES: + embedding_models[table] = load_embedding_model_for_table(table) + return embedding_models + + +@st.cache_resource +def update_retriever_tools(): + retrievers_tools = {} + for table in MYSCALE_TABLES: + logger.info(f"Updating retriever tools [, ] for table {table}") + retrievers_tools.update( + { + f"{table} + Self Querying": create_retriever_tool( + st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"], + *MYSCALE_TABLES[table].tool_desc + ), + f"{table} + Vector SQL": create_retriever_tool( + st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"], + *MYSCALE_TABLES[table].tool_desc + ), + }) + return retrievers_tools + + +@st.cache_resource +def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers: + metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes + + self_query_retriever = build_self_query_retriever(table_name) + self_query_chain = build_retrieval_qa_with_sources_chain( + table_name=table_name, + retriever=self_query_retriever, + chain_name="Self Query Retriever" + ) + + vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name) + vector_sql_chain = build_retrieval_qa_with_sources_chain( + table_name=table_name, + retriever=vector_sql_retriever, + chain_name="Vector SQL DB Retriever" + ) + + metadata_columns = [ + MetadataColumn( + name=attribute.name, + desc=attribute.description, + type=attribute.type + ) + for attribute in metadata_col_attributes + ] + return ChainsAndRetrievers( + metadata_columns=metadata_columns, + # for self query + retriever=self_query_retriever, + chain=self_query_chain, + # for vector sql + sql_retriever=vector_sql_retriever, + sql_chain=vector_sql_chain + ) + + +@st.cache_resource +def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]: + chains_and_retrievers = {} + for table in MYSCALE_TABLES: + logger.info(f"Building chains, retrievers for table {table}") + chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict() + return chains_and_retrievers diff --git a/app/backend/construct/build_chains.py b/app/backend/construct/build_chains.py new file mode 100644 index 0000000..43b450d --- /dev/null +++ b/app/backend/construct/build_chains.py @@ -0,0 +1,39 @@ +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate +from langchain.schema import BaseRetriever +import streamlit as st + +from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain +from backend.chains.stuff_documents import CustomStuffDocumentChain +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.prompts import COMBINE_PROMPT +from backend.constants.variables import GLOBAL_CONFIG + + +def build_retrieval_qa_with_sources_chain( + table_name: str, + retriever: BaseRetriever, + chain_name: str = "" +) -> CustomRetrievalQAWithSourcesChain: + with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'): + # Assign ref_id for documents + custom_stuff_document_chain = CustomStuffDocumentChain( + llm_chain=LLMChain( + prompt=COMBINE_PROMPT, + llm=ChatOpenAI( + model_name=GLOBAL_CONFIG.chat_model, + openai_api_key=GLOBAL_CONFIG.openai_api_key, + temperature=0.6 + ), + ), + document_prompt=MYSCALE_TABLES[table_name].doc_prompt, + document_variable_name="summaries", + ) + chain = CustomRetrievalQAWithSourcesChain( + retriever=retriever, + combine_documents_chain=custom_stuff_document_chain, + return_source_documents=True, + max_tokens_limit=12000, + ) + return chain diff --git a/app/backend/construct/build_chat_bot.py b/app/backend/construct/build_chat_bot.py new file mode 100644 index 0000000..a79f4e4 --- /dev/null +++ b/app/backend/construct/build_chat_bot.py @@ -0,0 +1,36 @@ +from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable +from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER +import streamlit as st + +from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING +from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT +from backend.chat_bot.session_manager import SessionManager + + +def build_chat_knowledge_table(): + if CHAT_KNOWLEDGE_TABLE not in st.session_state: + st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable( + host=GLOBAL_CONFIG.myscale_host, + port=GLOBAL_CONFIG.myscale_port, + username=GLOBAL_CONFIG.myscale_user, + password=GLOBAL_CONFIG.myscale_password, + # embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"], + embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"], + parser_api_key=GLOBAL_CONFIG.untrusted_api, + ) + + +def initialize_session_manager(): + if CHAT_SESSION not in st.session_state: + st.session_state[CHAT_SESSION] = { + "session_id": "default", + "system_prompt": DEFAULT_SYSTEM_PROMPT, + } + if CHAT_SESSION_MANAGER not in st.session_state: + st.session_state[CHAT_SESSION_MANAGER] = SessionManager( + st.session_state, + host=GLOBAL_CONFIG.myscale_host, + port=GLOBAL_CONFIG.myscale_port, + username=GLOBAL_CONFIG.myscale_user, + password=GLOBAL_CONFIG.myscale_password, + ) diff --git a/app/backend/construct/build_retriever_tool.py b/app/backend/construct/build_retriever_tool.py new file mode 100644 index 0000000..b1ed27d --- /dev/null +++ b/app/backend/construct/build_retriever_tool.py @@ -0,0 +1,45 @@ +import json +from typing import List + +from langchain.pydantic_v1 import BaseModel, Field +from langchain.schema import BaseRetriever, Document +from langchain.tools import Tool + +from backend.chat_bot.json_decoder import CustomJSONEncoder + + +class RetrieverInput(BaseModel): + query: str = Field(description="query to look up in retriever") + + +def create_retriever_tool( + retriever: BaseRetriever, + tool_name: str, + description: str +) -> Tool: + """Create a tool to do retrieval of documents. + + Args: + retriever: The retriever to use for the retrieval + tool_name: The name for the tool. This will be passed to the language model, + so should be unique and somewhat descriptive. + description: The description for the tool. This will be passed to the language + model, so should be descriptive. + + Returns: + Tool class to pass to an agent + """ + def wrap(func): + def wrapped_retrieve(*args, **kwargs): + docs: List[Document] = func(*args, **kwargs) + return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) + + return wrapped_retrieve + + return Tool( + name=tool_name, + description=description, + func=wrap(retriever.get_relevant_documents), + coroutine=retriever.aget_relevant_documents, + args_schema=RetrieverInput, + ) diff --git a/app/backend/construct/build_retrievers.py b/app/backend/construct/build_retrievers.py new file mode 100644 index 0000000..6c3918a --- /dev/null +++ b/app/backend/construct/build_retrievers.py @@ -0,0 +1,120 @@ +import streamlit as st +from langchain.chat_models import ChatOpenAI +from langchain.prompts.prompt import PromptTemplate +from langchain.retrievers.self_query.base import SelfQueryRetriever +from langchain.retrievers.self_query.myscale import MyScaleTranslator +from langchain.utilities.sql_database import SQLDatabase +from langchain.vectorstores import MyScaleSettings +from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever +from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain +from sqlalchemy import create_engine, MetaData + +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.prompts import MYSCALE_PROMPT +from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG +from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser +from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson +from logger import logger + + +@st.cache_resource +def build_self_query_retriever(table_name: str) -> SelfQueryRetriever: + with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."): + myscale_connection = { + "host": GLOBAL_CONFIG.myscale_host, + "port": GLOBAL_CONFIG.myscale_port, + "username": GLOBAL_CONFIG.myscale_user, + "password": GLOBAL_CONFIG.myscale_password, + } + myscale_settings = MyScaleSettings( + **myscale_connection, + database=MYSCALE_TABLES[table_name].database, + table=MYSCALE_TABLES[table_name].table, + column_map={ + "id": "id", + "text": MYSCALE_TABLES[table_name].text_col_name, + "vector": MYSCALE_TABLES[table_name].vector_col_name, + # TODO refine MyScaleDB metadata in langchain. + "metadata": MYSCALE_TABLES[table_name].metadata_col_name + } + ) + myscale_vector_store = MyScaleWithoutMetadataJson( + embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], + config=myscale_settings, + must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names + ) + + with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."): + retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm( + llm=ChatOpenAI( + model_name=GLOBAL_CONFIG.query_model, + base_url=GLOBAL_CONFIG.openai_api_base, + api_key=GLOBAL_CONFIG.openai_api_key, + temperature=0 + ), + vectorstore=myscale_vector_store, + document_contents=MYSCALE_TABLES[table_name].table_contents, + metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes, + use_original_query=False, + structured_query_translator=MyScaleTranslator() + ) + return retriever + + +@st.cache_resource +def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever: + """Get a group of relative docs from MyScaleDB""" + with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'): + if GLOBAL_CONFIG.myscale_enable_https == False: + engine = create_engine( + f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' + f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' + f'/{MYSCALE_TABLES[table_name].database}?protocol=http' + ) + else: + engine = create_engine( + f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' + f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' + f'/{MYSCALE_TABLES[table_name].database}?protocol=https' + ) + metadata = MetaData(bind=engine) + logger.info(f"{table_name} metadata is : {metadata}") + prompt = PromptTemplate( + input_variables=["input", "table_info", "top_k"], + template=MYSCALE_PROMPT, + ) + # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column. + output_parser = VectorSQLRetrieveOutputParser.from_embeddings( + model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], + # rewrite columns needs be searched. + must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names + ) + + # `db_chain` will generate a SQL + vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm( + llm=ChatOpenAI( + model_name=GLOBAL_CONFIG.query_model, + base_url=GLOBAL_CONFIG.openai_api_base, + api_key=GLOBAL_CONFIG.openai_api_key, + temperature=0 + ), + prompt=prompt, + top_k=10, + return_direct=True, + db=SQLDatabase( + engine, + None, + metadata, + include_tables=[MYSCALE_TABLES[table_name].table], + max_string_length=1024 + ), + sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type. + native_format=True + ) + + # `retriever` can search a group of documents with `db_chain` + vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever( + sql_db_chain=vector_sql_db_chain, + page_content_key=MYSCALE_TABLES[table_name].text_col_name + ) + return vector_sql_db_chain_retriever diff --git a/app/backend/retrievers/__init__.py b/app/backend/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/retrievers/self_query.py b/app/backend/retrievers/self_query.py new file mode 100644 index 0000000..d30e06a --- /dev/null +++ b/app/backend/retrievers/self_query.py @@ -0,0 +1,89 @@ +from typing import List + +import pandas as pd +import streamlit as st +from langchain.retrievers import SelfQueryRetriever +from langchain_core.documents import Document +from langchain_core.runnables import RunnableConfig + +from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons +from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler +from ui.utils import display +from logger import logger + + +def process_self_query(selected_table, query_type): + place_holder = st.empty() + logger.info( + f"button-1: {RetrieverButtons.self_query_from_db}, " + f"button-2: {RetrieverButtons.self_query_with_llm}, " + f"content: {st.session_state.query_self}" + ) + with place_holder.expander('🪵 Chat Log', expanded=True): + try: + if query_type == RetrieverButtons.self_query_from_db: + callback = CustomSelfQueryRetrieverCallBackHandler() + retriever: SelfQueryRetriever = \ + st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"] + config: RunnableConfig = {"callbacks": [callback]} + + relevant_docs = retriever.invoke( + input=st.session_state.query_self, + config=config + ) + + callback.progress_bar.progress( + value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅") + + st.markdown(f"### Self Query Results from `{selected_table}` \n" + f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") + display( + dataframe=pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs] + ), + columns_=MYSCALE_TABLES[selected_table].must_have_col_names + ) + elif query_type == RetrieverButtons.self_query_with_llm: + # callback = CustomSelfQueryRetrieverCallBackHandler() + callback = ChatDataSelfAskCallBackHandler() + chain: CustomRetrievalQAWithSourcesChain = \ + st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"] + chain_results = chain(st.session_state.query_self, callbacks=[callback]) + callback.progress_bar.progress( + value=1.0, + text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅" + ) + + documents_reference: List[Document] = chain_results["source_documents"] + st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n" + f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") + display( + pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in documents_reference] + ) + ) + st.markdown( + f"### Answer from LLM \n" + f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n" + ) + st.write(chain_results['answer']) + st.markdown( + f"### References from `{selected_table}`\n" + f"> Here shows that which documents used by LLM \n\n" + ) + if len(chain_results['sources']) == 0: + st.write("No documents is used by LLM.") + else: + display( + dataframe=pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']] + ), + columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names, + index='ref_id' + ) + st.markdown(DIVIDER_HTML, unsafe_allow_html=True) + except Exception as e: + st.write('Oops 😵 Something bad happened...') + raise e diff --git a/app/backend/retrievers/vector_sql_output_parser.py b/app/backend/retrievers/vector_sql_output_parser.py new file mode 100644 index 0000000..79ebca5 --- /dev/null +++ b/app/backend/retrievers/vector_sql_output_parser.py @@ -0,0 +1,23 @@ +from typing import Dict, Any, List + +from langchain_experimental.sql.vector_sql import VectorSQLOutputParser + + +class VectorSQLRetrieveOutputParser(VectorSQLOutputParser): + """Based on VectorSQLOutputParser + It also modify the SQL to get all columns + """ + must_have_columns: List[str] + + @property + def _type(self) -> str: + return "vector_sql_retrieve_custom" + + def parse(self, text: str) -> Dict[str, Any]: + text = text.strip() + start = text.upper().find("SELECT") + if start >= 0: + end = text.upper().find("FROM") + text = text.replace( + text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns)) + return super().parse(text) diff --git a/app/backend/retrievers/vector_sql_query.py b/app/backend/retrievers/vector_sql_query.py new file mode 100644 index 0000000..d7d97ae --- /dev/null +++ b/app/backend/retrievers/vector_sql_query.py @@ -0,0 +1,95 @@ +from typing import List + +import pandas as pd +import streamlit as st +from langchain.schema import Document +from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever + +from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons +from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler +from ui.utils import display +from logger import logger + + +def process_sql_query(selected_table: str, query_type: str): + place_holder = st.empty() + logger.info( + f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, " + f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, " + f"table: {selected_table}, " + f"content: {st.session_state.query_sql}" + ) + with place_holder.expander('🪵 Query Log', expanded=True): + try: + if query_type == RetrieverButtons.vector_sql_query_from_db: + callback = VectorSQLSearchDBCallBackHandler() + vector_sql_retriever: VectorSQLDatabaseChainRetriever = \ + st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"] + relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents( + query=st.session_state.query_sql, + callbacks=[callback] + ) + + callback.progress_bar.progress( + value=1.0, + text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅" + ) + + st.markdown(f"### Vector Search Results from `{selected_table}` \n" + f"> Here we get documents from MyScaleDB with given sql statement \n\n") + display( + pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs] + ) + ) + elif query_type == RetrieverButtons.vector_sql_query_with_llm: + callback = VectorSQLSearchLLMCallBackHandler(table=selected_table) + vector_sql_chain: CustomRetrievalQAWithSourcesChain = \ + st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"] + chain_results = vector_sql_chain( + inputs=st.session_state.query_sql, + callbacks=[callback] + ) + + callback.progress_bar.progress( + value=1.0, + text="[Question -> LLM -> SQL Statement -> MyScaleDB -> " + "(Question,Results) -> LLM -> Results] Done! ✅" + ) + + documents_reference: List[Document] = chain_results["source_documents"] + st.markdown(f"### Vector Search Results from `{selected_table}` \n" + f"> Here we get documents from MyScaleDB with given sql statement \n\n") + display( + pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in documents_reference] + ) + ) + st.markdown( + f"### Answer from LLM \n" + f"> The response of the LLM when given the vector search results. \n\n" + ) + st.write(chain_results['answer']) + st.markdown( + f"### References from `{selected_table}`\n" + f"> Here shows that which documents used by LLM \n\n" + ) + if len(chain_results['sources']) == 0: + st.write("No documents is used by LLM.") + else: + display( + dataframe=pd.DataFrame( + [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']] + ), + columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names, + index='ref_id' + ) + else: + raise NotImplementedError(f"Unsupported query type: {query_type}") + st.markdown(DIVIDER_HTML, unsafe_allow_html=True) + except Exception as e: + st.write('Oops 😵 Something bad happened...') + raise e + diff --git a/app/backend/types/__init__.py b/app/backend/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/types/chains_and_retrievers.py b/app/backend/types/chains_and_retrievers.py new file mode 100644 index 0000000..91e3082 --- /dev/null +++ b/app/backend/types/chains_and_retrievers.py @@ -0,0 +1,34 @@ +from typing import Dict +from dataclasses import dataclass +from typing import List, Any +from langchain.retrievers import SelfQueryRetriever +from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever + +from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain + + +@dataclass +class MetadataColumn: + name: str + desc: str + type: str + + +@dataclass +class ChainsAndRetrievers: + metadata_columns: List[MetadataColumn] + retriever: SelfQueryRetriever + chain: CustomRetrievalQAWithSourcesChain + sql_retriever: VectorSQLDatabaseChainRetriever + sql_chain: CustomRetrievalQAWithSourcesChain + + def to_dict(self) -> Dict[str, Any]: + return { + "metadata_columns": self.metadata_columns, + "retriever": self.retriever, + "chain": self.chain, + "sql_retriever": self.sql_retriever, + "sql_chain": self.sql_chain + } + + diff --git a/app/backend/types/global_config.py b/app/backend/types/global_config.py new file mode 100644 index 0000000..ffa624f --- /dev/null +++ b/app/backend/types/global_config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class GlobalConfig: + openai_api_base: Optional[str] = "" + openai_api_key: Optional[str] = "" + + auth0_client_id: Optional[str] = "" + auth0_domain: Optional[str] = "" + + myscale_user: Optional[str] = "" + myscale_password: Optional[str] = "" + myscale_host: Optional[str] = "" + myscale_port: Optional[int] = 443 + + query_model: Optional[str] = "" + chat_model: Optional[str] = "" + + untrusted_api: Optional[str] = "" + myscale_enable_https: Optional[bool] = True diff --git a/app/backend/types/table_config.py b/app/backend/types/table_config.py new file mode 100644 index 0000000..61087ba --- /dev/null +++ b/app/backend/types/table_config.py @@ -0,0 +1,25 @@ +from typing import Callable +from langchain.chains.query_constructor.schema import AttributeInfo +from langchain.prompts import PromptTemplate +from dataclasses import dataclass +from typing import List + + +@dataclass +class TableConfig: + database: str + table: str + table_contents: str + # column names + must_have_col_names: List[str] + vector_col_name: str + text_col_name: str + metadata_col_name: str + # hint for UI + hint: Callable + hint_sql: Callable + # for langchain + doc_prompt: PromptTemplate + metadata_col_attributes: List[AttributeInfo] + emb_model: Callable + tool_desc: tuple diff --git a/app/backend/vector_store/__init__.py b/app/backend/vector_store/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/backend/vector_store/myscale_without_metadata.py b/app/backend/vector_store/myscale_without_metadata.py new file mode 100644 index 0000000..f446f64 --- /dev/null +++ b/app/backend/vector_store/myscale_without_metadata.py @@ -0,0 +1,52 @@ +from typing import Any, Optional, List + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.myscale import MyScale, MyScaleSettings + +from logger import logger + + +class MyScaleWithoutMetadataJson(MyScale): + def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], + **kwargs: Any) -> None: + try: + super().__init__(embedding, config, **kwargs) + except Exception as e: + logger.error(e) + self.must_have_cols: List[str] = must_have_cols + + def _build_qstr( + self, q_emb: List[float], topk: int, where_str: Optional[str] = None + ) -> str: + q_emb_str = ",".join(map(str, q_emb)) + if where_str: + where_str = f"PREWHERE {where_str}" + else: + where_str = "" + + q_str = f""" + SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)} + FROM {self.config.database}.{self.config.table} + {where_str} + ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) + AS dist {self.dist_order} + LIMIT {topk} + """ + return q_str + + def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, + **kwargs: Any) -> List[Document]: + q_str = self._build_qstr(embedding, k, where_str) + try: + return [ + Document( + page_content=r[self.config.column_map["text"]], + metadata={k: r[k] for k in self.must_have_cols}, + ) + for r in self.client.query(q_str).named_results() + ] + except Exception as e: + logger.error( + f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") + return [] diff --git a/app/callbacks/arxiv_callbacks.py b/app/callbacks/arxiv_callbacks.py deleted file mode 100644 index efa3caa..0000000 --- a/app/callbacks/arxiv_callbacks.py +++ /dev/null @@ -1,148 +0,0 @@ -import streamlit as st -import json -import textwrap -from typing import Dict, Any, List -from sql_formatter.core import format_sql -from langchain.callbacks.streamlit.streamlit_callback_handler import ( - LLMThought, - StreamlitCallbackHandler, -) -from langchain.schema.output import LLMResult - - -class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler): - def __init__(self) -> None: - self.progress_bar = st.progress(value=0.0, text="Working...") - self.tokens_stream = "" - - def on_llm_start(self, serialized, prompts, **kwargs) -> None: - pass - - def on_text(self, text: str, **kwargs) -> None: - self.progress_bar.progress(value=0.2, text="Asking LLM...") - - def on_chain_end(self, outputs, **kwargs) -> None: - self.progress_bar.progress(value=0.6, text="Searching in DB...") - if "repr" in outputs: - st.markdown("### Generated Filter") - st.markdown( - f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True) - - def on_chain_start(self, serialized, inputs, **kwargs) -> None: - pass - - -class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): - def __init__(self) -> None: - self.progress_bar = st.progress(value=0.0, text="Searching DB...") - self.status_bar = st.empty() - self.prog_value = 0.0 - self.prog_map = { - "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2, - "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4, - "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8, - } - - def on_llm_start(self, serialized, prompts, **kwargs) -> None: - pass - - def on_text(self, text: str, **kwargs) -> None: - pass - - def on_chain_start(self, serialized, inputs, **kwargs) -> None: - cid = ".".join(serialized["id"]) - if cid != "langchain.chains.llm.LLMChain": - self.progress_bar.progress( - value=self.prog_map[cid], text=f"Running Chain `{cid}`..." - ) - self.prog_value = self.prog_map[cid] - else: - self.prog_value += 0.1 - self.progress_bar.progress( - value=self.prog_value, text=f"Running Chain `{cid}`..." - ) - - def on_chain_end(self, outputs, **kwargs) -> None: - pass - - -class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler): - def __init__(self) -> None: - self.progress_bar = st.progress(value=0.0, text="Writing SQL...") - self.status_bar = st.empty() - self.prog_value = 0 - self.prog_interval = 0.2 - - def on_llm_start(self, serialized, prompts, **kwargs) -> None: - pass - - def on_llm_end( - self, - response: LLMResult, - *args, - **kwargs, - ): - text = response.generations[0][0].text - if text.replace(" ", "").upper().startswith("SELECT"): - st.write("We generated Vector SQL for you:") - st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") - print(f"Vector SQL: {text}") - self.prog_value += self.prog_interval - self.progress_bar.progress( - value=self.prog_value, text="Searching in DB...") - - def on_chain_start(self, serialized, inputs, **kwargs) -> None: - cid = ".".join(serialized["id"]) - self.prog_value += self.prog_interval - self.progress_bar.progress( - value=self.prog_value, text=f"Running Chain `{cid}`..." - ) - - def on_chain_end(self, outputs, **kwargs) -> None: - pass - - -class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler): - def __init__(self) -> None: - self.progress_bar = st.progress(value=0.0, text="Writing SQL...") - self.status_bar = st.empty() - self.prog_value = 0 - self.prog_interval = 0.1 - - -class LLMThoughtWithKB(LLMThought): - def on_tool_end( - self, - output: str, - color=None, - observation_prefix=None, - llm_prefix=None, - **kwargs: Any, - ) -> None: - try: - self._container.markdown( - "\n\n".join( - ["### Retrieved Documents:"] - + [ - f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" - for i, r in enumerate(json.loads(output)) - ] - ) - ) - except Exception as e: - super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) - - -class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - if self._current_thought is None: - self._current_thought = LLMThoughtWithKB( - parent_container=self._parent_container, - expanded=self._expand_new_thoughts, - collapse_on_complete=self._collapse_completed_thoughts, - labeler=self._thought_labeler, - ) - - self._current_thought.on_llm_start(serialized, prompts) diff --git a/app/chains/arxiv_chains.py b/app/chains/arxiv_chains.py deleted file mode 100644 index 359c5f5..0000000 --- a/app/chains/arxiv_chains.py +++ /dev/null @@ -1,197 +0,0 @@ -import logging -import inspect -from typing import Dict, Any, Optional, List, Tuple - - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.embeddings.base import Embeddings -from langchain.callbacks.manager import Callbacks -from langchain.schema.prompt_template import format_document -from langchain.docstore.document import Document -from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain -from langchain.vectorstores.myscale import MyScale, MyScaleSettings -from langchain.chains.combine_documents.stuff import StuffDocumentsChain - -from langchain_experimental.sql.vector_sql import VectorSQLOutputParser - -logger = logging.getLogger() - - -class MyScaleWithoutMetadataJson(MyScale): - def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None: - super().__init__(embedding, config, **kwargs) - self.must_have_cols: List[str] = must_have_cols - - def _build_qstr( - self, q_emb: List[float], topk: int, where_str: Optional[str] = None - ) -> str: - q_emb_str = ",".join(map(str, q_emb)) - if where_str: - where_str = f"PREWHERE {where_str}" - else: - where_str = "" - - q_str = f""" - SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)} - FROM {self.config.database}.{self.config.table} - {where_str} - ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) - AS dist {self.dist_order} - LIMIT {topk} - """ - return q_str - - def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]: - q_str = self._build_qstr(embedding, k, where_str) - try: - return [ - Document( - page_content=r[self.config.column_map["text"]], - metadata={k: r[k] for k in self.must_have_cols}, - ) - for r in self.client.query(q_str).named_results() - ] - except Exception as e: - logger.error( - f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") - return [] - - -class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser): - """Based on VectorSQLOutputParser - It also modify the SQL to get all columns - """ - must_have_columns: List[str] - - @property - def _type(self) -> str: - return "vector_sql_retrieve_custom" - - def parse(self, text: str) -> Dict[str, Any]: - text = text.strip() - start = text.upper().find("SELECT") - if start >= 0: - end = text.upper().find("FROM") - text = text.replace( - text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns)) - return super().parse(text) - - -class ArXivStuffDocumentChain(StuffDocumentsChain): - """Combine arxiv documents with PDF reference number""" - - def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: - """Construct inputs from kwargs and docs. - - Format and the join all the documents together into one input with name - `self.document_variable_name`. The pluck any additional variables - from **kwargs. - - Args: - docs: List of documents to format and then join into single input - **kwargs: additional inputs to chain, will pluck any other required - arguments from here. - - Returns: - dictionary of inputs to LLMChain - """ - # Format each document according to the prompt - doc_strings = [] - for doc_id, doc in enumerate(docs): - # add temp reference number in metadata - doc.metadata.update({'ref_id': doc_id}) - doc.page_content = doc.page_content.replace('\n', ' ') - doc_strings.append(format_document(doc, self.document_prompt)) - # Join the documents together to put them in the prompt. - inputs = { - k: v - for k, v in kwargs.items() - if k in self.llm_chain.prompt.input_variables - } - inputs[self.document_variable_name] = self.document_separator.join( - doc_strings) - return inputs - - def combine_docs( - self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any - ) -> Tuple[str, dict]: - """Stuff all documents into one prompt and pass to LLM. - - Args: - docs: List of documents to join together into one variable - callbacks: Optional callbacks to pass along - **kwargs: additional parameters to use to get inputs to LLMChain. - - Returns: - The first element returned is the single string output. The second - element returned is a dictionary of other keys to return. - """ - inputs = self._get_inputs(docs, **kwargs) - # Call predict on the LLM. - output = self.llm_chain.predict(callbacks=callbacks, **inputs) - return output, {} - - @property - def _chain_type(self) -> str: - return "referenced_stuff_documents_chain" - - -class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain): - """QA with source chain for Chat ArXiv app with references - - This chain will automatically assign reference number to the article, - Then parse it back to titles or anything else. - """ - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - accepts_run_manager = ( - "run_manager" in inspect.signature(self._get_docs).parameters - ) - if accepts_run_manager: - docs = self._get_docs(inputs, run_manager=_run_manager) - else: - docs = self._get_docs(inputs) # type: ignore[call-arg] - - answer = self.combine_documents_chain.run( - input_documents=docs, callbacks=_run_manager.get_child(), **inputs - ) - # parse source with ref_id - sources = [] - ref_cnt = 1 - for d in docs: - ref_id = d.metadata['ref_id'] - if f"Doc #{ref_id}" in answer: - answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}") - if f"#{ref_id}" in answer: - title = d.metadata['title'].replace('\n', '') - d.metadata['ref_id'] = ref_cnt - answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]") - sources.append(d) - ref_cnt += 1 - - result: Dict[str, Any] = { - self.answer_key: answer, - self.sources_answer_key: sources, - } - if self.return_source_documents: - result["source_documents"] = docs - return result - - async def _acall( - self, - inputs: Dict[str, Any], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - raise NotImplementedError - - @property - def _chain_type(self) -> str: - return "arxiv_qa_with_sources_chain" diff --git a/app/chat.py b/app/chat.py deleted file mode 100644 index e1f6aee..0000000 --- a/app/chat.py +++ /dev/null @@ -1,402 +0,0 @@ -import json -import pandas as pd -from os import environ -from time import sleep -import datetime -import streamlit as st -from lib.sessions import SessionManager -from lib.private_kb import PrivateKnowledgeBase -from langchain.schema import HumanMessage, FunctionMessage -from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler -from lib.json_conv import CustomJSONDecoder - -from lib.helper import ( - build_agents, - MYSCALE_HOST, - MYSCALE_PASSWORD, - MYSCALE_PORT, - MYSCALE_USER, - DEFAULT_SYSTEM_PROMPT, - UNSTRUCTURED_API, -) -from login import back_to_main - -environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] - -TOOL_NAMES = { - "langchain_retriever_tool": "Self-querying retriever", - "vecsql_retriever_tool": "Vector SQL", -} - - -def on_chat_submit(): - with st.session_state.next_round.container(): - with st.chat_message("user"): - st.write(st.session_state.chat_input) - with st.chat_message("assistant"): - container = st.container() - st_callback = ChatDataAgentCallBackHandler( - container, collapse_completed_thoughts=False - ) - ret = st.session_state.agent( - {"input": st.session_state.chat_input}, callbacks=[st_callback] - ) - print(ret) - - -def clear_history(): - if "agent" in st.session_state: - st.session_state.agent.memory.clear() - - -def back_to_main(): - if "user_info" in st.session_state: - del st.session_state.user_info - if "user_name" in st.session_state: - del st.session_state.user_name - if "jump_query_ask" in st.session_state: - del st.session_state.jump_query_ask - if "sel_sess" in st.session_state: - del st.session_state.sel_sess - if "current_sessions" in st.session_state: - del st.session_state.current_sessions - - -def on_session_change_submit(): - if "session_manager" in st.session_state and "session_editor" in st.session_state: - print(st.session_state.session_editor) - try: - for elem in st.session_state.session_editor["added_rows"]: - if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: - if elem["session_id"] != "" and "?" not in elem["session_id"]: - st.session_state.session_manager.add_session( - user_id=st.session_state.user_name, - session_id=f"{st.session_state.user_name}?{elem['session_id']}", - system_prompt=elem["system_prompt"], - ) - else: - raise KeyError( - "`session_id` should NOT be neither empty nor contain question marks." - ) - else: - raise KeyError( - "You should fill both `session_id` and `system_prompt` to add a column!" - ) - for elem in st.session_state.session_editor["deleted_rows"]: - st.session_state.session_manager.remove_session( - session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}", - ) - refresh_sessions() - except Exception as e: - sleep(2) - st.error(f"{type(e)}: {str(e)}") - finally: - st.session_state.session_editor["added_rows"] = [] - st.session_state.session_editor["deleted_rows"] = [] - refresh_agent() - - -def build_session_manager(): - return SessionManager( - st.session_state, - host=MYSCALE_HOST, - port=MYSCALE_PORT, - username=MYSCALE_USER, - password=MYSCALE_PASSWORD, - ) - - -def refresh_sessions(): - st.session_state[ - "current_sessions" - ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) - if ( - type(st.session_state.current_sessions) is not dict - and len(st.session_state.current_sessions) <= 0 - ): - st.session_state.session_manager.add_session( - st.session_state.user_name, - f"{st.session_state.user_name}?default", - DEFAULT_SYSTEM_PROMPT, - ) - st.session_state[ - "current_sessions" - ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) - st.session_state["user_files"] = st.session_state.private_kb.list_files( - st.session_state.user_name - ) - st.session_state["user_tools"] = st.session_state.private_kb.list_tools( - st.session_state.user_name - ) - st.session_state["tools_with_users"] = { - **st.session_state.tools, - **st.session_state.private_kb.as_tools(st.session_state.user_name), - } - try: - dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index( - "default" - if "" not in st.session_state - else st.session_state.sel_session["session_id"] - ) - except ValueError: - dfl_indx = 0 - st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx] - - -def build_kb_as_tool(): - if ( - "b_tool_name" in st.session_state - and "b_tool_desc" in st.session_state - and "b_tool_files" in st.session_state - and len(st.session_state.b_tool_name) > 0 - and len(st.session_state.b_tool_desc) > 0 - and len(st.session_state.b_tool_files) > 0 - ): - st.session_state.private_kb.create_tool( - st.session_state.user_name, - st.session_state.b_tool_name, - st.session_state.b_tool_desc, - [f["file_name"] for f in st.session_state.b_tool_files], - ) - refresh_sessions() - else: - st.session_state.tool_status.error( - "You should fill all fields to build up a tool!" - ) - sleep(2) - - -def remove_kb(): - if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0: - st.session_state.private_kb.remove_tools( - st.session_state.user_name, - [f["tool_name"] for f in st.session_state.r_tool_names], - ) - refresh_sessions() - else: - st.session_state.tool_status.error( - "You should specify at least one tool to delete!" - ) - sleep(2) - - -def refresh_agent(): - with st.spinner("Initializing session..."): - print( - f"??? Changed to ", - f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", - ) - st.session_state["agent"] = build_agents( - f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", - ["LangChain Self Query Retriever For Wikipedia"] - if "selected_tools" not in st.session_state - else st.session_state.selected_tools, - system_prompt=DEFAULT_SYSTEM_PROMPT - if "sel_sess" not in st.session_state - else st.session_state.sel_sess["system_prompt"], - ) - - -def add_file(): - if ( - "uploaded_files" not in st.session_state - or len(st.session_state.uploaded_files) == 0 - ): - st.session_state.tool_status.error("Please upload files!", icon="⚠️") - sleep(2) - return - try: - st.session_state.tool_status.info("Uploading...") - st.session_state.private_kb.add_by_file( - st.session_state.user_name, st.session_state.uploaded_files - ) - refresh_sessions() - except ValueError as e: - st.session_state.tool_status.error("Failed to upload! " + str(e)) - sleep(2) - - -def clear_files(): - st.session_state.private_kb.clear(st.session_state.user_name) - refresh_sessions() - - -def chat_page(): - if "sel_sess" not in st.session_state: - st.session_state["sel_sess"] = { - "session_id": "default", - "system_prompt": DEFAULT_SYSTEM_PROMPT, - } - if "private_kb" not in st.session_state: - st.session_state["private_kb"] = PrivateKnowledgeBase( - host=MYSCALE_HOST, - port=MYSCALE_PORT, - username=MYSCALE_USER, - password=MYSCALE_PASSWORD, - embedding=st.session_state.embeddings["Wikipedia"], - parser_api_key=UNSTRUCTURED_API, - ) - if "session_manager" not in st.session_state: - st.session_state["session_manager"] = build_session_manager() - with st.sidebar: - with st.expander("Session Management"): - if "current_sessions" not in st.session_state: - refresh_sessions() - st.info( - "Here you can set up your session! \n\nYou can **change your prompt** here!", - icon="🤖", - ) - st.info( - ( - "**Add columns by clicking the empty row**.\n" - "And **delete columns by selecting rows with a press on `DEL` Key**" - ), - icon="💡", - ) - st.info( - "Don't forget to **click `Submit Change` to save your change**!", - icon="📒", - ) - st.data_editor( - st.session_state.current_sessions, - num_rows="dynamic", - key="session_editor", - use_container_width=True, - ) - st.button("Submit Change!", on_click=on_session_change_submit) - with st.expander("Session Selection", expanded=True): - st.info( - "If no session is attach to your account, then we will add a default session to you!", - icon="❤️", - ) - try: - dfl_indx = [ - x["session_id"] for x in st.session_state.current_sessions - ].index( - "default" - if "" not in st.session_state - else st.session_state.sel_session["session_id"] - ) - except Exception as e: - print("*** ", str(e)) - dfl_indx = 0 - st.selectbox( - "Choose a session to chat:", - options=st.session_state.current_sessions, - index=dfl_indx, - key="sel_sess", - format_func=lambda x: x["session_id"], - on_change=refresh_agent, - ) - print(st.session_state.sel_sess) - with st.expander("Tool Settings", expanded=True): - st.info( - "We provides you several knowledge base tools for you. We are building more tools!", - icon="🔧", - ) - st.session_state["tool_status"] = st.empty() - tab_kb, tab_file = st.tabs( - [ - "Knowledge Bases", - "File Upload", - ] - ) - with tab_kb: - st.markdown("#### Build You Own Knowledge") - st.multiselect( - "Select Files to Build up", - st.session_state.user_files, - placeholder="You should upload files first", - key="b_tool_files", - format_func=lambda x: x["file_name"], - ) - st.text_input( - "Tool Name", "get_relevant_documents", key="b_tool_name") - st.text_input( - "Tool Description", - "Searches among user's private files and returns related documents", - key="b_tool_desc", - ) - st.button("Build!", on_click=build_kb_as_tool) - st.markdown("### Knowledge Base Selection") - if ( - "user_tools" in st.session_state - and len(st.session_state.user_tools) > 0 - ): - st.markdown("***User Created Knowledge Bases***") - st.dataframe(st.session_state.user_tools) - st.multiselect( - "Select a Knowledge Base Tool", - st.session_state.tools.keys() - if "tools_with_users" not in st.session_state - else st.session_state.tools_with_users, - default=["Wikipedia + Self Querying"], - key="selected_tools", - on_change=refresh_agent, - ) - st.markdown("### Delete Knowledge Base") - st.multiselect( - "Choose Knowledge Base to Remove", - st.session_state.user_tools, - format_func=lambda x: x["tool_name"], - key="r_tool_names", - ) - st.button("Delete", on_click=remove_kb) - with tab_file: - st.info( - ( - "We adopted [Unstructured API](https://unstructured.io/api-key) " - "here and we only store the processed texts from your documents. " - "For privacy concerns, please refer to " - "[our policy issue](https://myscale.com/privacy/)." - ), - icon="📃", - ) - st.file_uploader( - "Upload files", key="uploaded_files", accept_multiple_files=True - ) - st.markdown("### Uploaded Files") - st.dataframe( - st.session_state.private_kb.list_files( - st.session_state.user_name), - use_container_width=True, - ) - col_1, col_2 = st.columns(2) - with col_1: - st.button("Add Files", on_click=add_file) - with col_2: - st.button("Clear Files and All Tools", - on_click=clear_files) - - st.button("Clear Chat History", on_click=clear_history) - st.button("Logout", on_click=back_to_main) - if "agent" not in st.session_state: - refresh_agent() - print("!!! ", st.session_state.agent.memory.chat_memory.session_id) - for msg in st.session_state.agent.memory.chat_memory.messages: - speaker = "user" if isinstance(msg, HumanMessage) else "assistant" - if isinstance(msg, FunctionMessage): - with st.chat_message("Knowledge Base", avatar="📖"): - st.write( - f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" - ) - st.write("Retrieved from knowledge base:") - try: - st.dataframe( - pd.DataFrame.from_records( - json.loads(msg.content, cls=CustomJSONDecoder) - ), - use_container_width=True, - ) - except: - st.write(msg.content) - else: - if len(msg.content) > 0: - with st.chat_message(speaker): - print(type(msg), msg.dict()) - st.write( - f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" - ) - st.write(f"{msg.content}") - st.session_state["next_round"] = st.empty() - st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") diff --git a/app/lib/helper.py b/app/lib/helper.py deleted file mode 100644 index 9d04210..0000000 --- a/app/lib/helper.py +++ /dev/null @@ -1,573 +0,0 @@ - -import json -import time -import hashlib -from typing import Dict, Any, List, Tuple -import re -from os import environ -import streamlit as st -from langchain.schema import BaseRetriever -from langchain.tools import Tool -from langchain.pydantic_v1 import BaseModel, Field - -from sqlalchemy import Column, Text, create_engine, MetaData -from langchain.agents import AgentExecutor -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from clickhouse_sqlalchemy import ( - types, engines -) -from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain -from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever -from langchain.utilities.sql_database import SQLDatabase -from langchain.chains import LLMChain -from sqlalchemy import create_engine, MetaData -from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ - SystemMessagePromptTemplate, HumanMessagePromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.chat_models import ChatOpenAI -from langchain.schema import BaseRetriever, Document -from langchain import OpenAI -from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName -from langchain.retrievers.self_query.base import SelfQueryRetriever -from langchain.retrievers.self_query.myscale import MyScaleTranslator -from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings -from langchain.vectorstores import MyScaleSettings -from chains.arxiv_chains import MyScaleWithoutMetadataJson -from langchain.prompts.prompt import PromptTemplate -from langchain.prompts.chat import MessagesPlaceholder -from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory -from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent -from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage, \ - SystemMessage, ChatMessage, ToolMessage -from langchain.memory import SQLChatMessageHistory -from langchain.memory.chat_message_histories.sql import \ - DefaultMessageConverter -from langchain.schema.messages import BaseMessage -# from langchain.agents.agent_toolkits import create_retriever_tool -from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt -from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain -from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser -from .json_conv import CustomJSONEncoder - -environ['TOKENIZERS_PARALLELISM'] = 'true' -environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] - -query_model_name = "gpt-3.5-turbo-0125" -chat_model_name = "gpt-3.5-turbo-0125" - - -OPENAI_API_KEY = st.secrets['OPENAI_API_KEY'] -OPENAI_API_BASE = st.secrets['OPENAI_API_BASE'] -MYSCALE_USER = st.secrets['MYSCALE_USER'] -MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD'] -MYSCALE_HOST = st.secrets['MYSCALE_HOST'] -MYSCALE_PORT = st.secrets['MYSCALE_PORT'] -UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API'] - -COMBINE_PROMPT = ChatPromptTemplate.from_strings( - string_messages=[(SystemMessagePromptTemplate, combine_prompt_template), - (HumanMessagePromptTemplate, '{question}')]) -DEFAULT_SYSTEM_PROMPT = ( - "Do your best to answer the questions. " - "Feel free to use any tools available to look up " - "relevant information. Please keep all details in query " - "when calling search functions." -) - - -def hint_arxiv(): - st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" - "For example: \n\n" - "*If you want to search papers with complex filters*:\n\n" - "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n" - "*If you want to ask questions based on papers in database*:\n\n" - "- What is PageRank?\n" - "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n" - "- Introduce some applications of GANs published around 2019.\n" - "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n" - "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n" - "- Is it possible to synthesize room temperature super conductive material?") - - -def hint_sql_arxiv(): - st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') - st.markdown('''```sql -CREATE TABLE default.ChatArXiv ( - `abstract` String, - `id` String, - `vector` Array(Float32), - `metadata` Object('JSON'), - `pubdate` DateTime, - `title` String, - `categories` Array(String), - `authors` Array(String), - `comment` String, - `primary_category` String, - VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), - CONSTRAINT vec_len CHECK length(vector) = 768) -ENGINE = ReplacingMergeTree ORDER BY id -```''') - - -def hint_wiki(): - st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" - "For example: \n\n" - "- Which company did Elon Musk found?\n" - "- What is Iron Gwazi?\n" - "- What is a Ring in mathematics?\n" - "- 苹果的发源地是那里?\n") - - -def hint_sql_wiki(): - st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') - st.markdown('''```sql -CREATE TABLE wiki.Wikipedia ( - `id` String, - `title` String, - `text` String, - `url` String, - `wiki_id` UInt64, - `views` Float32, - `paragraph_id` UInt64, - `langs` UInt32, - `emb` Array(Float32), - VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), - CONSTRAINT emb_len CHECK length(emb) = 768) -ENGINE = ReplacingMergeTree ORDER BY id -```''') - - -sel_map = { - 'Wikipedia': { - "database": "wiki", - "table": "Wikipedia", - "hint": hint_wiki, - "hint_sql": hint_sql_wiki, - "doc_prompt": PromptTemplate( - input_variables=["page_content", - "url", "title", "ref_id", "views"], - template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"), - "metadata_cols": [ - AttributeInfo( - name="title", - description="title of the wikipedia page", - type="string", - ), - AttributeInfo( - name="text", - description="paragraph from this wiki page", - type="string", - ), - AttributeInfo( - name="views", - description="number of views", - type="float" - ), - ], - "must_have_cols": ['id', 'title', 'url', 'text', 'views'], - "vector_col": "emb", - "text_col": "text", - "metadata_col": "metadata", - "emb_model": lambda: SentenceTransformerEmbeddings( - model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',), - "tool_desc": ("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages"), - }, - 'ArXiv Papers': { - "database": "default", - "table": "ChatArXiv", - "hint": hint_arxiv, - "hint_sql": hint_sql_arxiv, - "doc_prompt": PromptTemplate( - input_variables=["page_content", "id", "title", "ref_id", - "authors", "pubdate", "categories"], - template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"), - "metadata_cols": [ - AttributeInfo( - name=VirtualColumnName(name="pubdate"), - description="The year the paper is published", - type="timestamp", - ), - AttributeInfo( - name="authors", - description="List of author names", - type="list[string]", - ), - AttributeInfo( - name="title", - description="Title of the paper", - type="string", - ), - AttributeInfo( - name="categories", - description="arxiv categories to this paper", - type="list[string]" - ), - AttributeInfo( - name="length(categories)", - description="length of arxiv categories to this paper", - type="int" - ), - ], - "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], - "vector_col": "vector", - "text_col": "abstract", - "metadata_col": "metadata", - "emb_model": lambda: HuggingFaceInstructEmbeddings( - model_name='hkunlp/instructor-xl', - embed_instruction="Represent the question for retrieving supporting scientific papers: "), - "tool_desc": ("search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers"), - } -} - - -def build_embedding_model(_sel): - """Build embedding model - """ - with st.spinner("Loading Model..."): - embeddings = sel_map[_sel]["emb_model"]() - return embeddings - - -def build_chains_retrievers(_sel: str) -> Dict[str, Any]: - """build chains and retrievers - - :param _sel: selected knowledge base - :type _sel: str - :return: _description_ - :rtype: Dict[str, Any] - """ - metadata_field_info = sel_map[_sel]["metadata_cols"] - retriever = build_self_query(_sel) - chain = build_qa_chain(_sel, retriever, name="Self Query Retriever") - sql_retriever = build_vector_sql(_sel) - sql_chain = build_qa_chain(_sel, sql_retriever, name="Vector SQL") - - return { - "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], - "retriever": retriever, - "chain": chain, - "sql_retriever": sql_retriever, - "sql_chain": sql_chain - } - - -def build_self_query(_sel: str) -> SelfQueryRetriever: - """Build self querying retriever - - :param _sel: selected knowledge base - :type _sel: str - :return: retriever used by chains - :rtype: SelfQueryRetriever - """ - with st.spinner(f"Connecting DB for {_sel}..."): - myscale_connection = { - "host": MYSCALE_HOST, - "port": MYSCALE_PORT, - "username": MYSCALE_USER, - "password": MYSCALE_PASSWORD, - } - config = MyScaleSettings(**myscale_connection, - database=sel_map[_sel]["database"], - table=sel_map[_sel]["table"], - column_map={ - "id": "id", - "text": sel_map[_sel]["text_col"], - "vector": sel_map[_sel]["vector_col"], - "metadata": sel_map[_sel]["metadata_col"] - }) - doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config, - must_have_cols=sel_map[_sel]['must_have_cols']) - - with st.spinner(f"Building Self Query Retriever for {_sel}..."): - metadata_field_info = sel_map[_sel]["metadata_cols"] - retriever = SelfQueryRetriever.from_llm( - OpenAI(model_name=query_model_name, - openai_api_key=OPENAI_API_KEY, temperature=0), - doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info, - use_original_query=False, structured_query_translator=MyScaleTranslator()) - return retriever - - -def build_vector_sql(_sel: str) -> VectorSQLDatabaseChainRetriever: - """Build Vector SQL Database Retriever - - :param _sel: selected knowledge base - :type _sel: str - :return: retriever used by chains - :rtype: VectorSQLDatabaseChainRetriever - """ - with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'): - engine = create_engine( - f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https') - metadata = MetaData(bind=engine) - PROMPT = PromptTemplate( - input_variables=["input", "table_info", "top_k"], - template=_myscale_prompt, - ) - output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings( - model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"]) - sql_query_chain = VectorSQLDatabaseChain.from_llm( - llm=OpenAI(model_name=query_model_name, - openai_api_key=OPENAI_API_KEY, temperature=0), - prompt=PROMPT, - top_k=10, - return_direct=True, - db=SQLDatabase(engine, None, metadata, max_string_length=1024), - sql_cmd_parser=output_parser, - native_format=True - ) - sql_retriever = VectorSQLDatabaseChainRetriever( - sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"]) - return sql_retriever - - -def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str = "Self-query") -> ArXivQAwithSourcesChain: - """_summary_ - - :param _sel: selected knowledge base - :type _sel: str - :param retriever: retriever used by chains - :type retriever: BaseRetriever - :param name: display name, defaults to "Self-query" - :type name: str, optional - :return: QA chain interacts with user - :rtype: ArXivQAwithSourcesChain - """ - with st.spinner(f'Building QA Chain with {name} for {_sel}...'): - chain = ArXivQAwithSourcesChain( - retriever=retriever, - combine_documents_chain=ArXivStuffDocumentChain( - llm_chain=LLMChain( - prompt=COMBINE_PROMPT, - llm=ChatOpenAI(model_name=chat_model_name, - openai_api_key=OPENAI_API_KEY, temperature=0.6), - ), - document_prompt=sel_map[_sel]["doc_prompt"], - document_variable_name="summaries", - - ), - return_source_documents=True, - max_tokens_limit=12000, - ) - return chain - - -@st.cache_resource -def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]: - """build all resources - - :return: sel_map_obj - :rtype: Dict[str, Any] - """ - sel_map_obj = {} - embeddings = {} - for k in sel_map: - embeddings[k] = build_embedding_model(k) - st.session_state[f'emb_model_{k}'] = embeddings[k] - sel_map_obj[k] = build_chains_retrievers(k) - return sel_map_obj, embeddings - - -def create_message_model(table_name, DynamicBase): # type: ignore - """ - Create a message model for a given table name. - - Args: - table_name: The name of the table to use. - DynamicBase: The base class to use for the model. - - Returns: - The model class. - - """ - - # Model decleared inside a function to have a dynamic table name - class Message(DynamicBase): - __tablename__ = table_name - id = Column(types.Float64) - session_id = Column(Text) - user_id = Column(Text) - msg_id = Column(Text, primary_key=True) - type = Column(Text) - addtionals = Column(Text) - message = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - partition_by='session_id', - order_by=('id', 'msg_id')), - {'comment': 'Store Chat History'} - ) - - return Message - - -def _message_from_dict(message: dict) -> BaseMessage: - _type = message["type"] - if _type == "human": - return HumanMessage(**message["data"]) - elif _type == "ai": - return AIMessage(**message["data"]) - elif _type == "system": - return SystemMessage(**message["data"]) - elif _type == "chat": - return ChatMessage(**message["data"]) - elif _type == "function": - return FunctionMessage(**message["data"]) - elif _type == "tool": - return ToolMessage(**message["data"]) - elif _type == "AIMessageChunk": - message["data"]["type"] = "ai" - return AIMessage(**message["data"]) - else: - raise ValueError(f"Got unexpected message type: {_type}") - - -class DefaultClickhouseMessageConverter(DefaultMessageConverter): - """The default message converter for SQLChatMessageHistory.""" - - def __init__(self, table_name: str): - self.model_class = create_message_model(table_name, declarative_base()) - - def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: - tstamp = time.time() - msg_id = hashlib.sha256( - f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest() - user_id, _ = session_id.split("?") - return self.model_class( - id=tstamp, - msg_id=msg_id, - user_id=user_id, - session_id=session_id, - type=message.type, - addtionals=json.dumps(message.additional_kwargs), - message=json.dumps({ - "type": message.type, - "additional_kwargs": {"timestamp": tstamp}, - "data": message.dict()}) - ) - - def from_sql_model(self, sql_message: Any) -> BaseMessage: - msg_dump = json.loads(sql_message.message) - msg = _message_from_dict(msg_dump) - msg.additional_kwargs = msg_dump["additional_kwargs"] - return msg - - def get_sql_model_class(self) -> Any: - return self.model_class - - -def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs): - name = name.replace(" ", "_") - conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}' - chat_memory = SQLChatMessageHistory( - session_id, - connection_string=f'{conn_str}/chat?protocol=https', - custom_message_converter=DefaultClickhouseMessageConverter(name)) - memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) - - _system_message = SystemMessage( - content=system_prompt - ) - prompt = OpenAIFunctionsAgent.create_prompt( - system_message=_system_message, - extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], - ) - agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) - return AgentExecutor( - agent=agent, - tools=tools, - memory=memory, - verbose=True, - return_intermediate_steps=True, - **kwargs - ) - - -class RetrieverInput(BaseModel): - query: str = Field(description="query to look up in retriever") - - -def create_retriever_tool( - retriever: BaseRetriever, name: str, description: str -) -> Tool: - """Create a tool to do retrieval of documents. - - Args: - retriever: The retriever to use for the retrieval - name: The name for the tool. This will be passed to the language model, - so should be unique and somewhat descriptive. - description: The description for the tool. This will be passed to the language - model, so should be descriptive. - - Returns: - Tool class to pass to an agent - """ - def wrap(func): - def wrapped_retrieve(*args, **kwargs): - docs: List[Document] = func(*args, **kwargs) - return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) - return wrapped_retrieve - - return Tool( - name=name, - description=description, - func=wrap(retriever.get_relevant_documents), - coroutine=retriever.aget_relevant_documents, - args_schema=RetrieverInput, - ) - - -@st.cache_resource -def build_tools(): - """build all resources - - :return: sel_map_obj - :rtype: Dict[str, Any] - """ - sel_map_obj = {} - for k in sel_map: - if f'emb_model_{k}' not in st.session_state: - st.session_state[f'emb_model_{k}'] = build_embedding_model(k) - if "sel_map_obj" not in st.session_state: - st.session_state["sel_map_obj"] = {} - if k not in st.session_state.sel_map_obj: - st.session_state["sel_map_obj"][k] = {} - if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]: - st.session_state.sel_map_obj[k].update(build_chains_retrievers(k)) - sel_map_obj.update({ - f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],), - f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],), - }) - return sel_map_obj - - -def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT): - chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, - openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True, - ) - tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users - sel_tools = [tools[k] for k in tool_names] - agent = create_agent_executor( - "chat_memory", - session_id, - chat_llm, - tools=sel_tools, - system_prompt=system_prompt - ) - return agent - - -def display(dataframe, columns_=None, index=None): - if len(dataframe) > 0: - if index: - dataframe.set_index(index) - if columns_: - st.dataframe(dataframe[columns_]) - else: - st.dataframe(dataframe) - else: - st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True) diff --git a/app/lib/private_kb.py b/app/lib/private_kb.py deleted file mode 100644 index b5c98e0..0000000 --- a/app/lib/private_kb.py +++ /dev/null @@ -1,213 +0,0 @@ -import pandas as pd -import hashlib -import requests -from typing import List, Optional -from datetime import datetime -from langchain.schema.embeddings import Embeddings -from streamlit.runtime.uploaded_file_manager import UploadedFile -from clickhouse_connect import get_client -from multiprocessing.pool import ThreadPool -from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings -from .helper import create_retriever_tool - -parser_url = "https://api.unstructured.io/general/v0/general" - - -def parse_files(api_key, user_id, files: List[UploadedFile]): - def parse_file(file: UploadedFile): - headers = { - "accept": "application/json", - "unstructured-api-key": api_key, - } - data = {"strategy": "auto", "ocr_languages": ["eng"]} - file_hash = hashlib.sha256(file.read()).hexdigest() - file_data = {"files": (file.name, file.getvalue(), file.type)} - response = requests.post( - parser_url, headers=headers, data=data, files=file_data - ) - json_response = response.json() - if response.status_code != 200: - raise ValueError(str(json_response)) - texts = [ - { - "text": t["text"], - "file_name": t["metadata"]["filename"], - "entity_id": hashlib.sha256( - (file_hash + t["text"]).encode() - ).hexdigest(), - "user_id": user_id, - "created_by": datetime.now(), - } - for t in json_response - if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 - ] - return texts - - with ThreadPool(8) as p: - rows = [] - for r in p.imap_unordered(parse_file, files): - rows.extend(r) - return rows - - -def extract_embedding(embeddings: Embeddings, texts): - if len(texts) > 0: - embs = embeddings.embed_documents( - [t["text"] for _, t in enumerate(texts)]) - for i, _ in enumerate(texts): - texts[i]["vector"] = embs[i] - return texts - raise ValueError("No texts extracted!") - - -class PrivateKnowledgeBase: - def __init__( - self, - host, - port, - username, - password, - embedding: Embeddings, - parser_api_key, - db="chat", - kb_table="private_kb", - tool_table="private_tool", - ) -> None: - super().__init__() - kb_schema_ = f""" - CREATE TABLE IF NOT EXISTS {db}.{kb_table}( - entity_id String, - file_name String, - text String, - user_id String, - created_by DateTime, - vector Array(Float32), - CONSTRAINT cons_vec_len CHECK length(vector) = 768, - VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') - ) ENGINE = ReplacingMergeTree ORDER BY entity_id - """ - tool_schema_ = f""" - CREATE TABLE IF NOT EXISTS {db}.{tool_table}( - tool_id String, - tool_name String, - file_names Array(String), - user_id String, - created_by DateTime, - tool_description String - ) ENGINE = ReplacingMergeTree ORDER BY tool_id - """ - self.kb_table = kb_table - self.tool_table = tool_table - config = MyScaleSettings( - host=host, - port=port, - username=username, - password=password, - database=db, - table=kb_table, - ) - client = get_client( - host=config.host, - port=config.port, - username=config.username, - password=config.password, - ) - client.command("SET allow_experimental_object_type=1") - client.command(kb_schema_) - client.command(tool_schema_) - self.parser_api_key = parser_api_key - self.vstore = MyScaleWithoutJSON( - embedding=embedding, - config=config, - must_have_cols=["file_name", "text", "created_by"], - ) - - def list_files(self, user_id, tool_name=None): - query = f""" - SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, - arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars - FROM {self.vstore.config.database}.{self.kb_table} - WHERE user_id = '{user_id}' GROUP BY file_name - """ - return [r for r in self.vstore.client.query(query).named_results()] - - def add_by_file( - self, user_id, files: List[UploadedFile], **kwargs - ): - data = parse_files(self.parser_api_key, user_id, files) - data = extract_embedding(self.vstore.embeddings, data) - self.vstore.client.insert_df( - self.kb_table, - pd.DataFrame(data), - database=self.vstore.config.database, - ) - - def clear(self, user_id): - self.vstore.client.command( - f"DELETE FROM {self.vstore.config.database}.{self.kb_table} " - f"WHERE user_id='{user_id}'" - ) - query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}'""" - self.vstore.client.command(query) - - def create_tool( - self, user_id, tool_name, tool_description, files: Optional[List[str]] = None - ): - self.vstore.client.insert_df( - self.tool_table, - pd.DataFrame( - [ - { - "tool_id": hashlib.sha256( - (user_id + tool_name).encode("utf-8") - ).hexdigest(), - "tool_name": tool_name, - "file_names": files, - "user_id": user_id, - "created_by": datetime.now(), - "tool_description": tool_description, - } - ] - ), - database=self.vstore.config.database, - ) - - def list_tools(self, user_id, tool_name=None): - extended_where = f"AND tool_name = '{tool_name}'" if tool_name else "" - query = f""" - SELECT tool_name, tool_description, length(file_names) - FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' {extended_where} - """ - return [r for r in self.vstore.client.query(query).named_results()] - - def remove_tools(self, user_id, tool_names): - tool_names = ",".join([f"'{t}'" for t in tool_names]) - query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]""" - self.vstore.client.command(query) - - def as_tools(self, user_id, tool_name=None): - tools = self.list_tools(user_id=user_id, tool_name=tool_name) - retrievers = { - t["tool_name"]: create_retriever_tool( - self.vstore.as_retriever( - search_kwargs={ - "where_str": ( - f"user_id='{user_id}' " - f"""AND file_name IN ( - SELECT arrayJoin(file_names) FROM ( - SELECT file_names - FROM {self.vstore.config.database}.{self.tool_table} - WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}') - )""" - ) - }, - ), - name=t["tool_name"], - description=t["tool_description"], - ) - for t in tools - } - return retrievers diff --git a/app/lib/schemas.py b/app/lib/schemas.py deleted file mode 100644 index ace331f..0000000 --- a/app/lib/schemas.py +++ /dev/null @@ -1,52 +0,0 @@ -from sqlalchemy import Column, Text -from clickhouse_sqlalchemy import types, engines - - -def create_message_model(table_name, DynamicBase): # type: ignore - """ - Create a message model for a given table name. - - Args: - table_name: The name of the table to use. - DynamicBase: The base class to use for the model. - - Returns: - The model class. - - """ - - # Model decleared inside a function to have a dynamic table name - class Message(DynamicBase): - __tablename__ = table_name - id = Column(types.Float64) - session_id = Column(Text) - user_id = Column(Text) - msg_id = Column(Text, primary_key=True) - type = Column(Text) - addtionals = Column(Text) - message = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - partition_by='session_id', - order_by=('id', 'msg_id')), - {'comment': 'Store Chat History'} - ) - - return Message - - -def create_session_table(table_name, DynamicBase): # type: ignore - # Model decleared inside a function to have a dynamic table name - class Session(DynamicBase): - __tablename__ = table_name - user_id = Column(Text) - session_id = Column(Text, primary_key=True) - system_prompt = Column(Text) - create_by = Column(types.DateTime) - additionals = Column(Text) - __table_args__ = ( - engines.ReplacingMergeTree( - order_by=('session_id')), - {'comment': 'Store Session and Prompts'} - ) - return Session diff --git a/app/lib/sessions.py b/app/lib/sessions.py deleted file mode 100644 index 5cefd63..0000000 --- a/app/lib/sessions.py +++ /dev/null @@ -1,78 +0,0 @@ -import json -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base -from langchain.schema import BaseChatMessageHistory -from datetime import datetime -from sqlalchemy import Column, Text, orm, create_engine -from .schemas import create_message_model, create_session_table - - -def get_sessions(engine, model_class, user_id): - with orm.sessionmaker(engine)() as session: - result = ( - session.query(model_class) - .where( - model_class.session_id == user_id - ) - .order_by(model_class.create_by.desc()) - ) - return json.loads(result) - - -class SessionManager: - def __init__(self, session_state, host, port, username, password, - db='chat', sess_table='sessions', msg_table='chat_memory') -> None: - conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' - self.engine = create_engine(conn_str, echo=False) - self.sess_model_class = create_session_table( - sess_table, declarative_base()) - self.sess_model_class.metadata.create_all(self.engine) - self.msg_model_class = create_message_model( - msg_table, declarative_base()) - self.msg_model_class.metadata.create_all(self.engine) - self.Session = orm.sessionmaker(self.engine) - self.session_state = session_state - - def list_sessions(self, user_id): - with self.Session() as session: - result = ( - session.query(self.sess_model_class) - .where( - self.sess_model_class.user_id == user_id - ) - .order_by(self.sess_model_class.create_by.desc()) - ) - sessions = [] - for r in result: - sessions.append({ - "session_id": r.session_id.split("?")[-1], - "system_prompt": r.system_prompt, - }) - return sessions - - def modify_system_prompt(self, session_id, sys_prompt): - with self.Session() as session: - session.update(self.sess_model_class).where( - self.sess_model_class == session_id).value(system_prompt=sys_prompt) - session.commit() - - def add_session(self, user_id, session_id, system_prompt, **kwargs): - with self.Session() as session: - elem = self.sess_model_class( - user_id=user_id, session_id=session_id, system_prompt=system_prompt, - create_by=datetime.now(), additionals=json.dumps(kwargs) - ) - session.add(elem) - session.commit() - - def remove_session(self, session_id): - with self.Session() as session: - session.query(self.sess_model_class).where( - self.sess_model_class.session_id == session_id).delete() - # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete() - if "agent" in self.session_state: - self.session_state.agent.memory.chat_memory.clear() - if "file_analyzer" in self.session_state: - self.session_state.file_analyzer.clear_files() diff --git a/app/logger.py b/app/logger.py new file mode 100644 index 0000000..a19a1c7 --- /dev/null +++ b/app/logger.py @@ -0,0 +1,18 @@ +import logging + + +def setup_logger(): + logger_ = logging.getLogger('chat-data') + logger_.setLevel(logging.INFO) + if not logger_.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]' + ) + console_handler.setFormatter(formatter) + logger_.addHandler(console_handler) + return logger_ + + +logger = setup_logger() diff --git a/app/login.py b/app/login.py deleted file mode 100644 index 63231e2..0000000 --- a/app/login.py +++ /dev/null @@ -1,55 +0,0 @@ -import streamlit as st -from auth0_component import login_button - -AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID'] -AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN'] - - -def login(): - if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask): - return True - st.subheader( - "🤗 Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! 🤗 ") - st.write("You can now chat with ArXiv and Wikipedia! 🌟\n") - st.write("Built purely with streamlit 👑 , LangChain 🦜🔗 and love ❤️ for AI!") - st.write( - "Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!") - st.write( - "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!") - st.divider() - col1, col2 = st.columns(2, gap='large') - with col1.container(): - st.write("Try out MyScale's Self-query and Vector SQL retrievers!") - st.write("In this demo, you will be able to see how those retrievers " - "**digest** -> **translate** -> **retrieve** -> **answer** to your question!") - st.session_state["jump_query_ask"] = st.button("Query / Ask") - with col2.container(): - # st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)") - st.write("Now with the power of LangChain's Conversantional Agents, we are able to build " - "an RAG-enabled chatbot within one MyScale instance! ") - st.write("Log in to Chat with RAG!") - st.write("Recommended to use the standalone version of Chat-Data, available [here](https://myscale-chatdata.hf.space/).") - login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0") - st.divider() - st.write("- [Privacy Policy](https://myscale.com/privacy/)\n" - "- [Terms of Sevice](https://myscale.com/terms/)") - if st.session_state.auth0 is not None: - st.session_state.user_info = dict(st.session_state.auth0) - if 'email' in st.session_state.user_info: - email = st.session_state.user_info["email"] - else: - email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}" - st.session_state["user_name"] = email - del st.session_state.auth0 - st.experimental_rerun() - if st.session_state.jump_query_ask: - st.experimental_rerun() - - -def back_to_main(): - if "user_info" in st.session_state: - del st.session_state.user_info - if "user_name" in st.session_state: - del st.session_state.user_name - if "jump_query_ask" in st.session_state: - del st.session_state.jump_query_ask diff --git a/app/requirements.txt b/app/requirements.txt index 3acacf7..cc60e48 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,15 +1,17 @@ -langchain @ git+https://github.com/myscale/langchain.git@preview#egg=langchain&subdirectory=libs/langchain -langchain-experimental @ git+https://github.com/myscale/langchain.git@preview#egg=langchain-experimental&subdirectory=libs/experimental -# https://github.com/PromtEngineer/localGPT/issues/722 -sentence_transformers==2.2.2 +langchain==0.2.1 +langchain-community==0.2.1 +langchain-core==0.2.1 +langchain-experimental==0.0.59 +langchain-openai==0.1.7 +sentence-transformers==2.2.2 InstructorEmbedding pandas -sentence_transformers -streamlit==1.25 +streamlit +streamlit-extras streamlit-auth0-component altair==4.2.2 clickhouse-connect -openai==0.28 +openai==1.35.3 lark tiktoken sql-formatter diff --git a/app/ui/__init__.py b/app/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/ui/chat_page.py b/app/ui/chat_page.py new file mode 100644 index 0000000..44e659d --- /dev/null +++ b/app/ui/chat_page.py @@ -0,0 +1,196 @@ +import datetime +import json + +import pandas as pd +import streamlit as st +from langchain_core.messages import HumanMessage, FunctionMessage +from streamlit.delta_generator import DeltaGenerator + +from backend.chat_bot.json_decoder import CustomJSONDecoder +from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \ + EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \ + EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \ + USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \ + CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS +from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS +from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager +from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \ + create_private_knowledge_base_as_tool, \ + remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit + + +def render_session_manager(): + with st.expander("🤖 Session Management"): + if CHAT_CURRENT_USER_SESSIONS not in st.session_state: + refresh_sessions() + st.markdown("Here you can update `session_id` and `system_prompt`") + st.markdown("- Click empty row to add a new item") + st.markdown("- If needs to delete an item, just click it and press `DEL` key") + st.markdown("- Don't forget to submit your change.") + + st.data_editor( + data=st.session_state[CHAT_CURRENT_USER_SESSIONS], + num_rows="dynamic", + key="session_editor", + use_container_width=True, + ) + st.button("⏫ Submit", on_click=on_session_change_submit, type="primary") + + +def render_session_selection(): + with st.expander("✅ Session Selection", expanded=True): + st.selectbox( + "Choose a `session` to chat", + options=st.session_state[CHAT_CURRENT_USER_SESSIONS], + index=None, + key=EL_SESSION_SELECTOR, + format_func=lambda x: x["session_id"], + on_change=refresh_agent, + ) + + +def render_files_manager(): + with st.expander("📃 **Upload your personal files**", expanded=False): + st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).") + st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).") + st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True) + # st.markdown("### Uploaded Files") + st.dataframe( + data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]), + use_container_width=True, + ) + st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty() + col_1, col_2 = st.columns(2) + with col_1: + st.button(label="Upload files", on_click=add_file) + with col_2: + st.button(label="Clear all files and tools", on_click=clear_files) + + +def _render_create_personal_knowledge_bases(div: DeltaGenerator): + with div: + st.markdown("- If you haven't upload your personal files, please upload them first.") + st.markdown("- Select some **files** to build your `personal knowledge base`.") + st.markdown("- Once the your `personal knowledge base` is built, " + "it will answer your questions using information from your personal **files**.") + st.multiselect( + label="⚡️Select some files to build a **personal knowledge base**", + options=st.session_state[USER_PRIVATE_FILES], + placeholder="You should upload some files first", + key=EL_BUILD_KB_WITH_FILES, + format_func=lambda x: x["file_name"], + ) + st.text_input( + label="⚡️Personal knowledge base name", + value="get_relevant_documents", + key=EL_PERSONAL_KB_NAME + ) + st.text_input( + label="⚡️Personal knowledge base description", + value="Searches from some personal files.", + key=EL_PERSONAL_KB_DESCRIPTION, + ) + st.button( + label="Build 🔧", + on_click=create_private_knowledge_base_as_tool + ) + + +def _render_remove_personal_knowledge_bases(div: DeltaGenerator): + with div: + st.markdown("> Here is all your personal knowledge bases.") + if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0: + st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) + else: + st.warning("You don't have any personal knowledge bases, please create a new one.") + st.multiselect( + label="Choose a personal knowledge base to delete", + placeholder="Choose a personal knowledge base to delete", + options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES], + format_func=lambda x: x["tool_name"], + key=EL_PERSONAL_KB_NEEDS_REMOVE, + ) + st.button("Delete", on_click=remove_private_knowledge_bases, type="primary") + + +def render_personal_tools_build(): + with st.expander("🔨 **Build your personal knowledge base**", expanded=True): + create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"]) + _render_create_personal_knowledge_bases(create_new_kb) + _render_remove_personal_knowledge_bases(kb_manager) + + +def render_knowledge_base_selector(): + with st.expander("🙋 **Select some knowledge bases to query**", expanded=True): + st.markdown("- Knowledge bases come in two types: `public` and `private`.") + st.markdown("- All users can access our `public` knowledge bases.") + st.markdown("- Only you can access your `personal` knowledge bases.") + options = st.session_state[RETRIEVER_TOOLS].keys() + if AVAILABLE_RETRIEVAL_TOOLS in st.session_state: + options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS] + st.multiselect( + label="Select some knowledge base tool", + placeholder="Please select some knowledge bases to query", + options=options, + default=["Wikipedia + Self Querying"], + key=EL_SELECTED_KBS, + on_change=refresh_agent, + ) + + +def chat_page(): + # initialize resources + build_chat_knowledge_table() + initialize_session_manager() + + # render sidebar + with st.sidebar: + left, middle, right = st.columns([1, 1, 2]) + with left: + st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main) + with right: + st.markdown(f"👤 `{st.session_state[USER_NAME]}`") + st.markdown(DIVIDER_HTML, unsafe_allow_html=True) + render_session_manager() + render_session_selection() + render_files_manager() + render_personal_tools_build() + render_knowledge_base_selector() + + # render chat history + if "agent" not in st.session_state: + refresh_agent() + for msg in st.session_state.agent.memory.chat_memory.messages: + speaker = "user" if isinstance(msg, HumanMessage) else "assistant" + if isinstance(msg, FunctionMessage): + with st.chat_message(name="from knowledge base", avatar="📚"): + st.write( + f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" + ) + st.write("Retrieved from knowledge base:") + try: + st.dataframe( + pd.DataFrame.from_records( + json.loads(msg.content, cls=CustomJSONDecoder) + ), + use_container_width=True, + ) + except Exception as e: + st.warning(e) + st.write(msg.content) + else: + if len(msg.content) > 0: + with st.chat_message(speaker): + # print(type(msg), msg.dict()) + st.write( + f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" + ) + st.write(f"{msg.content}") + st.session_state["next_round"] = st.empty() + from streamlit import _bottom + with _bottom: + col1, col2 = st.columns([1, 16]) + with col1: + st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary") + with col2: + st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") diff --git a/app/ui/home.py b/app/ui/home.py new file mode 100644 index 0000000..e2d71ca --- /dev/null +++ b/app/ui/home.py @@ -0,0 +1,156 @@ +import base64 + +from streamlit_extras.add_vertical_space import add_vertical_space +from streamlit_extras.card import card +from streamlit_extras.colored_header import colored_header +from streamlit_extras.mention import mention +from streamlit_extras.tags import tagger_component + +from logger import logger +import os + +import streamlit as st +from auth0_component import login_button + +from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML +from streamlit_extras.let_it_rain import rain + + +def render_home(): + render_home_header() + # st.divider() + # st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) + add_vertical_space(5) + render_home_content() + # st.divider() + st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) + render_home_footer() + + +def render_home_header(): + logger.info("render home header") + st.header("ChatData - Your Intelligent Assistant") + st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) + st.markdown("> [ChatData](https://github.com/myscale/ChatData) \ + is developed by [MyScale](https://myscale.com/), \ + it's an integration of [LangChain](https://www.langchain.com/) \ + and [MyScaleDB](https://github.com/myscale/myscaledb)") + + tagger_component( + "Keywords:", + ["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"], + color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"], + ) + text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4]) + with text: + st.markdown("Related:") + with col1.container(): + mention( + label="streamlit", + icon="streamlit", + url="https://streamlit.io/", + write=True + ) + with col2.container(): + mention( + label="langchain", + icon="🦜🔗", + url="https://www.langchain.com/", + write=True + ) + with col3.container(): + mention( + label="streamlit-extras", + icon="🪢", + url="https://github.com/arnaudmiribel/streamlit-extras", + write=True + ) + + +def _render_self_query_chain_content(): + col1, col2 = st.columns([1, 1], gap='large') + with col1.container(): + st.image(image='../assets/home_page_background_1.png', + caption=None, + width=None, + use_column_width=True, + clamp=False, + channels="RGB", + output_format="PNG") + with col2.container(): + st.header("VectorSearch & SelfQuery with Sources") + st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.") + st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options: + +1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content. + +2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""") + add_vertical_space(3) + _, middle, _ = st.columns([2, 1, 2], gap='small') + with middle.container(): + st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary") + + +def _render_chat_bot_content(): + col1, col2 = st.columns(2, gap='large') + with col1.container(): + st.image(image='../assets/home_page_background_2.png', + caption=None, + width=None, + use_column_width=True, + clamp=False, + channels="RGB", + output_format="PNG") + with col2.container(): + st.header("Chat Bot") + st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.") + st.markdown("- You need to log in. We use `user_name` to identify each customer.") + st.markdown("- You can upload your own PDF files and build your own knowledge base. \ + (This is just a sample application. Please do not upload important or confidential files.)") + st.markdown("- A default session will be assigned as your initial chat session. \ + You can create and switch to other sessions to jump between different chat conversations.") + add_vertical_space(1) + _, middle, _ = st.columns([1, 2, 1], gap='small') + with middle.container(): + if USER_NAME not in st.session_state: + login_button(clientId=os.environ["AUTH0_CLIENT_ID"], + domain=os.environ["AUTH0_DOMAIN"], + key="auth0") + # if user_info: + # user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null") + # st.session_state[USER_NAME] = user_name + # print(user_info) + + +def render_home_content(): + logger.info("render home content") + _render_self_query_chain_content() + add_vertical_space(3) + _render_chat_bot_content() + + +def render_home_footer(): + logger.info("render home footer") + st.write( + "Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!" + ) + st.write( + "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!") + st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)") + + # st.write( + # "Recommended to use the standalone version of Chat-Data, " + # "available [here](https://myscale-chatdata.hf.space/)." + # ) + + if st.session_state.auth0 is not None: + st.session_state[USER_INFO] = dict(st.session_state.auth0) + if 'email' in st.session_state[USER_INFO]: + email = st.session_state[USER_INFO]["email"] + else: + email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}" + st.session_state["user_name"] = email + del st.session_state.auth0 + st.rerun() + if st.session_state.jump_query_ask: + st.rerun() diff --git a/app/ui/retrievers.py b/app/ui/retrievers.py new file mode 100644 index 0000000..4fb542c --- /dev/null +++ b/app/ui/retrievers.py @@ -0,0 +1,97 @@ +import streamlit as st +from streamlit_extras.add_vertical_space import add_vertical_space + +from backend.constants.myscale_tables import MYSCALE_TABLES +from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons +from backend.retrievers.self_query import process_self_query +from backend.retrievers.vector_sql_query import process_sql_query +from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO + + +def back_to_main(): + if USER_INFO in st.session_state: + del st.session_state[USER_INFO] + if USER_NAME in st.session_state: + del st.session_state[USER_NAME] + if JUMP_QUERY_ASK in st.session_state: + del st.session_state[JUMP_QUERY_ASK] + + +def _render_table_selector() -> str: + col1, col2 = st.columns(2) + with col1: + selected_table = st.selectbox( + label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.', + options=MYSCALE_TABLES.keys(), + ) + MYSCALE_TABLES[selected_table].hint() + with col2: + add_vertical_space(1) + st.info(f"Here is your selected public knowledge base schema in MyScaleDB", + icon='📚') + MYSCALE_TABLES[selected_table].hint_sql() + + return selected_table + + +def render_retrievers(): + st.button("⬅️ Back", key="back_sql", on_click=back_to_main) + st.subheader('Please choose a public knowledge base to search.') + selected_table = _render_table_selector() + + tab_sql, tab_self_query = st.tabs( + tabs=['Vector SQL', 'Self-querying Retriever'] + ) + + with tab_sql: + render_tab_sql(selected_table) + + with tab_self_query: + render_tab_self_query(selected_table) + + +def render_tab_sql(selected_table: str): + st.warning( + "When you input a query with filtering conditions, you need to ensure that your filters are applied only to " + "the metadata we provide. This table allows filters to be established on the following metadata fields:", + icon="⚠️") + st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"]) + + cols = st.columns([8, 3, 3, 2]) + cols[0].text_input("Input your question:", key='query_sql') + with cols[1].container(): + add_vertical_space(2) + st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db) + with cols[2].container(): + add_vertical_space(2) + st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm) + + if st.session_state[RetrieverButtons.vector_sql_query_from_db]: + process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db) + + if st.session_state[RetrieverButtons.vector_sql_query_with_llm]: + process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm) + + +def render_tab_self_query(selected_table): + st.warning( + "When you input a query with filtering conditions, you need to ensure that your filters are applied only to " + "the metadata we provide. This table allows filters to be established on the following metadata fields:", + icon="⚠️") + st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"]) + + cols = st.columns([8, 3, 3, 2]) + cols[0].text_input("Input your question:", key='query_self') + + with cols[1].container(): + add_vertical_space(2) + st.button("Retrieve from MyScaleDB ➡️", key='search_self') + with cols[2].container(): + add_vertical_space(2) + st.button("Retrieve and answer with LLM ➡️", key='ask_self') + + if st.session_state.search_self: + process_self_query(selected_table, RetrieverButtons.self_query_from_db) + + if st.session_state.ask_self: + process_self_query(selected_table, RetrieverButtons.self_query_with_llm) diff --git a/app/ui/utils.py b/app/ui/utils.py new file mode 100644 index 0000000..92b1177 --- /dev/null +++ b/app/ui/utils.py @@ -0,0 +1,18 @@ +import streamlit as st + + +def display(dataframe, columns_=None, index=None): + if len(dataframe) > 0: + if index: + dataframe.set_index(index) + if columns_: + st.dataframe(dataframe[columns_]) + else: + st.dataframe(dataframe) + else: + st.write( + "Sorry 😵 we didn't find any articles related to your query.\n\n" + "Maybe the LLM is too naughty that does not follow our instruction... \n\n" + "Please try again and use verbs that may match the datatype.", + unsafe_allow_html=True + ) diff --git a/assets/chatdata-homepage.png b/assets/chatdata-homepage.png deleted file mode 100644 index 1580b7f..0000000 Binary files a/assets/chatdata-homepage.png and /dev/null differ diff --git a/assets/home.png b/assets/home.png new file mode 100644 index 0000000..22f1eaf Binary files /dev/null and b/assets/home.png differ diff --git a/assets/home_page_background_1.png b/assets/home_page_background_1.png new file mode 100644 index 0000000..1be1602 Binary files /dev/null and b/assets/home_page_background_1.png differ diff --git a/assets/home_page_background_2.png b/assets/home_page_background_2.png new file mode 100644 index 0000000..bf18230 Binary files /dev/null and b/assets/home_page_background_2.png differ