Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new component CSVDocumentSplitter to recursively split CSV documents #8815

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pydoc/config/preprocessors_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/preprocessors]
modules: ["document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
modules: ["csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
3 changes: 2 additions & 1 deletion haystack/components/preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

from .csv_document_splitter import CSVDocumentSplitter
from .document_cleaner import DocumentCleaner
from .document_splitter import DocumentSplitter
from .recursive_splitter import RecursiveDocumentSplitter
from .text_cleaner import TextCleaner

__all__ = ["DocumentSplitter", "DocumentCleaner", "RecursiveDocumentSplitter", "TextCleaner"]
__all__ = ["DocumentSplitter", "DocumentCleaner", "RecursiveDocumentSplitter", "TextCleaner", "CSVDocumentSplitter"]
188 changes: 188 additions & 0 deletions haystack/components/preprocessors/csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from io import StringIO
from typing import Dict, List, Literal

import pandas as pd

from haystack import Document, component, logging

logger = logging.getLogger(__name__)


@component
class CSVDocumentSplitter:
"""
A component for splitting CSV documents
"""

def __init__(self, row_split_threshold: int = 2, column_split_threshold: int = 2) -> None:
"""
Initializes the CSVDocumentSplitter component.

:param row_split_threshold:
The minimum number of consecutive empty rows required to trigger a split.
A higher threshold prevents excessive splitting, while a lower threshold may lead
to more fragmented sub-tables.
:param column_split_threshold:
The minimum number of consecutive empty columns required to trigger a split.
A higher threshold prevents excessive splitting, while a lower threshold may lead
to more fragmented sub-tables.
"""
if row_split_threshold < 1:
raise ValueError("split_threshold must be greater than 0")
self.row_split_threshold = row_split_threshold
if column_split_threshold < 1:
raise ValueError("split_threshold must be greater than 0")
self.column_split_threshold = column_split_threshold

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
"""
Processes and splits a list of CSV documents into multiple sub-tables.

**Splitting Process:**
1. Row Splitting: Detects empty rows and separates tables stacked vertically.
2. Column Splitting: Detects empty columns and separates side-by-side tables.
3. Recursive Row Check: After splitting by columns, it checks for new row splits
introduced by the column split.

:param documents: A list of Documents containing CSV-formatted content.
Each document is assumed to contain one or more tables separated by empty rows or columns.

:return:
A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
each representing an extracted sub-table from the original CSV.

- If a document cannot be processed, it is returned unchanged.
- The `meta` field from the original document is preserved in the split documents.
"""
cleaned_documents = []
for document in documents:
try:
df = pd.read_csv(StringIO(document.content), header=None, dtype=object) # type: ignore
except Exception as e:
logger.error(f"Error processing document {document.id}. Keeping it, but skipping splitting. Error: {e}")
cleaned_documents.append(document)
continue

split_dfs = self._recursive_split(
df=df, row_split_threshold=self.row_split_threshold, column_split_threshold=self.column_split_threshold
)
for split_df in split_dfs:
cleaned_documents.append(
Document(
content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
meta=document.meta.copy(),
)
)

return {"documents": cleaned_documents}

def _find_split_indices(self, df: pd.DataFrame, split_threshold: int, axis: Literal["row", "column"]) -> List[int]:
"""
Finds the indices of consecutive empty rows or columns in a DataFrame.

:param df: DataFrame to split.
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
:param axis: Axis along which to find empty elements. Either "row" or "column".
:return: List of indices where consecutive empty rows or columns start.
"""
if axis == "row":
empty_elements = df[df.isnull().all(axis=1)].index.tolist()
else:
empty_elements = df.columns[df.isnull().all(axis=0)].tolist()

# Identify groups of consecutive empty elements
split_indices = []
consecutive_count = 1
for i in range(1, len(empty_elements)):
if empty_elements[i] == empty_elements[i - 1] + 1:
consecutive_count += 1
else:
if consecutive_count >= split_threshold:
split_indices.append(empty_elements[i - 1])
consecutive_count = 1

if consecutive_count >= split_threshold:
split_indices.append(empty_elements[-1])

return split_indices

def _split_dataframe(
self, df: pd.DataFrame, split_threshold: int, axis: Literal["row", "column"]
) -> List[pd.DataFrame]:
"""
Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`.

:param df: DataFrame to split.
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
:param axis: Axis along which to split. Either "row" or "column".
:return: List of split DataFrames.
"""
# Find indices of consecutive empty rows or columns
split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis)

# Split the DataFrame at identified indices
sub_tables = []
start_idx = 0
df_length = df.shape[0] if axis == "row" else df.shape[1]
for end_idx in split_indices + [df_length]:
# Avoid empty splits
if end_idx - start_idx > 1:
if axis == "row":
sub_table = df.iloc[start_idx:end_idx].dropna(how="all", axis=0)
else:
sub_table = df.iloc[:, start_idx:end_idx].dropna(how="all", axis=1)
if not sub_table.empty:
sub_tables.append(sub_table)
start_idx = end_idx + 1

return sub_tables

def _recursive_split(
self, df: pd.DataFrame, row_split_threshold: int, column_split_threshold: int
) -> List[pd.DataFrame]:
"""
Recursively splits a DataFrame.

Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process
until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table.

:param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV.
:param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
:param column_split_threshold: The minimum number of consecutive empty columns to trigger a split.

**Splitting Process:**
1. Row Splitting: Detects empty rows and separates tables stacked vertically.
2. Column Splitting: Detects empty columns and separates side-by-side tables.
3. Recursive Row Check: After splitting by columns, it checks for new row splits
introduced by the column split.

Termination Condition: If no further splits are detected, the recursion stops.
"""

# Step 1: Split by rows
new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row")

# Step 2: Split by columns
final_tables = []
for table in new_sub_tables:
final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column"))

# Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split
result = []
for table in final_tables:
# Check if there are consecutive rows >= row_split_threshold now present
if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0:
result.extend(
self._recursive_split(
df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold
)
)
else:
result.append(table)

return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Introducing CSVDocumentSplitter: The CSVDocumentSplitter splits CSV documents into structured sub-tables by recursively splitting by empty rows and columns larger than a specified threshold.
This is particularly useful when converting Excel files which can often have multiple tables within one sheet.
90 changes: 90 additions & 0 deletions test/components/converters/test_csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from haystack import Document
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter


@pytest.fixture
def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()


def test_single_table_no_split(splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
4,5,6
"""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content


def test_row_split(splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
for i, table in enumerate(result):
assert table.content == expected_tables[i]


def test_column_split(splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
3,4,,,9,10
"""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
for i, table in enumerate(result):
assert table.content == expected_tables[i]


def test_recursive_split(splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
,,,,,
,,,,,
P,Q,,,M,N
3,4,,,9,10
"""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 4
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
for i, table in enumerate(result):
assert table.content == expected_tables[i]


def test_threshold_no_effect() -> None:
splitter = CSVDocumentSplitter(row_split_threshold=3)
csv_content = """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 1


def test_empty_input(splitter: CSVDocumentSplitter) -> None:
csv_content = ""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content