Skip to content

Commit

Permalink
disable conversational memory with zephyr
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Nov 29, 2023
1 parent d2299da commit a2bcc71
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
st.session_state['uploaded'] = False

if 'memory' not in st.session_state:
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
st.session_state['memory'] = None

if 'binary' not in st.session_state:
st.session_state['binary'] = None
Expand Down Expand Up @@ -117,12 +117,14 @@ def clear_memory():
def init_qa(model, api_key=None):
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
if model == 'chatgpt-3.5-turbo':
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
if api_key:
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
temperature=0,
openai_api_key=api_key,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings(openai_api_key=api_key)

else:
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
temperature=0,
Expand All @@ -134,11 +136,13 @@ def init_qa(model, api_key=None):
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2")
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)

elif model == 'zephyr-7b-beta':
chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
st.session_state['memory'] = None
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
Expand Down Expand Up @@ -255,7 +259,8 @@ def play_old_messages():
'Reset chat memory.',
key="reset-memory-button",
on_click=clear_memory,
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)

left_column, right_column = st.columns([1, 1])

Expand All @@ -267,8 +272,8 @@ def play_old_messages():
":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")

uploaded_file = st.file_uploader("Upload an article",
type=("pdf", "txt"),
on_change=new_file,
type=("pdf", "txt"),
on_change=new_file,
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
st.session_state['api_keys'],
help="The full-text is extracted using Grobid. ")
Expand Down Expand Up @@ -335,8 +340,8 @@ def get_pdf_display(binary):

st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1,
include_biblio=True)
perc_overlap=0.1,
include_biblio=True)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []

Expand Down Expand Up @@ -389,7 +394,7 @@ def get_pdf_display(binary):
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
context_size=context_size)
context_size=context_size)

if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
Expand Down

0 comments on commit a2bcc71

Please sign in to comment.