Skip to content

Commit

Permalink
format python code and remove unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
lqhl committed Jan 17, 2024
1 parent 6cb0f75 commit 67687ec
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 78 deletions.
16 changes: 7 additions & 9 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
import time
import pandas as pd
from os import environ
import streamlit as st
Expand All @@ -13,10 +11,10 @@
from lib.helper import build_tools, build_all, sel_map, display



environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']

st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
st.set_page_config(page_title="ChatData",
page_icon="https://myscale.com/favicon.ico")
st.markdown(
f"""
<style>
Expand All @@ -36,11 +34,12 @@
if "user_name" in st.session_state:
chat_page()
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:

sel = st.selectbox('Choose the knowledge base you want to ask with:',
options=['ArXiv Papers', 'Wikipedia'])
options=['ArXiv Papers', 'Wikipedia'])
sel_map[sel]['hint']()
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
tab_sql, tab_self_query = st.tabs(
['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
sel_map[sel]['hint_sql']()
st.text_input("Ask a question:", key='query_sql')
Expand Down Expand Up @@ -85,7 +84,6 @@
st.write('Oops 😵 Something bad happened...')
raise e


with tab_self_query:
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
Expand Down Expand Up @@ -132,4 +130,4 @@
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
raise e
7 changes: 4 additions & 3 deletions app/callbacks/arxiv_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from streamlit.delta_generator import DeltaGenerator


class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
Expand All @@ -26,7 +25,8 @@ def on_chain_end(self, outputs, **kwargs) -> None:
self.progress_bar.progress(value=0.6, text="Searching in DB...")
if "repr" in outputs:
st.markdown("### Generated Filter")
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
st.markdown(
f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)

def on_chain_start(self, serialized, inputs, **kwargs) -> None:
pass
Expand Down Expand Up @@ -88,7 +88,8 @@ def on_llm_end(
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
print(f"Vector SQL: {text}")
self.prog_value += self.prog_interval
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
self.progress_bar.progress(
value=self.prog_value, text="Searching in DB...")

def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
Expand Down
19 changes: 11 additions & 8 deletions app/chains/arxiv_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
CallbackManagerForChainRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain.schema.prompt_template import format_document
from langchain.docstore.document import Document
Expand All @@ -20,11 +19,12 @@

logger = logging.getLogger()


class MyScaleWithoutMetadataJson(MyScale):
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
super().__init__(embedding, config, **kwargs)
self.must_have_cols: List[str] = must_have_cols

def _build_qstr(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
Expand All @@ -43,7 +43,7 @@ def _build_qstr(
LIMIT {topk}
"""
return q_str

def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
q_str = self._build_qstr(embedding, k, where_str)
try:
Expand All @@ -55,9 +55,11 @@ def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
logger.error(
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []


class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
"""Based on VectorSQLOutputParser
It also modify the SQL to get all columns
Expand All @@ -73,9 +75,11 @@ def parse(self, text: str) -> Dict[str, Any]:
start = text.upper().find("SELECT")
if start >= 0:
end = text.upper().find("FROM")
text = text.replace(text[start + len("SELECT") + 1 : end - 1], ", ".join(self.must_have_columns))
text = text.replace(
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
return super().parse(text)


class ArXivStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""

Expand Down Expand Up @@ -172,8 +176,7 @@ def _call(
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
sources.append(d)
ref_cnt += 1



result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
Expand All @@ -191,4 +194,4 @@ async def _acall(

@property
def _chain_type(self) -> str:
return "arxiv_qa_with_sources_chain"
return "arxiv_qa_with_sources_chain"
12 changes: 6 additions & 6 deletions app/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from lib.private_kb import PrivateKnowledgeBase
from langchain.schema import HumanMessage, FunctionMessage
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler,
)
from lib.json_conv import CustomJSONDecoder

from lib.helper import (
Expand Down Expand Up @@ -313,7 +310,8 @@ def chat_page():
key="b_tool_files",
format_func=lambda x: x["file_name"],
)
st.text_input("Tool Name", "get_relevant_documents", key="b_tool_name")
st.text_input(
"Tool Name", "get_relevant_documents", key="b_tool_name")
st.text_input(
"Tool Description",
"Searches among user's private files and returns related documents",
Expand Down Expand Up @@ -359,14 +357,16 @@ def chat_page():
)
st.markdown("### Uploaded Files")
st.dataframe(
st.session_state.private_kb.list_files(st.session_state.user_name),
st.session_state.private_kb.list_files(
st.session_state.user_name),
use_container_width=True,
)
col_1, col_2 = st.columns(2)
with col_1:
st.button("Add Files", on_click=add_file)
with col_2:
st.button("Clear Files and All Tools", on_click=clear_files)
st.button("Clear Files and All Tools",
on_click=clear_files)

st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
Expand Down
Loading

0 comments on commit 67687ec

Please sign in to comment.