Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include biblio in embeddings #21

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Union, Any

from document_qa.grobid_processors import GrobidProcessor
from grobid_client.grobid_client import GrobidClient
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain
Expand All @@ -12,8 +13,6 @@
from langchain.vectorstores import Chroma
from tqdm import tqdm

from document_qa.grobid_processors import GrobidProcessor


class DocumentQAEngine:
llm = None
Expand Down Expand Up @@ -173,8 +172,10 @@ def _get_context_multiquery(self, doc_id, query, context_size=4):
relevant_documents = multi_query_retriever.get_relevant_documents(query)
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False):
"""
Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
"""
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
Expand All @@ -189,6 +190,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
texts = []
metadatas = []
ids = []

if chunk_size < 0:
for passage in structure['passages']:
biblio_copy = copy.copy(biblio)
Expand All @@ -212,10 +214,25 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
metadatas = [biblio for _ in range(len(texts))]
ids = [id for id, t in enumerate(texts)]

if "biblio" in include:
biblio_metadata = copy.copy(biblio)
biblio_metadata['type'] = "biblio"
biblio_metadata['section'] = "header"
for key in ['title', 'authors', 'publication_year']:
if key in biblio_metadata:
texts.append("{}: {}".format(key, biblio_metadata[key]))
metadatas.append(biblio_metadata)
ids.append(key)

return texts, metadatas, ids

def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False):
include = ["biblio"] if include_biblio else []
texts, metadata, ids = self.get_text_from_document(
pdf_path,
chunk_size=chunk_size,
perc_overlap=perc_overlap,
include=include)
if doc_id:
hash = doc_id
else:
Expand All @@ -233,7 +250,7 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o

return hash

def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
input_files = []
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
for file_ in files:
Expand All @@ -250,9 +267,12 @@ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.
if os.path.exists(data_path):
print(data_path, "exists. Skipping it ")
continue

texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
perc_overlap=perc_overlap)
include = ["biblio"] if include_biblio else []
texts, metadata, ids = self.get_text_from_document(
input_file,
chunk_size=chunk_size,
perc_overlap=perc_overlap,
include=include)
filename = metadata[0]['filename']

vector_db_document = Chroma.from_texts(texts,
Expand Down
2 changes: 1 addition & 1 deletion document_qa/grobid_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def parse_grobid_xml(self, text):
}
try:
year = dateparser.parse(doc_biblio.header.date).year
biblio["year"] = year
biblio["publication_year"] = year
except:
pass

Expand Down
3 changes: 2 additions & 1 deletion streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def play_old_messages():
# hash = get_file_hash(tmp_file.name)[:10]
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1)
perc_overlap=0.1,
include_biblio=True)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []

Expand Down
Loading