Skip to content

Commit

Permalink
update dependencies, remove biblio from search space
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Aug 23, 2024
1 parent bf050bb commit 2814cb7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 30 deletions.
9 changes: 4 additions & 5 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
map_rerank_prompt
from langchain.evaluation import PairwiseEmbeddingDistanceEvalChain, load_evaluator, EmbeddingDistance
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.retrievers import MultiQueryRetriever
from langchain.schema import Document
Expand Down Expand Up @@ -273,7 +272,7 @@ def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> Li
"""
db = self.data_storage.embeddings_dict[doc_id]
retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
relevant_documents = retriever.get_relevant_documents(query)
relevant_documents = retriever.invoke(query)

return relevant_documents

Expand All @@ -284,7 +283,7 @@ def analyse_query(self, query, doc_id, context_size=4):
# search_type="similarity_score_threshold"
# )
retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
relevant_documents = retriever.get_relevant_documents(query)
relevant_documents = retriever.invoke(query)
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
for doc in
relevant_documents]
Expand Down Expand Up @@ -338,7 +337,7 @@ def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
db = self.data_storage.embeddings_dict[doc_id]
retriever = db.as_retriever(search_kwargs={"k": context_size})
relevant_documents = retriever.get_relevant_documents(query)
relevant_documents = retriever.invoke(query)
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
for doc in
relevant_documents]
Expand All @@ -361,7 +360,7 @@ def get_full_context_by_document(self, doc_id):
def _get_context_multiquery(self, doc_id, query, context_size=4):
db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
relevant_documents = multi_query_retriever.get_relevant_documents(query)
relevant_documents = multi_query_retriever.invoke(query)
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
Expand Down
18 changes: 9 additions & 9 deletions document_qa/grobid_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def parse_grobid_xml(self, text, coordinates=False):
soup = BeautifulSoup(text, 'xml')
blocks_header = get_xml_nodes_header(soup, use_paragraphs=True)

passages.append({
"text": f"authors: {biblio['authors']}",
"type": passage_type,
"section": "<header>",
"subSection": "<authors>",
"passage_id": "hauthors",
"coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
blocks_header['authors']])
})
# passages.append({
# "text": f"authors: {biblio['authors']}",
# "type": passage_type,
# "section": "<header>",
# "subSection": "<authors>",
# "passage_id": "hauthors",
# "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
# blocks_header['authors']])
# })

passages.append({
"text": self.post_process(" ".join([node.text for node in blocks_header['title']])),
Expand Down
13 changes: 8 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ dateparser

# LLM
chromadb==0.4.24
tiktoken==0.6.0
openai==1.16.2
langchain==0.1.14
langchain-core==0.1.40
tiktoken==0.7.0
openai==1.42.0
langchain==0.2.14
langchain-core==0.2.34
langchain-openai==0.1.22
langchain-huggingface==0.0.3
langchain-community==0.2.12
typing-inspect==0.9.0
typing_extensions==4.11.0
pydantic==2.6.4
sentence_transformers==2.6.1
streamlit-pdf-viewer==0.0.17
streamlit-pdf-viewer==0.0.18-dev1
umap-learn
plotly
31 changes: 20 additions & 11 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import dotenv
from grobid_quantities.quantities import QuantitiesAPI
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.callbacks import PromptLayerCallbackHandler
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
from streamlit_pdf_viewer import pdf_viewer

from document_qa.ner_client_generic import NERClientGeneric
Expand Down Expand Up @@ -97,6 +98,9 @@
if 'embeddings' not in st.session_state:
st.session_state['embeddings'] = None

if 'scroll_to_first_annotation' not in st.session_state:
st.session_state['scroll_to_first_annotation'] = False

st.set_page_config(
page_title="Scientific Document Insights Q/A",
page_icon="📝",
Expand Down Expand Up @@ -169,7 +173,8 @@ def init_qa(model, embeddings_name=None, api_key=None):
repo_id=OPEN_MODELS[model],
temperature=0.01,
max_new_tokens=4092,
model_kwargs={"max_length": 8192}
model_kwargs={"max_length": 8192},
callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])]
)
embeddings = HuggingFaceEmbeddings(
model_name=OPEN_EMBEDDINGS[embeddings_name])
Expand Down Expand Up @@ -233,8 +238,8 @@ def play_old_messages(container):
# is_api_key_provided = st.session_state['api_key']

with st.sidebar:
st.title("📝 Scientific Document Insights Q/A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
st.title("📝 Document Q/A")
st.markdown("Upload a scientific article in PDF, ask questions, get insights.")
st.markdown(
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")

Expand Down Expand Up @@ -301,14 +306,14 @@ def play_old_messages(container):
# help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
# disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)

left_column, right_column = st.columns([1, 1])
left_column, right_column = st.columns([5, 4])
right_column = right_column.container(border=True)
left_column = left_column.container(border=True)

with right_column:
uploaded_file = st.file_uploader(
"Upload an article",
type=("pdf", "txt"),
"Upload a scientific article",
type=("pdf"),
on_change=new_file,
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
st.session_state['api_keys'],
Expand Down Expand Up @@ -343,6 +348,10 @@ def play_old_messages(container):
"relevant paragraphs to the question in the paper. "
"Question coefficient attempt to estimate how effective the question will be answered."
)
st.session_state['scroll_to_first_annotation'] = st.checkbox(
"Scroll to context",
help='The PDF viewer will automatically scroll to the first relevant passage in the document.'
)
st.session_state['ner_processing'] = st.checkbox(
"Identify materials and properties.",
help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.'
Expand Down Expand Up @@ -415,7 +424,6 @@ def generate_color_gradient(num_elements):

with right_column:
if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
# messages.chat_message("user").markdown(question)
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})

for message in st.session_state.messages:
Expand Down Expand Up @@ -491,5 +499,6 @@ def generate_color_gradient(num_elements):
input=st.session_state['binary'],
annotation_outline_size=2,
annotations=st.session_state['annotations'],
render_text=True
render_text=True,
scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state['scroll_to_first_annotation']) else None
)

0 comments on commit 2814cb7

Please sign in to comment.