diff --git a/streamlit_app.py b/streamlit_app.py index 31eebe4..8972e40 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -54,7 +54,7 @@ st.session_state['uploaded'] = False if 'memory' not in st.session_state: - st.session_state['memory'] = ConversationBufferWindowMemory(k=4) + st.session_state['memory'] = None if 'binary' not in st.session_state: st.session_state['binary'] = None @@ -117,12 +117,14 @@ def clear_memory(): def init_qa(model, api_key=None): ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model == 'chatgpt-3.5-turbo': + st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if api_key: chat = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key, frequency_penalty=0.1) embeddings = OpenAIEmbeddings(openai_api_key=api_key) + else: chat = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, @@ -134,11 +136,13 @@ def init_qa(model, api_key=None): model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}) embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2") + st.session_state['memory'] = ConversationBufferWindowMemory(k=4) elif model == 'zephyr-7b-beta': chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}) embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") + st.session_state['memory'] = None else: st.error("The model was not loaded properly. Try reloading. ") st.stop() @@ -255,7 +259,8 @@ def play_old_messages(): 'Reset chat memory.', key="reset-memory-button", on_click=clear_memory, - help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.") + help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.", + disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None) left_column, right_column = st.columns([1, 1]) @@ -267,8 +272,8 @@ def play_old_messages(): ":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.") uploaded_file = st.file_uploader("Upload an article", - type=("pdf", "txt"), - on_change=new_file, + type=("pdf", "txt"), + on_change=new_file, disabled=st.session_state['model'] is not None and st.session_state['model'] not in st.session_state['api_keys'], help="The full-text is extracted using Grobid. ") @@ -335,8 +340,8 @@ def get_pdf_display(binary): st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name, chunk_size=chunk_size, - perc_overlap=0.1, - include_biblio=True) + perc_overlap=0.1, + include_biblio=True) st.session_state['loaded_embeddings'] = True st.session_state.messages = [] @@ -389,7 +394,7 @@ def get_pdf_display(binary): elif mode == "LLM": with st.spinner("Generating response..."): _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id, - context_size=context_size) + context_size=context_size) if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")