Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeuralChat] Enable RAG's table extraction and summary #1417

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ function main() {
apt-get install -y libgl1-mesa-glx
apt-get install -y libgl1-mesa-dev
apt-get install libsm6 libxext6 -y
apt-get install tesseract-ocr -y
apt-get install libtesseract-dev -y
apt-get install poppler-utils -y
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
python -m pip install --upgrade --force-reinstall torch==2.2.0
Expand Down
Binary file not shown.
32 changes: 25 additions & 7 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def build_chatbot(config: PipelineConfig=None):
return
from .models.base_model import register_model_adapter
register_model_adapter(adapter)
# register plugin instance in model adaptor
# register plugin instance in model adapter
use_retrieval_plugin = False
if config.plugins:
for plugin_name, plugin_value in config.plugins.items():
enable_plugin = plugin_value.get('enable', False)
Expand Down Expand Up @@ -243,6 +244,10 @@ def build_chatbot(config: PipelineConfig=None):
elif plugin_name == "retrieval":
from .pipeline.plugins.retrieval.retrieval_agent import Agent_QA
plugins[plugin_name]['class'] = Agent_QA
use_retrieval_plugin = True
retrieval_plugin_value = plugin_value
retrieval_plugin_value['args']['table_summary_model_name_or_path'] = config.model_name_or_path
continue
elif plugin_name == "cache":
from .pipeline.plugins.caching.cache import ChatCache
plugins[plugin_name]['class'] = ChatCache
Expand All @@ -267,15 +272,12 @@ def build_chatbot(config: PipelineConfig=None):
try:
plugins[plugin_name]["instance"] = plugins[plugin_name]['class'](**plugin_value['args'])
except Exception as e:
if "[Rereieval ERROR] Document format not supported" in str(e):
set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED)
logger.error("build_chatbot: retrieval plugin init failed")
elif "[SafetyChecker ERROR] Sensitive check file not found" in str(e):
if "[SafetyChecker ERROR] Sensitive check file not found" in str(e):
set_latest_error(ErrorCodes.ERROR_SENSITIVE_CHECK_FILE_NOT_FOUND)
logger.error("build_chatbot: safety checker plugin init failed")
logging.error("build_chatbot: safety checker plugin init failed")
else:
set_latest_error(ErrorCodes.ERROR_GENERIC)
logger.error("build_chatbot: plugin init failed")
logging.error("build_chatbot: plugin init failed")
return
adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"])

Expand Down Expand Up @@ -306,6 +308,22 @@ def build_chatbot(config: PipelineConfig=None):
if config.hf_endpoint_url:
return adapter
adapter.load_model(parameters)

if use_retrieval_plugin:
print(f"create retrieval plugin instance...")
print(f"plugin parameters: ", retrieval_plugin_value['args'])
try:
plugins["retrieval"]["instance"] = plugins["retrieval"]['class'](**retrieval_plugin_value['args'])
except Exception as e:
if "[Rereieval ERROR] Document format not supported" in str(e):
set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED)
logging.error("build_chatbot: retrieval plugin init failed")
else:
set_latest_error(ErrorCodes.ERROR_GENERIC)
logging.error("build_chatbot: plugin init failed")
return
adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"])

if get_latest_error():
return
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ To ensure a smooth experience, we've made sure this plugin is compatible with co
| pdf | No format required |

