diff --git a/streamlit_app.py b/streamlit_app.py index 643baa7..6ff5637 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -24,9 +24,23 @@ "gpt-4", "gpt-4-1106-preview"] +OPENAI_EMBEDDINGS = [ + 'text-embedding-ada-002', + 'text-embedding-3-large', + 'openai-text-embedding-3-small' +] + OPEN_MODELS = { - 'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1', + 'mistral-7b-instruct-v0.2': 'mistralai/Mistral-7B-Instruct-v0.2', "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta' + # 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct", + # 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" +} + +DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)' +OPEN_EMBEDDINGS = { + DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2', + 'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral' } DISABLE_MEMORY = ['zephyr-7b-beta'] @@ -83,6 +97,9 @@ if 'pdf_rendering' not in st.session_state: st.session_state['pdf_rendering'] = None +if 'embeddings' not in st.session_state: + st.session_state['embeddings'] = None + st.set_page_config( page_title="Scientific Document Insights Q/A", page_icon="📝", @@ -139,24 +156,34 @@ def clear_memory(): # @st.cache_resource -def init_qa(model, api_key=None): +def init_qa(model, embeddings_name=None, api_key=None): ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model in OPENAI_MODELS: + if embeddings_name is None: + embeddings_name = 'text-embedding-ada-002' + st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if api_key: chat = ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key, frequency_penalty=0.1) - embeddings = OpenAIEmbeddings(openai_api_key=api_key) + if embeddings_name not in OPENAI_EMBEDDINGS: + st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.") + st.stop() + return + embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key) else: chat = ChatOpenAI(model_name=model, temperature=0, frequency_penalty=0.1) - embeddings = OpenAIEmbeddings() + embeddings = OpenAIEmbeddings(model=embeddings_name) elif model in OPEN_MODELS: + if embeddings_name is None: + embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME + chat = HuggingFaceEndpoint( repo_id=OPEN_MODELS[model], temperature=0.01, @@ -164,7 +191,7 @@ def init_qa(model, api_key=None): model_kwargs={"max_length": 4096} ) embeddings = HuggingFaceEmbeddings( - model_name="all-MiniLM-L6-v2") + model_name=OPEN_EMBEDDINGS[embeddings_name]) st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None else: st.error("The model was not loaded properly. Try reloading. ") @@ -231,15 +258,25 @@ def play_old_messages(): "Model:", options=OPENAI_MODELS + list(OPEN_MODELS.keys()), index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index( - "zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else ( + "mistral-7b-instruct-v0.2") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else ( OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]), placeholder="Select model", help="Select the LLM model:", disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] ) + embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS + + st.session_state['embeddings'] = embedding_name = st.selectbox( + "Embeddings:", + options=embedding_choices, + index=0, + placeholder="Select embedding", + help="Select the Embedding function:", + disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] + ) st.markdown( - ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") + ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") if (model in OPEN_MODELS) and model not in st.session_state['api_keys']: if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: @@ -256,7 +293,7 @@ def play_old_messages(): st.session_state['api_keys'][model] = api_key # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key - st.session_state['rqa'][model] = init_qa(model) + st.session_state['rqa'][model] = init_qa(model, embedding_name) elif model in OPENAI_MODELS and model not in st.session_state['api_keys']: if 'OPENAI_API_KEY' not in os.environ: @@ -270,9 +307,9 @@ def play_old_messages(): with st.spinner("Preparing environment"): st.session_state['api_keys'][model] = api_key if 'OPENAI_API_KEY' not in os.environ: - st.session_state['rqa'][model] = init_qa(model, api_key) + st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key) else: - st.session_state['rqa'][model] = init_qa(model) + st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings']) # else: # is_api_key_provided = st.session_state['api_key'] @@ -371,10 +408,13 @@ def play_old_messages(): st.header("Query mode (Advanced use)") st.markdown( - """By default, the mode is set to LLM (Language Model) which enables question/answering. You can directly ask questions related to the document content, and the system will answer the question using content from the document.""") + """By default, the mode is set to LLM (Language Model) which enables question/answering. + You can directly ask questions related to the document content, and the system will answer the question using content from the document.""") st.markdown( - """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """) + """If you switch the mode to "Embedding," the system will return specific chunks from the document + that are semantically related to your query. This mode helps to test why sometimes the answers are not + satisfying or incomplete. """) if uploaded_file and not st.session_state.loaded_embeddings: if model not in st.session_state['api_keys']: