Skip to content

Commit

Permalink
Merge branch 'main' into fix-conversational-memory
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano authored Nov 22, 2023
2 parents ad304e2 + 55e39a2 commit 16cf398
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
40 changes: 30 additions & 10 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Union, Any

from document_qa.grobid_processors import GrobidProcessor
from grobid_client.grobid_client import GrobidClient
from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
Expand All @@ -14,8 +15,6 @@
from langchain.vectorstores import Chroma
from tqdm import tqdm

from document_qa.grobid_processors import GrobidProcessor


class DocumentQAEngine:
llm = None
Expand Down Expand Up @@ -188,8 +187,10 @@ def _get_context_multiquery(self, doc_id, query, context_size=4):
relevant_documents = multi_query_retriever.get_relevant_documents(query)
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False):
"""
Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
"""
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
Expand All @@ -204,6 +205,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
texts = []
metadatas = []
ids = []

if chunk_size < 0:
for passage in structure['passages']:
biblio_copy = copy.copy(biblio)
Expand All @@ -227,10 +229,25 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
metadatas = [biblio for _ in range(len(texts))]
ids = [id for id, t in enumerate(texts)]

if "biblio" in include:
biblio_metadata = copy.copy(biblio)
biblio_metadata['type'] = "biblio"
biblio_metadata['section'] = "header"
for key in ['title', 'authors', 'publication_year']:
if key in biblio_metadata:
texts.append("{}: {}".format(key, biblio_metadata[key]))
metadatas.append(biblio_metadata)
ids.append(key)

return texts, metadatas, ids

def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False):
include = ["biblio"] if include_biblio else []
texts, metadata, ids = self.get_text_from_document(
pdf_path,
chunk_size=chunk_size,
perc_overlap=perc_overlap,
include=include)
if doc_id:
hash = doc_id
else:
Expand All @@ -252,7 +269,7 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o

return hash

def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
input_files = []
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
for file_ in files:
Expand All @@ -269,9 +286,12 @@ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.
if os.path.exists(data_path):
print(data_path, "exists. Skipping it ")
continue

texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
perc_overlap=perc_overlap)
include = ["biblio"] if include_biblio else []
texts, metadata, ids = self.get_text_from_document(
input_file,
chunk_size=chunk_size,
perc_overlap=perc_overlap,
include=include)
filename = metadata[0]['filename']

vector_db_document = Chroma.from_texts(texts,
Expand Down
2 changes: 1 addition & 1 deletion document_qa/grobid_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def parse_grobid_xml(self, text):
}
try:
year = dateparser.parse(doc_biblio.header.date).year
biblio["year"] = year
biblio["publication_year"] = year
except:
pass

Expand Down
3 changes: 2 additions & 1 deletion streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def play_old_messages():
# hash = get_file_hash(tmp_file.name)[:10]
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1)
perc_overlap=0.1,
include_biblio=True)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []

Expand Down

0 comments on commit 16cf398

Please sign in to comment.