# Usage
Before using RAG in NeuralChat, please install the necessary dependencies in [requirements.txt](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/requirements.txt) to avoid the import errors. The most convenient way to use is this plugin is via our `build_chatbot` api as introduced in the [example code](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/neural_chat/examples/plugins/retrieval). The user could refer to it for a simple test.
Before using RAG in NeuralChat, please install the necessary dependencies in [requirements.txt](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/requirements.txt) to avoid the import errors. Additionally, for users who take pdf format file as input, the installation of [tesseract](https://tesseract-ocr.github.io/tessdoc/Installation.html) and [poppler](https://poppler.freedesktop.org/) is also required. The most convenient way to use is this plugin is via our `build_chatbot` api as introduced in the [example code](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/neural_chat/examples/plugins/retrieval). The user could refer to it for a simple test.

We support multiple file formats for retrieval, including unstructured file formats such as pdf, docx, html, txt, and markdown, as well as structured file formats like jsonl/json, csv, xlsx. For structured file formats, they must adhere to predefined structures. We also support to upload the knowledge base via a http web link.

Expand Down Expand Up @@ -92,6 +92,7 @@ Below are the description for the available parameters in `agent_QA`,
| enable_rerank | bool | Whether to enable retrieval then rerank pipeline |True, False|
| reranker_model | str | The name of the reranker model from the Huggingface or a local path |-|
| top_n | int | The return number of the reranker model |-|
| table_strategy | str | The strategies to understand tables for table retrieval. As the setting progresses from "fast" to "hq" to "llm," the focus shifts towards deeper table understanding at the expense of processing speed. The default strategy is "fast" |"fast", "hq", "llm"|
Copy link
Collaborator

@XinyuYe-Intel XinyuYe-Intel Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code, seems "fast" table_strategy would only return None instead of table content, is this somewhat unreasonable?

It appears "hq" strategy uses unstructured pkg to extract table, I also used this pkg, and find it actually performed worse than table-transformer.

Also does the "llm" strategy return the reliable table contents? From the code, looks like it uses LLM and a prompt to generate the table summarization of the document, but from my previous experience, such way would generate results that significantly deviate the table content sometimes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for insightful comments, my opinion on these issues are as follows:

From the code, seems "fast" table_strategy would only return None instead of table content, is this somewhat unreasonable?

In fact, by default, our program will use OCR to extract all text information in files including table information, which has been implemented in other PRs. This PR is just to further enhance the understanding of the table, so no content is returned in fast mode (fast mode is also the default mode).

It appears "hq" strategy uses unstructured pkg to extract table, I also used this pkg, and find it actually performed worse than table-transformer.

At present, we do use unstructured to extract table information, and the extraction performance is quite satisfactory. We have not tried the table transformer, but it is indeed worth considering.

Also does the "llm" strategy return the reliable table contents? From the code, looks like it uses LLM and a prompt to generate the table summarization of the document, but from my previous experience, such way would generate results that significantly deviate the table content sometimes.

Your understanding of what llm mode does is correct. It is true that llm's table summary is not completely reliable, but according to the experimental results, there will be much better table QA performance in llm mode overall.


More retriever- and vectorstore-related parameters please check [here](#langchain-extension)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def uni_pro(text):
return filtered_text


def read_pdf(pdf_path):
def read_pdf(pdf_path, table_strategy, table_summary_model_name_or_path):
"""Read the pdf file."""
doc = fitz.open(pdf_path)
reader = easyocr.Reader(['en'], gpu=False)
Expand Down Expand Up @@ -77,7 +77,8 @@ def read_pdf(pdf_path):
else:
pageimg=pageimg+'.'
result=result+pageimg
return result
tables_result = get_tables_result(pdf_path, table_strategy, table_summary_model_name_or_path)
return result, tables_result


def read_html(html_path):
Expand Down Expand Up @@ -214,10 +215,11 @@ def load_structured_data(input, process, max_length, min_length):
content = load_csv(input)
return content

def load_unstructured_data(input):
def load_unstructured_data(input, table_strategy, table_summary_model_name_or_path):
"""Load unstructured context."""
tables = None
if input.endswith("pdf"):
text = read_pdf(input)
text, tables = read_pdf(input, table_strategy, table_summary_model_name_or_path)
elif input.endswith("docx"):
text = read_docx(input)
elif input.endswith("html"):
Expand All @@ -231,7 +233,7 @@ def load_unstructured_data(input):
text = text.replace('\n\n', ' ')
text = uni_pro(text)
text = re.sub(r'\s+', ' ', text)
return text
return text, tables

def get_chuck_data(content, max_length, min_length, input):
"""Process the context to make it maintain a suitable length for the generation."""
Expand All @@ -257,6 +259,84 @@ def get_chuck_data(content, max_length, min_length, input):
return paragraphs



def get_tables_result(pdf_path, table_strategy, table_summary_model_name_or_path):
"""Extract tables information from pdf file."""
if table_strategy == 'fast':
return None

from unstructured.partition.pdf import partition_pdf
from unstructured.documents.elements import FigureCaption
from intel_extension_for_transformers.neural_chat.models.model_utils import predict
from intel_extension_for_transformers.neural_chat.prompts.prompt import TABLESUMMARY_PROMPT

tables_result = []
raw_pdf_elements = partition_pdf(
filename=pdf_path,
infer_table_structure=True,
)
tables = [el for el in raw_pdf_elements if el.category == "Table"]
for table in tables:
table_coords = table.metadata.coordinates.points
content = table.metadata.text_as_html
table_page_number = table.metadata.page_number
min_distance = float('inf')
table_summary = None
if table_strategy == 'hq':
for element in raw_pdf_elements:
if isinstance(element, FigureCaption) or element.text.startswith('Tab'):
caption_page_number = element.metadata.page_number
caption_coords = element.metadata.coordinates.points
related, y_distance = get_relation(table_coords, caption_coords, \
table_page_number, caption_page_number)
if related:
if y_distance < min_distance:
min_distance = y_distance
table_summary = element.text
if table_summary is None:
parent_id = table.metadata.parent_id
for element in raw_pdf_elements:
if element.id == parent_id:
table_summary = element.text
break
elif table_strategy == 'llm':
prompt = TABLESUMMARY_PROMPT.format(table_content=content)
params = {}
params["model_name"] = table_summary_model_name_or_path
params["prompt"] = prompt
params["temperature"] = 0.8
params["top_p"] = 0.9
params["top_k"] = 40
params["max_new_tokens"] = 1000
params["num_beams"] = 2
params["num_return_sequences"] = 2
params["use_cache"] = True
table_summary = predict(**params)
table_summary = table_summary[table_summary.find('### Generated Summary:\n'):]
table_summary = re.sub('### Generated Summary:\n', '', table_summary)
elif table_strategy == None:
table_summary = None
if table_summary is None:
text = f'[Table: {content}]'
else:
text = f'|Table: [Summary: {table_summary}], [Content: {content}]|'
tables_result.append([text, pdf_path])
return tables_result


def get_relation(table_coords, caption_coords, table_page_number, caption_page_number, threshold=100):
"""Get the relation of a pair of table and caption"""
same_page = table_page_number == caption_page_number
x_overlap = (min(table_coords[2][0], caption_coords[2][0]) - max(table_coords[0][0], caption_coords[0][0])) > 0
if table_coords[0][1] - caption_coords[1][1] >= 0:
y_distance = table_coords[0][1] - caption_coords[1][1]
elif caption_coords[0][1] - table_coords[1][1] >= 0:
y_distance = caption_coords[0][1] - table_coords[1][1]
else:
y_distance = 0
y_close = y_distance < threshold
return same_page and x_overlap and y_close, y_distance

def clean_filename(url):
# Characters to be removed or replaced
invalid_chars = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def load(self, input, **kwargs):
self.min_chuck_size = kwargs['min_chuck_size']
if 'process' in kwargs:
self.process = kwargs['process']
self.table_summary_model_name_or_path = kwargs['table_summary_model_name_or_path'] \
if 'table_summary_model_name_or_path' in kwargs else None
self.table_strategy = kwargs['table_strategy'] if 'table_strategy' in kwargs else 'fast'

if isinstance(input, str):
if os.path.isfile(input):
Expand All @@ -74,11 +77,14 @@ def parse_document(self, input):
"""
if input.endswith("pdf") or input.endswith("docx") or input.endswith("html") \
or input.endswith("txt") or input.endswith("md"):
content = load_unstructured_data(input)
content, tables = load_unstructured_data(input, self.table_strategy, \
self.table_summary_model_name_or_path)
if self.process:
chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input)
else:
chuck = [[content.strip(),input]]
if tables is not None:
chuck = chuck + tables
elif input.endswith("jsonl") or input.endswith("xlsx") or input.endswith("csv") or \
input.endswith("json"):
chuck = load_structured_data(input, self.process, \
Expand Down Expand Up @@ -118,11 +124,15 @@ def batch_parse_document(self, input):
for filename in filenames:
if filename.endswith("pdf") or filename.endswith("docx") or filename.endswith("html") \
or filename.endswith("txt") or filename.endswith("md"):
content = load_unstructured_data(os.path.join(dirpath, filename))
content, tables = load_unstructured_data(os.path.join(dirpath, filename), \
self.table_strategy, \
self.table_summary_model_name_or_path)
if self.process:
chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input)
else:
chuck = [[content.strip(),input]]
if tables is not None:
chuck = chuck + tables
paragraphs += chuck
elif filename.endswith("jsonl") or filename.endswith("xlsx") or filename.endswith("csv") or \
filename.endswith("json"):
Expand Down
11 changes: 11 additions & 0 deletions intel_extension_for_transformers/neural_chat/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,14 @@ def generate_sqlcoder_prompt(qurey, metadata_file):
qurey=qurey, table_metadata_string=table_metadata_string
)
return prompt

TABLESUMMARY_PROMPT = """
Task: Your task is to give a concise summary of the table. \
The summary should cover the overall table structure and all detailed information of the table. \
The table will be given in html format. Summarize the table below.
---
### Table:
{table_content}
---
### Generated Summary:
"""
Original file line number Diff line number Diff line change
Expand Up @@ -714,5 +714,86 @@ def test_embedding_precision_fp32(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestHQTableStrategy(unittest.TestCase):
def setUp(self):
if os.path.exists("./hq_table_strategy"):
shutil.rmtree("./hq_table_strategy", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./hq_table_strategy"):
shutil.rmtree("./hq_table_strategy", ignore_errors=True)
return super().tearDown()

def test_hq_table_strategy(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
plugins.retrieval.args["persist_directory"] = "./hq_table_strategy"
plugins.retrieval.args["retrieval_type"] = 'default'
plugins.retrieval.args["table_strategy"] = 'hq'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestLLMTableStrategy(unittest.TestCase):
def setUp(self):
if os.path.exists("./llm_table_strategy"):
shutil.rmtree("./llm_table_strategy", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./llm_table_strategy"):
shutil.rmtree("./llm_table_strategy", ignore_errors=True)
return super().tearDown()

def test_llm_table_strategy(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
plugins.retrieval.args["persist_directory"] = "./llm_table_strategy"
plugins.retrieval.args["retrieval_type"] = 'default'
plugins.retrieval.args["table_strategy"] = 'llm'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestTableSummaryNoneMode(unittest.TestCase):
def setUp(self):
if os.path.exists("./fast_table_strategy"):
shutil.rmtree("./fast_table_strategy", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./fast_table_strategy"):
shutil.rmtree("./fast_table_strategy", ignore_errors=True)
return super().tearDown()

def test_fast_table_strategy(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
plugins.retrieval.args["persist_directory"] = "./fast_table_strategy"
plugins.retrieval.args["retrieval_type"] = 'default'
plugins.retrieval.args["table_strategy"] = 'fast'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ torchvision==0.17.0
tqdm
transformers==4.36.2
transformers_stream_generator
unstructured
unstructured[all-docs]
urllib3
uvicorn
vector_quantize_pytorch
Expand Down
Loading