diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index f91fc6d..0986460 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Union, Any +from document_qa.grobid_processors import GrobidProcessor from grobid_client.grobid_client import GrobidClient from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ @@ -14,8 +15,6 @@ from langchain.vectorstores import Chroma from tqdm import tqdm -from document_qa.grobid_processors import GrobidProcessor - class DocumentQAEngine: llm = None @@ -188,8 +187,10 @@ def _get_context_multiquery(self, doc_id, query, context_size=4): relevant_documents = multi_query_retriever.get_relevant_documents(query) return relevant_documents - def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): - """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately""" + def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False): + """ + Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately + """ if verbose: print("File", pdf_file_path) filename = Path(pdf_file_path).stem @@ -204,6 +205,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, texts = [] metadatas = [] ids = [] + if chunk_size < 0: for passage in structure['passages']: biblio_copy = copy.copy(biblio) @@ -227,10 +229,25 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, metadatas = [biblio for _ in range(len(texts))] ids = [id for id, t in enumerate(texts)] + if "biblio" in include: + biblio_metadata = copy.copy(biblio) + biblio_metadata['type'] = "biblio" + biblio_metadata['section'] = "header" + for key in ['title', 'authors', 'publication_year']: + if key in biblio_metadata: + texts.append("{}: {}".format(key, biblio_metadata[key])) + metadatas.append(biblio_metadata) + ids.append(key) + return texts, metadatas, ids - def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1): - texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap) + def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False): + include = ["biblio"] if include_biblio else [] + texts, metadata, ids = self.get_text_from_document( + pdf_path, + chunk_size=chunk_size, + perc_overlap=perc_overlap, + include=include) if doc_id: hash = doc_id else: @@ -252,7 +269,7 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o return hash - def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1): + def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False): input_files = [] for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False): for file_ in files: @@ -269,9 +286,12 @@ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0. if os.path.exists(data_path): print(data_path, "exists. Skipping it ") continue - - texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size, - perc_overlap=perc_overlap) + include = ["biblio"] if include_biblio else [] + texts, metadata, ids = self.get_text_from_document( + input_file, + chunk_size=chunk_size, + perc_overlap=perc_overlap, + include=include) filename = metadata[0]['filename'] vector_db_document = Chroma.from_texts(texts, diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index e21b1f1..4d8f36b 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -171,7 +171,7 @@ def parse_grobid_xml(self, text): } try: year = dateparser.parse(doc_biblio.header.date).year - biblio["year"] = year + biblio["publication_year"] = year except: pass diff --git a/streamlit_app.py b/streamlit_app.py index f5a9f60..8a4118e 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -288,7 +288,8 @@ def play_old_messages(): # hash = get_file_hash(tmp_file.name)[:10] st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name, chunk_size=chunk_size, - perc_overlap=0.1) + perc_overlap=0.1, + include_biblio=True) st.session_state['loaded_embeddings'] = True st.session_state.messages = []