diff --git a/docs/pydoc/config/preprocessors_api.yml b/docs/pydoc/config/preprocessors_api.yml index abbf221239..a578b73fdc 100644 --- a/docs/pydoc/config/preprocessors_api.yml +++ b/docs/pydoc/config/preprocessors_api.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/preprocessors] - modules: ["csv_document_cleaner", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"] + modules: ["csv_document_cleaner", "csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/preprocessors/__init__.py b/haystack/components/preprocessors/__init__.py index 6836f7d8b5..371e160b49 100644 --- a/haystack/components/preprocessors/__init__.py +++ b/haystack/components/preprocessors/__init__.py @@ -3,9 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 from .csv_document_cleaner import CSVDocumentCleaner +from .csv_document_splitter import CSVDocumentSplitter from .document_cleaner import DocumentCleaner from .document_splitter import DocumentSplitter from .recursive_splitter import RecursiveDocumentSplitter from .text_cleaner import TextCleaner -__all__ = ["CSVDocumentCleaner", "DocumentCleaner", "DocumentSplitter", "RecursiveDocumentSplitter", "TextCleaner"] +__all__ = [ + "CSVDocumentCleaner", + "CSVDocumentSplitter", + "DocumentCleaner", + "DocumentSplitter", + "RecursiveDocumentSplitter", + "TextCleaner", +] diff --git a/haystack/components/preprocessors/csv_document_splitter.py b/haystack/components/preprocessors/csv_document_splitter.py new file mode 100644 index 0000000000..4809bf8381 --- /dev/null +++ b/haystack/components/preprocessors/csv_document_splitter.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from io import StringIO +from typing import Any, Dict, List, Literal, Optional, Tuple + +from haystack import Document, component, logging +from haystack.lazy_imports import LazyImport + +with LazyImport("Run 'pip install pandas'") as pandas_import: + import pandas as pd + +logger = logging.getLogger(__name__) + + +@component +class CSVDocumentSplitter: + """ + A component for splitting CSV documents into sub-tables based on empty rows and columns. + + The splitter identifies consecutive empty rows or columns that exceed a given threshold + and uses them as delimiters to segment the document into smaller tables. + """ + + def __init__( + self, + row_split_threshold: Optional[int] = 2, + column_split_threshold: Optional[int] = 2, + read_csv_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes the CSVDocumentSplitter component. + + :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split. + :param column_split_threshold: The minimum number of consecutive empty columns required to trigger a split. + :param read_csv_kwargs: Additional keyword arguments to pass to `pandas.read_csv`. + By default, the component with options: + - `header=None` + - `skip_blank_lines=False` to preserve blank lines + - `dtype=object` to prevent type inference (e.g., converting numbers to floats). + See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information. + """ + pandas_import.check() + if row_split_threshold is not None and row_split_threshold < 1: + raise ValueError("row_split_threshold must be greater than 0") + + if column_split_threshold is not None and column_split_threshold < 1: + raise ValueError("column_split_threshold must be greater than 0") + + if row_split_threshold is None and column_split_threshold is None: + raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.") + + self.row_split_threshold = row_split_threshold + self.column_split_threshold = column_split_threshold + self.read_csv_kwargs = read_csv_kwargs or {} + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]) -> Dict[str, List[Document]]: + """ + Processes and splits a list of CSV documents into multiple sub-tables. + + **Splitting Process:** + 1. Applies a row-based split if `row_split_threshold` is provided. + 2. Applies a column-based split if `column_split_threshold` is provided. + 3. If both thresholds are specified, performs a recursive split by rows first, then columns, ensuring + further fragmentation of any sub-tables that still contain empty sections. + 4. Sorts the resulting sub-tables based on their original positions within the document. + + :param documents: A list of Documents containing CSV-formatted content. + Each document is assumed to contain one or more tables separated by empty rows or columns. + + :return: + A dictionary with a key `"documents"`, mapping to a list of new `Document` objects, + each representing an extracted sub-table from the original CSV. + The metadata of each document includes: + - A field `source_id` to track the original document. + - A field `row_idx_start` to indicate the starting row index of the sub-table in the original table. + - A field `col_idx_start` to indicate the starting column index of the sub-table in the original table. + - A field `split_id` to indicate the order of the split in the original document. + - All other metadata copied from the original document. + + - If a document cannot be processed, it is returned unchanged. + - The `meta` field from the original document is preserved in the split documents. + """ + if len(documents) == 0: + return {"documents": documents} + + resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs} + + split_documents = [] + for document in documents: + try: + df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore + except Exception as e: + logger.error(f"Error processing document {document.id}. Keeping it, but skipping splitting. Error: {e}") + split_documents.append(document) + continue + + if self.row_split_threshold is not None and self.column_split_threshold is None: + # split by rows + split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row") + elif self.column_split_threshold is not None and self.row_split_threshold is None: + # split by columns + split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column") + else: + # recursive split + split_dfs = self._recursive_split( + df=df, + row_split_threshold=self.row_split_threshold, # type: ignore + column_split_threshold=self.column_split_threshold, # type: ignore + ) + + # Sort split_dfs first by row index, then by column index + split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0])) + + for split_id, split_df in enumerate(split_dfs): + split_documents.append( + Document( + content=split_df.to_csv(index=False, header=False, lineterminator="\n"), + meta={ + **document.meta.copy(), + "source_id": document.id, + "row_idx_start": int(split_df.index[0]), + "col_idx_start": int(split_df.columns[0]), + "split_id": split_id, + }, + ) + ) + + return {"documents": split_documents} + + @staticmethod + def _find_split_indices( + df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"] + ) -> List[Tuple[int, int]]: + """ + Finds the indices of consecutive empty rows or columns in a DataFrame. + + :param df: DataFrame to split. + :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split. + :param axis: Axis along which to find empty elements. Either "row" or "column". + :return: List of indices where consecutive empty rows or columns start. + """ + if axis == "row": + empty_elements = df[df.isnull().all(axis=1)].index.tolist() + else: + empty_elements = df.columns[df.isnull().all(axis=0)].tolist() + + # If no empty elements found, return empty list + if len(empty_elements) == 0: + return [] + + # Identify groups of consecutive empty elements + split_indices = [] + consecutive_count = 1 + start_index = empty_elements[0] + + for i in range(1, len(empty_elements)): + if empty_elements[i] == empty_elements[i - 1] + 1: + consecutive_count += 1 + else: + if consecutive_count >= split_threshold: + split_indices.append((start_index, empty_elements[i - 1])) + consecutive_count = 1 + start_index = empty_elements[i] + + # Handle the last group of consecutive elements + if consecutive_count >= split_threshold: + split_indices.append((start_index, empty_elements[-1])) + + return split_indices + + def _split_dataframe( + self, df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"] + ) -> List["pd.DataFrame"]: + """ + Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`. + + :param df: DataFrame to split. + :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split. + :param axis: Axis along which to split. Either "row" or "column". + :return: List of split DataFrames. + """ + # Find indices of consecutive empty rows or columns + split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis) + + # If no split_indices are found, return the original DataFrame + if len(split_indices) == 0: + return [df] + + # Split the DataFrame at identified indices + sub_tables = [] + table_start_idx = 0 + df_length = df.shape[0] if axis == "row" else df.shape[1] + for empty_start_idx, empty_end_idx in split_indices + [(df_length, df_length)]: + # Avoid empty splits + if empty_start_idx - table_start_idx > 1: + if axis == "row": + sub_table = df.iloc[table_start_idx:empty_start_idx] + else: + sub_table = df.iloc[:, table_start_idx:empty_start_idx] + if not sub_table.empty: + sub_tables.append(sub_table) + table_start_idx = empty_end_idx + 1 + + return sub_tables + + def _recursive_split( + self, df: "pd.DataFrame", row_split_threshold: int, column_split_threshold: int + ) -> List["pd.DataFrame"]: + """ + Recursively splits a DataFrame. + + Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process + until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table. + + :param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV. + :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split. + :param column_split_threshold: The minimum number of consecutive empty columns to trigger a split. + """ + + # Step 1: Split by rows + new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row") + + # Step 2: Split by columns + final_tables = [] + for table in new_sub_tables: + final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column")) + + # Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split + result = [] + for table in final_tables: + # Check if there are consecutive rows >= row_split_threshold now present + if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0: + result.extend( + self._recursive_split( + df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold + ) + ) + else: + result.append(table) + + return result diff --git a/releasenotes/notes/csv-document-splitter-426dcc0392c08f62.yaml b/releasenotes/notes/csv-document-splitter-426dcc0392c08f62.yaml new file mode 100644 index 0000000000..9f59c03d12 --- /dev/null +++ b/releasenotes/notes/csv-document-splitter-426dcc0392c08f62.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Introducing CSVDocumentSplitter: The CSVDocumentSplitter splits CSV documents into structured sub-tables by recursively splitting by empty rows and columns larger than a specified threshold. + This is particularly useful when converting Excel files which can often have multiple tables within one sheet. diff --git a/test/components/preprocessors/test_csv_document_splitter.py b/test/components/preprocessors/test_csv_document_splitter.py new file mode 100644 index 0000000000..e94efd349a --- /dev/null +++ b/test/components/preprocessors/test_csv_document_splitter.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import pandas as pd +from io import StringIO +from haystack import Document, Pipeline +from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter + + +@pytest.fixture +def splitter() -> CSVDocumentSplitter: + return CSVDocumentSplitter() + + +@pytest.fixture +def two_tables_sep_by_two_empty_rows() -> str: + return """A,B,C +1,2,3 +,, +,, +X,Y,Z +7,8,9 +""" + + +@pytest.fixture +def three_tables_sep_by_empty_rows() -> str: + return """A,B,C +,, +1,2,3 +,, +,, +X,Y,Z +7,8,9 +""" + + +@pytest.fixture +def two_tables_sep_by_two_empty_columns() -> str: + return """A,B,,,X,Y +1,2,,,7,8 +3,4,,,9,10 +""" + + +class TestFindSplitIndices: + def test_find_split_indices_row_two_tables( + self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str + ) -> None: + df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_rows), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=2, axis="row") + assert result == [(2, 3)] + + def test_find_split_indices_row_two_tables_with_empty_row( + self, splitter: CSVDocumentSplitter, three_tables_sep_by_empty_rows: str + ) -> None: + df = pd.read_csv(StringIO(three_tables_sep_by_empty_rows), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=2, axis="row") + assert result == [(3, 4)] + + def test_find_split_indices_row_three_tables(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,B,C +1,2,3 +,, +,, +X,Y,Z +7,8,9 +,, +,, +P,Q,R +""" + df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=2, axis="row") + assert result == [(2, 3), (6, 7)] + + def test_find_split_indices_column_two_tables( + self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str + ) -> None: + df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_columns), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=1, axis="column") + assert result == [(2, 3)] + + def test_find_split_indices_column_two_tables_with_empty_column(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,,B,,,X,Y +1,,2,,,7,8 +3,,4,,,9,10 +""" + df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=2, axis="column") + assert result == [(3, 4)] + + def test_find_split_indices_column_three_tables(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,B,,,X,Y,,,P,Q +1,2,,,7,8,,,11,12 +3,4,,,9,10,,,13,14 +""" + df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore + result = splitter._find_split_indices(df, split_threshold=2, axis="column") + assert result == [(2, 3), (6, 7)] + + +class TestInit: + def test_row_split_threshold_raises_error(self) -> None: + with pytest.raises(ValueError, match="row_split_threshold must be greater than 0"): + CSVDocumentSplitter(row_split_threshold=-1) + + def test_column_split_threshold_raises_error(self) -> None: + with pytest.raises(ValueError, match="column_split_threshold must be greater than 0"): + CSVDocumentSplitter(column_split_threshold=-1) + + def test_row_split_threshold_and_row_column_threshold_none(self) -> None: + with pytest.raises( + ValueError, match="At least one of row_split_threshold or column_split_threshold must be specified." + ): + CSVDocumentSplitter(row_split_threshold=None, column_split_threshold=None) + + +class TestCSVDocumentSplitter: + def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,B,C +1,2,3 +4,5,6 +""" + doc = Document(content=csv_content, id="test_id") + result = splitter.run([doc])["documents"] + assert len(result) == 1 + assert result[0].content == csv_content + assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0} + + def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None: + doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id") + result = splitter.run([doc])["documents"] + assert len(result) == 2 + expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1}, + ] + for i, table in enumerate(result): + assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] + + def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None: + doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id") + result = splitter.run([doc])["documents"] + assert len(result) == 2 + expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + ] + for i, table in enumerate(result): + assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] + + def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,B,,,X,Y +1,2,,,7,8 +,,,,, +,,,,, +P,Q,,,M,N +3,4,,,9,10 +""" + doc = Document(content=csv_content, id="test_id") + result = splitter.run([doc])["documents"] + assert len(result) == 4 + expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3}, + ] + for i, table in enumerate(result): + assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] + + def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None: + csv_content = """A,B,,,X,Y +1,2,,,7,8 +,,,,M,N +,,,,9,10 +P,Q,,,, +3,4,,,, +""" + doc = Document(content=csv_content, id="test_id") + result = splitter.run([doc])["documents"] + assert len(result) == 3 + expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2}, + ] + for i, table in enumerate(result): + assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] + + def test_csv_with_blank_lines(self, splitter: CSVDocumentSplitter) -> None: + csv_data = """ID,LeftVal,,,RightVal,Extra +1,Hello,,,World,Joined +2,StillLeft,,,StillRight,Bridge + +A,B,,,C,D +E,F,,,G,H +""" + splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=1) + result = splitter.run([Document(content=csv_data, id="test_id")]) + docs = result["documents"] + assert len(docs) == 4 + expected_tables = [ + "ID,LeftVal\n1,Hello\n2,StillLeft\n", + "RightVal,Extra\nWorld,Joined\nStillRight,Bridge\n", + "A,B\nE,F\n", + "C,D\nG,H\n", + ] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3}, + ] + for i, table in enumerate(docs): + assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] + + def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None: + splitter = CSVDocumentSplitter(row_split_threshold=3) + doc = Document(content=two_tables_sep_by_two_empty_rows) + result = splitter.run([doc])["documents"] + assert len(result) == 1 + + def test_empty_input(self, splitter: CSVDocumentSplitter) -> None: + csv_content = "" + doc = Document(content=csv_content) + result = splitter.run([doc])["documents"] + assert len(result) == 1 + assert result[0].content == csv_content + + def test_empty_documents(self, splitter: CSVDocumentSplitter) -> None: + result = splitter.run([])["documents"] + assert len(result) == 0 + + def test_to_dict_with_defaults(self) -> None: + splitter = CSVDocumentSplitter() + config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter") + config = { + "type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter", + "init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}}, + } + assert config_serialized == config + + def test_to_dict_non_defaults(self) -> None: + splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=None, read_csv_kwargs={"sep": ";"}) + config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter") + config = { + "type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter", + "init_parameters": { + "row_split_threshold": 1, + "column_split_threshold": None, + "read_csv_kwargs": {"sep": ";"}, + }, + } + assert config_serialized == config + + def test_from_dict_defaults(self) -> None: + splitter = component_from_dict( + CSVDocumentSplitter, + data={ + "type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter", + "init_parameters": {}, + }, + name="CSVDocumentSplitter", + ) + assert splitter.row_split_threshold == 2 + assert splitter.column_split_threshold == 2 + assert splitter.read_csv_kwargs == {} + + def test_from_dict_non_defaults(self) -> None: + splitter = component_from_dict( + CSVDocumentSplitter, + data={ + "type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter", + "init_parameters": { + "row_split_threshold": 1, + "column_split_threshold": None, + "read_csv_kwargs": {"sep": ";"}, + }, + }, + name="CSVDocumentSplitter", + ) + assert splitter.row_split_threshold == 1 + assert splitter.column_split_threshold is None + assert splitter.read_csv_kwargs == {"sep": ";"}