Skip to content

Commit

Permalink
update to mistral v0.2, add selectable embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed May 6, 2024
1 parent d74cacd commit 7bc374b
Showing 1 changed file with 52 additions and 12 deletions.
64 changes: 52 additions & 12 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,23 @@
"gpt-4",
"gpt-4-1106-preview"]

OPENAI_EMBEDDINGS = [
'text-embedding-ada-002',
'text-embedding-3-large',
'openai-text-embedding-3-small'
]

OPEN_MODELS = {
'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1',
'mistral-7b-instruct-v0.2': 'mistralai/Mistral-7B-Instruct-v0.2',
"zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta'
# 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct",
# 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct"
}

DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)'
OPEN_EMBEDDINGS = {
DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2',
'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
}

DISABLE_MEMORY = ['zephyr-7b-beta']
Expand Down Expand Up @@ -83,6 +97,9 @@
if 'pdf_rendering' not in st.session_state:
st.session_state['pdf_rendering'] = None

if 'embeddings' not in st.session_state:
st.session_state['embeddings'] = None

st.set_page_config(
page_title="Scientific Document Insights Q/A",
page_icon="📝",
Expand Down Expand Up @@ -139,32 +156,42 @@ def clear_memory():


# @st.cache_resource
def init_qa(model, api_key=None):
def init_qa(model, embeddings_name=None, api_key=None):
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
if model in OPENAI_MODELS:
if embeddings_name is None:
embeddings_name = 'text-embedding-ada-002'

st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
if api_key:
chat = ChatOpenAI(model_name=model,
temperature=0,
openai_api_key=api_key,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
if embeddings_name not in OPENAI_EMBEDDINGS:
st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.")
st.stop()
return
embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key)

else:
chat = ChatOpenAI(model_name=model,
temperature=0,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings()
embeddings = OpenAIEmbeddings(model=embeddings_name)

elif model in OPEN_MODELS:
if embeddings_name is None:
embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME

chat = HuggingFaceEndpoint(
repo_id=OPEN_MODELS[model],
temperature=0.01,
max_new_tokens=2048,
model_kwargs={"max_length": 4096}
)
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2")
model_name=OPEN_EMBEDDINGS[embeddings_name])
st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
else:
st.error("The model was not loaded properly. Try reloading. ")
Expand Down Expand Up @@ -231,15 +258,25 @@ def play_old_messages():
"Model:",
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
"zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
"mistral-7b-instruct-v0.2") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
placeholder="Select model",
help="Select the LLM model:",
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
)
embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS

st.session_state['embeddings'] = embedding_name = st.selectbox(
"Embeddings:",
options=embedding_choices,
index=0,
placeholder="Select embedding",
help="Select the Embedding function:",
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
)

st.markdown(
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")

if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
Expand All @@ -256,7 +293,7 @@ def play_old_messages():
st.session_state['api_keys'][model] = api_key
# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'][model] = init_qa(model)
st.session_state['rqa'][model] = init_qa(model, embedding_name)

elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
if 'OPENAI_API_KEY' not in os.environ:
Expand All @@ -270,9 +307,9 @@ def play_old_messages():
with st.spinner("Preparing environment"):
st.session_state['api_keys'][model] = api_key
if 'OPENAI_API_KEY' not in os.environ:
st.session_state['rqa'][model] = init_qa(model, api_key)
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key)
else:
st.session_state['rqa'][model] = init_qa(model)
st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
# else:
# is_api_key_provided = st.session_state['api_key']

Expand Down Expand Up @@ -371,10 +408,13 @@ def play_old_messages():

st.header("Query mode (Advanced use)")
st.markdown(
"""By default, the mode is set to LLM (Language Model) which enables question/answering. You can directly ask questions related to the document content, and the system will answer the question using content from the document.""")
"""By default, the mode is set to LLM (Language Model) which enables question/answering.
You can directly ask questions related to the document content, and the system will answer the question using content from the document.""")

st.markdown(
"""If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
"""If you switch the mode to "Embedding," the system will return specific chunks from the document
that are semantically related to your query. This mode helps to test why sometimes the answers are not
satisfying or incomplete. """)

if uploaded_file and not st.session_state.loaded_embeddings:
if model not in st.session_state['api_keys']:
Expand Down

0 comments on commit 7bc374b

Please sign in to comment.