Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 124 additions & 5 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def make_batches(
return batches, schema

@classmethod
def make_batch_groups(
def make_grouped_batches(
cls,
*,
num_groups: int,
Expand Down Expand Up @@ -289,6 +289,38 @@ def make_batch_groups(
)
return groups, schema

@classmethod
def make_cogrouped_batches(
cls,
*,
num_groups: int,
num_rows: int,
num_cols: int,
batch_size: int = MAX_RECORDS_PER_BATCH,
spark_type_pool: list[tuple[Callable, Any]],
) -> tuple[list[tuple[pa.RecordBatch, pa.RecordBatch]], StructType]:
"""Create cogroups of batch pairs (left, right).

Each cogroup has two DataFrames with identical schema but independent
data, each with ``num_rows`` rows and ``num_cols`` flat columns.
"""
left_groups, schema = cls.make_grouped_batches(
num_groups=num_groups,
num_rows=num_rows,
num_cols=num_cols,
batch_size=batch_size,
spark_type_pool=spark_type_pool,
)
right_groups, _ = cls.make_grouped_batches(
num_groups=num_groups,
num_rows=num_rows,
num_cols=num_cols,
batch_size=batch_size,
spark_type_pool=spark_type_pool,
)
cogroups = [(left_groups[i][0], right_groups[i][0]) for i in range(num_groups)]
return cogroups, schema


class MockUDFFactory:
"""Constructs UDF command payloads for the worker protocol."""
Expand Down Expand Up @@ -424,6 +456,93 @@ class ArrowBatchedUDFPeakmemBench(_ArrowBatchedBenchMixin, _PeakmemBenchBase):
pass


# -- SQL_COGROUPED_MAP_ARROW_UDF ------------------------------------------------
# UDF receives two ``pa.Table`` (left, right) per co-group, returns ``pa.Table``.


class _CogroupedMapArrowBenchMixin:
"""Provides _write_scenario for SQL_COGROUPED_MAP_ARROW_UDF."""

def _cogrouped_map_arrow_identity(left, right):
"""Identity cogroup UDF: returns left table as-is."""
return left

def _cogrouped_map_arrow_concat(left, right):
"""Concat cogroup UDF: vertically concatenates left and right tables."""
import pyarrow as pa

return pa.concat_tables([left, right])

def _cogrouped_map_arrow_left_semi(left, right):
"""Left-semi cogroup UDF: filters left rows whose key exists in right."""
key_col = left.column_names[0]
return left.join(right.select([key_col]), keys=key_col, join_type="left semi")

_scenario_configs = {
"few_groups_sm": (50, 5_000, 1, 4),
"few_groups_lg": (50, 50_000, 1, 4),
"many_groups_sm": (2_000, 500, 1, 4),
"many_groups_lg": (500, 10_000, 1, 4),
"wide_values": (200, 5_000, 1, 20),
"multi_key": (200, 5_000, 3, 5),
}

@staticmethod
def _build_scenario(name):
"""Build a cogroup scenario: two DataFrames with the same grouping structure.

Unlike grouped map (which wraps columns in a struct), cogroup batches
have flat columns: [key_col_0, ..., key_col_k, val_col_0, ..., val_col_v].
"""
np.random.seed(42)
num_groups, rows_per_group, num_key_cols, num_value_cols = (
_CogroupedMapArrowBenchMixin._scenario_configs[name]
)
n_cols = num_key_cols + num_value_cols
type_pool = MockDataFactory.MIXED_TYPES[:n_cols]
while len(type_pool) < n_cols:
type_pool = type_pool + MockDataFactory.MIXED_TYPES[: n_cols - len(type_pool)]

cogroups, schema = MockDataFactory.make_cogrouped_batches(
num_groups=num_groups,
num_rows=rows_per_group,
num_cols=n_cols,
spark_type_pool=type_pool,
batch_size=rows_per_group,
)
return_type = StructType(schema.fields[num_key_cols:])
return (cogroups, return_type, num_key_cols, num_value_cols)

_udfs = {
"identity_udf": _cogrouped_map_arrow_identity,
"concat_udf": _cogrouped_map_arrow_concat,
"left_semi_udf": _cogrouped_map_arrow_left_semi,
}
params = [list(_scenario_configs), list(_udfs)]
param_names = ["scenario", "udf"]

def _write_scenario(self, scenario, udf_name, buf):
groups, schema, num_key_cols, num_value_cols = self._build_scenario(scenario)
udf_func = self._udfs[udf_name]
left_offsets = MockUDFFactory.make_grouped_arg_offsets(num_key_cols, num_value_cols)
right_offsets = MockUDFFactory.make_grouped_arg_offsets(num_key_cols, num_value_cols)
arg_offsets = left_offsets + right_offsets
MockProtocolWriter.write_worker_input(
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
lambda b: MockProtocolWriter.write_udf_payload(udf_func, schema, arg_offsets, b),
lambda b: MockProtocolWriter.write_grouped_data_payload(groups, num_dfs=2, buf=b),
buf,
)


class CogroupedMapArrowUDFTimeBench(_CogroupedMapArrowBenchMixin, _TimeBenchBase):
pass


class CogroupedMapArrowUDFPeakmemBench(_CogroupedMapArrowBenchMixin, _PeakmemBenchBase):
pass


# -- SQL_GROUPED_AGG_ARROW_UDF ------------------------------------------------
# UDF receives ``pa.Array`` columns per group, returns scalar.

Expand Down Expand Up @@ -456,7 +575,7 @@ def _build_scenario(name):
"""Build a single scenario by name."""
np.random.seed(42)
num_groups, rows_per_group, n_cols = _GroupedAggArrowBenchMixin._scenario_configs[name]
return MockDataFactory.make_batch_groups(
return MockDataFactory.make_grouped_batches(
num_groups=num_groups,
num_rows=rows_per_group,
num_cols=n_cols,
Expand Down Expand Up @@ -603,7 +722,7 @@ def _build_scenario(name):
num_fields=n_fields,
base_types=MockDataFactory.MIXED_TYPES,
)
groups, schema = MockDataFactory.make_batch_groups(
groups, schema = MockDataFactory.make_grouped_batches(
num_groups=num_groups,
num_rows=rows_per_group,
num_cols=1,
Expand Down Expand Up @@ -732,7 +851,7 @@ def _build_scenario(name):
)
return ([(b,) for b in batches] * 200, schema)
_kind, rows, n_cols, num_groups = cfg
groups, schema = MockDataFactory.make_batch_groups(
groups, schema = MockDataFactory.make_grouped_batches(
num_groups=num_groups,
num_rows=rows,
num_cols=n_cols,
Expand Down Expand Up @@ -1012,7 +1131,7 @@ def _build_scenarios():
"many_groups_lg": (500, 10_000, 5),
"wide_cols": (200, 5_000, 20),
}.items():
groups, schema = MockDataFactory.make_batch_groups(
groups, schema = MockDataFactory.make_grouped_batches(
num_groups=num_groups,
num_rows=rows_per_group,
num_cols=n_cols,
Expand Down