diff --git a/.github/workflows/script/unitTest/run_unit_test_neuralchat.sh b/.github/workflows/script/unitTest/run_unit_test_neuralchat.sh index 07e601a5e69..c47c0aa147d 100644 --- a/.github/workflows/script/unitTest/run_unit_test_neuralchat.sh +++ b/.github/workflows/script/unitTest/run_unit_test_neuralchat.sh @@ -87,6 +87,9 @@ function main() { apt-get install -y libgl1-mesa-glx apt-get install -y libgl1-mesa-dev apt-get install libsm6 libxext6 -y + apt-get install tesseract-ocr -y + apt-get install libtesseract-dev -y + apt-get install poppler-utils -y wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb python -m pip install --upgrade --force-reinstall torch==2.2.0 diff --git a/intel_extension_for_transformers/neural_chat/assets/docs/LLAMA2_page6.pdf b/intel_extension_for_transformers/neural_chat/assets/docs/LLAMA2_page6.pdf new file mode 100644 index 00000000000..e741ee1993e Binary files /dev/null and b/intel_extension_for_transformers/neural_chat/assets/docs/LLAMA2_page6.pdf differ diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index f523f092a16..a7f0c3cc5d7 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -178,7 +178,8 @@ def build_chatbot(config: PipelineConfig=None): return from .models.base_model import register_model_adapter register_model_adapter(adapter) - # register plugin instance in model adaptor + # register plugin instance in model adapter + use_retrieval_plugin = False if config.plugins: for plugin_name, plugin_value in config.plugins.items(): enable_plugin = plugin_value.get('enable', False) @@ -243,6 +244,10 @@ def build_chatbot(config: PipelineConfig=None): elif plugin_name == "retrieval": from .pipeline.plugins.retrieval.retrieval_agent import Agent_QA plugins[plugin_name]['class'] = Agent_QA + use_retrieval_plugin = True + retrieval_plugin_value = plugin_value + retrieval_plugin_value['args']['table_summary_model_name_or_path'] = config.model_name_or_path + continue elif plugin_name == "cache": from .pipeline.plugins.caching.cache import ChatCache plugins[plugin_name]['class'] = ChatCache @@ -267,15 +272,12 @@ def build_chatbot(config: PipelineConfig=None): try: plugins[plugin_name]["instance"] = plugins[plugin_name]['class'](**plugin_value['args']) except Exception as e: - if "[Rereieval ERROR] Document format not supported" in str(e): - set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED) - logger.error("build_chatbot: retrieval plugin init failed") - elif "[SafetyChecker ERROR] Sensitive check file not found" in str(e): + if "[SafetyChecker ERROR] Sensitive check file not found" in str(e): set_latest_error(ErrorCodes.ERROR_SENSITIVE_CHECK_FILE_NOT_FOUND) - logger.error("build_chatbot: safety checker plugin init failed") + logging.error("build_chatbot: safety checker plugin init failed") else: set_latest_error(ErrorCodes.ERROR_GENERIC) - logger.error("build_chatbot: plugin init failed") + logging.error("build_chatbot: plugin init failed") return adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"]) @@ -306,6 +308,22 @@ def build_chatbot(config: PipelineConfig=None): if config.hf_endpoint_url: return adapter adapter.load_model(parameters) + + if use_retrieval_plugin: + print(f"create retrieval plugin instance...") + print(f"plugin parameters: ", retrieval_plugin_value['args']) + try: + plugins["retrieval"]["instance"] = plugins["retrieval"]['class'](**retrieval_plugin_value['args']) + except Exception as e: + if "[Rereieval ERROR] Document format not supported" in str(e): + set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED) + logging.error("build_chatbot: retrieval plugin init failed") + else: + set_latest_error(ErrorCodes.ERROR_GENERIC) + logging.error("build_chatbot: plugin init failed") + return + adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"]) + if get_latest_error(): return else: diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/README.md b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/README.md index 87f9e984cc3..11f3d15c580 100644 --- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/README.md +++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/README.md @@ -42,7 +42,7 @@ To ensure a smooth experience, we've made sure this plugin is compatible with co | pdf | No format required | # Usage -Before using RAG in NeuralChat, please install the necessary dependencies in [requirements.txt](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/requirements.txt) to avoid the import errors. The most convenient way to use is this plugin is via our `build_chatbot` api as introduced in the [example code](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/neural_chat/examples/plugins/retrieval). The user could refer to it for a simple test. +Before using RAG in NeuralChat, please install the necessary dependencies in [requirements.txt](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/requirements.txt) to avoid the import errors. Additionally, for users who take pdf format file as input, the installation of [tesseract](https://tesseract-ocr.github.io/tessdoc/Installation.html) and [poppler](https://poppler.freedesktop.org/) is also required. The most convenient way to use is this plugin is via our `build_chatbot` api as introduced in the [example code](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/neural_chat/examples/plugins/retrieval). The user could refer to it for a simple test. We support multiple file formats for retrieval, including unstructured file formats such as pdf, docx, html, txt, and markdown, as well as structured file formats like jsonl/json, csv, xlsx. For structured file formats, they must adhere to predefined structures. We also support to upload the knowledge base via a http web link. @@ -92,6 +92,7 @@ Below are the description for the available parameters in `agent_QA`, | enable_rerank | bool | Whether to enable retrieval then rerank pipeline |True, False| | reranker_model | str | The name of the reranker model from the Huggingface or a local path |-| | top_n | int | The return number of the reranker model |-| +| table_strategy | str | The strategies to understand tables for table retrieval. As the setting progresses from "fast" to "hq" to "llm," the focus shifts towards deeper table understanding at the expense of processing speed. The default strategy is "fast" |"fast", "hq", "llm"| More retriever- and vectorstore-related parameters please check [here](#langchain-extension) diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/context_utils.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/context_utils.py index 91b8de54cb4..ebdb8421002 100644 --- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/context_utils.py +++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/context_utils.py @@ -48,7 +48,7 @@ def uni_pro(text): return filtered_text -def read_pdf(pdf_path): +def read_pdf(pdf_path, table_strategy, table_summary_model_name_or_path): """Read the pdf file.""" doc = fitz.open(pdf_path) reader = easyocr.Reader(['en'], gpu=False) @@ -77,7 +77,8 @@ def read_pdf(pdf_path): else: pageimg=pageimg+'.' result=result+pageimg - return result + tables_result = get_tables_result(pdf_path, table_strategy, table_summary_model_name_or_path) + return result, tables_result def read_html(html_path): @@ -214,10 +215,11 @@ def load_structured_data(input, process, max_length, min_length): content = load_csv(input) return content -def load_unstructured_data(input): +def load_unstructured_data(input, table_strategy, table_summary_model_name_or_path): """Load unstructured context.""" + tables = None if input.endswith("pdf"): - text = read_pdf(input) + text, tables = read_pdf(input, table_strategy, table_summary_model_name_or_path) elif input.endswith("docx"): text = read_docx(input) elif input.endswith("html"): @@ -231,7 +233,7 @@ def load_unstructured_data(input): text = text.replace('\n\n', ' ') text = uni_pro(text) text = re.sub(r'\s+', ' ', text) - return text + return text, tables def get_chuck_data(content, max_length, min_length, input): """Process the context to make it maintain a suitable length for the generation.""" @@ -257,6 +259,84 @@ def get_chuck_data(content, max_length, min_length, input): return paragraphs + +def get_tables_result(pdf_path, table_strategy, table_summary_model_name_or_path): + """Extract tables information from pdf file.""" + if table_strategy == 'fast': + return None + + from unstructured.partition.pdf import partition_pdf + from unstructured.documents.elements import FigureCaption + from intel_extension_for_transformers.neural_chat.models.model_utils import predict + from intel_extension_for_transformers.neural_chat.prompts.prompt import TABLESUMMARY_PROMPT + + tables_result = [] + raw_pdf_elements = partition_pdf( + filename=pdf_path, + infer_table_structure=True, + ) + tables = [el for el in raw_pdf_elements if el.category == "Table"] + for table in tables: + table_coords = table.metadata.coordinates.points + content = table.metadata.text_as_html + table_page_number = table.metadata.page_number + min_distance = float('inf') + table_summary = None + if table_strategy == 'hq': + for element in raw_pdf_elements: + if isinstance(element, FigureCaption) or element.text.startswith('Tab'): + caption_page_number = element.metadata.page_number + caption_coords = element.metadata.coordinates.points + related, y_distance = get_relation(table_coords, caption_coords, \ + table_page_number, caption_page_number) + if related: + if y_distance < min_distance: + min_distance = y_distance + table_summary = element.text + if table_summary is None: + parent_id = table.metadata.parent_id + for element in raw_pdf_elements: + if element.id == parent_id: + table_summary = element.text + break + elif table_strategy == 'llm': + prompt = TABLESUMMARY_PROMPT.format(table_content=content) + params = {} + params["model_name"] = table_summary_model_name_or_path + params["prompt"] = prompt + params["temperature"] = 0.8 + params["top_p"] = 0.9 + params["top_k"] = 40 + params["max_new_tokens"] = 1000 + params["num_beams"] = 2 + params["num_return_sequences"] = 2 + params["use_cache"] = True + table_summary = predict(**params) + table_summary = table_summary[table_summary.find('### Generated Summary:\n'):] + table_summary = re.sub('### Generated Summary:\n', '', table_summary) + elif table_strategy == None: + table_summary = None + if table_summary is None: + text = f'[Table: {content}]' + else: + text = f'|Table: [Summary: {table_summary}], [Content: {content}]|' + tables_result.append([text, pdf_path]) + return tables_result + + +def get_relation(table_coords, caption_coords, table_page_number, caption_page_number, threshold=100): + """Get the relation of a pair of table and caption""" + same_page = table_page_number == caption_page_number + x_overlap = (min(table_coords[2][0], caption_coords[2][0]) - max(table_coords[0][0], caption_coords[0][0])) > 0 + if table_coords[0][1] - caption_coords[1][1] >= 0: + y_distance = table_coords[0][1] - caption_coords[1][1] + elif caption_coords[0][1] - table_coords[1][1] >= 0: + y_distance = caption_coords[0][1] - table_coords[1][1] + else: + y_distance = 0 + y_close = y_distance < threshold + return same_page and x_overlap and y_close, y_distance + def clean_filename(url): # Characters to be removed or replaced invalid_chars = { diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/parser.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/parser.py index 3337470f430..810869b492f 100644 --- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/parser.py +++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/parser.py @@ -49,6 +49,9 @@ def load(self, input, **kwargs): self.min_chuck_size = kwargs['min_chuck_size'] if 'process' in kwargs: self.process = kwargs['process'] + self.table_summary_model_name_or_path = kwargs['table_summary_model_name_or_path'] \ + if 'table_summary_model_name_or_path' in kwargs else None + self.table_strategy = kwargs['table_strategy'] if 'table_strategy' in kwargs else 'fast' if isinstance(input, str): if os.path.isfile(input): @@ -74,11 +77,14 @@ def parse_document(self, input): """ if input.endswith("pdf") or input.endswith("docx") or input.endswith("html") \ or input.endswith("txt") or input.endswith("md"): - content = load_unstructured_data(input) + content, tables = load_unstructured_data(input, self.table_strategy, \ + self.table_summary_model_name_or_path) if self.process: chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input) else: chuck = [[content.strip(),input]] + if tables is not None: + chuck = chuck + tables elif input.endswith("jsonl") or input.endswith("xlsx") or input.endswith("csv") or \ input.endswith("json"): chuck = load_structured_data(input, self.process, \ @@ -118,11 +124,15 @@ def batch_parse_document(self, input): for filename in filenames: if filename.endswith("pdf") or filename.endswith("docx") or filename.endswith("html") \ or filename.endswith("txt") or filename.endswith("md"): - content = load_unstructured_data(os.path.join(dirpath, filename)) + content, tables = load_unstructured_data(os.path.join(dirpath, filename), \ + self.table_strategy, \ + self.table_summary_model_name_or_path) if self.process: chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input) else: chuck = [[content.strip(),input]] + if tables is not None: + chuck = chuck + tables paragraphs += chuck elif filename.endswith("jsonl") or filename.endswith("xlsx") or filename.endswith("csv") or \ filename.endswith("json"): diff --git a/intel_extension_for_transformers/neural_chat/prompts/prompt.py b/intel_extension_for_transformers/neural_chat/prompts/prompt.py index 6a1da02af39..f1f2b6fb0f1 100644 --- a/intel_extension_for_transformers/neural_chat/prompts/prompt.py +++ b/intel_extension_for_transformers/neural_chat/prompts/prompt.py @@ -321,3 +321,14 @@ def generate_sqlcoder_prompt(qurey, metadata_file): qurey=qurey, table_metadata_string=table_metadata_string ) return prompt + +TABLESUMMARY_PROMPT = """ +Task: Your task is to give a concise summary of the table. \ +The summary should cover the overall table structure and all detailed information of the table. \ +The table will be given in html format. Summarize the table below. +--- +### Table: +{table_content} +--- +### Generated Summary: +""" diff --git a/intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py b/intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py index 3de0f5a6567..6f0034e10ae 100644 --- a/intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py +++ b/intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py @@ -714,5 +714,86 @@ def test_embedding_precision_fp32(self): plugins.retrieval.args = {} plugins.retrieval.enable = False +class TestHQTableStrategy(unittest.TestCase): + def setUp(self): + if os.path.exists("./hq_table_strategy"): + shutil.rmtree("./hq_table_strategy", ignore_errors=True) + return super().setUp() + + def tearDown(self) -> None: + if os.path.exists("./hq_table_strategy"): + shutil.rmtree("./hq_table_strategy", ignore_errors=True) + return super().tearDown() + + def test_hq_table_strategy(self): + plugins.retrieval.args = {} + plugins.retrieval.enable = True + plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf" + plugins.retrieval.args["persist_directory"] = "./hq_table_strategy" + plugins.retrieval.args["retrieval_type"] = 'default' + plugins.retrieval.args["table_strategy"] = 'hq' + config = PipelineConfig(model_name_or_path="facebook/opt-125m", + plugins=plugins) + chatbot = build_chatbot(config) + response = chatbot.predict("What is the number of training tokens for LLaMA2?") + print(response) + self.assertIsNotNone(response) + plugins.retrieval.args = {} + plugins.retrieval.enable = False + +class TestLLMTableStrategy(unittest.TestCase): + def setUp(self): + if os.path.exists("./llm_table_strategy"): + shutil.rmtree("./llm_table_strategy", ignore_errors=True) + return super().setUp() + + def tearDown(self) -> None: + if os.path.exists("./llm_table_strategy"): + shutil.rmtree("./llm_table_strategy", ignore_errors=True) + return super().tearDown() + + def test_llm_table_strategy(self): + plugins.retrieval.args = {} + plugins.retrieval.enable = True + plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf" + plugins.retrieval.args["persist_directory"] = "./llm_table_strategy" + plugins.retrieval.args["retrieval_type"] = 'default' + plugins.retrieval.args["table_strategy"] = 'llm' + config = PipelineConfig(model_name_or_path="facebook/opt-125m", + plugins=plugins) + chatbot = build_chatbot(config) + response = chatbot.predict("What is the number of training tokens for LLaMA2?") + print(response) + self.assertIsNotNone(response) + plugins.retrieval.args = {} + plugins.retrieval.enable = False + +class TestTableSummaryNoneMode(unittest.TestCase): + def setUp(self): + if os.path.exists("./fast_table_strategy"): + shutil.rmtree("./fast_table_strategy", ignore_errors=True) + return super().setUp() + + def tearDown(self) -> None: + if os.path.exists("./fast_table_strategy"): + shutil.rmtree("./fast_table_strategy", ignore_errors=True) + return super().tearDown() + + def test_fast_table_strategy(self): + plugins.retrieval.args = {} + plugins.retrieval.enable = True + plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf" + plugins.retrieval.args["persist_directory"] = "./fast_table_strategy" + plugins.retrieval.args["retrieval_type"] = 'default' + plugins.retrieval.args["table_strategy"] = 'fast' + config = PipelineConfig(model_name_or_path="facebook/opt-125m", + plugins=plugins) + chatbot = build_chatbot(config) + response = chatbot.predict("What is the number of training tokens for LLaMA2?") + print(response) + self.assertIsNotNone(response) + plugins.retrieval.args = {} + plugins.retrieval.enable = False + if __name__ == '__main__': unittest.main() diff --git a/intel_extension_for_transformers/neural_chat/tests/requirements.txt b/intel_extension_for_transformers/neural_chat/tests/requirements.txt index 8d79ecdd5d5..c1e8414351f 100644 --- a/intel_extension_for_transformers/neural_chat/tests/requirements.txt +++ b/intel_extension_for_transformers/neural_chat/tests/requirements.txt @@ -84,7 +84,7 @@ torchvision==0.17.0 tqdm transformers==4.36.2 transformers_stream_generator -unstructured +unstructured[all-docs] urllib3 uvicorn vector_quantize_pytorch