@@ -258,7 +258,7 @@ def make_batches(
258258 return batches , schema
259259
260260 @classmethod
261- def make_batch_groups (
261+ def make_grouped_batches (
262262 cls ,
263263 * ,
264264 num_groups : int ,
@@ -289,6 +289,38 @@ def make_batch_groups(
289289 )
290290 return groups , schema
291291
292+ @classmethod
293+ def make_cogrouped_batches (
294+ cls ,
295+ * ,
296+ num_groups : int ,
297+ num_rows : int ,
298+ num_cols : int ,
299+ batch_size : int = MAX_RECORDS_PER_BATCH ,
300+ spark_type_pool : list [tuple [Callable , Any ]],
301+ ) -> tuple [list [tuple [pa .RecordBatch , pa .RecordBatch ]], StructType ]:
302+ """Create cogroups of batch pairs (left, right).
303+
304+ Each cogroup has two DataFrames with identical schema but independent
305+ data, each with ``num_rows`` rows and ``num_cols`` flat columns.
306+ """
307+ left_groups , schema = cls .make_grouped_batches (
308+ num_groups = num_groups ,
309+ num_rows = num_rows ,
310+ num_cols = num_cols ,
311+ batch_size = batch_size ,
312+ spark_type_pool = spark_type_pool ,
313+ )
314+ right_groups , _ = cls .make_grouped_batches (
315+ num_groups = num_groups ,
316+ num_rows = num_rows ,
317+ num_cols = num_cols ,
318+ batch_size = batch_size ,
319+ spark_type_pool = spark_type_pool ,
320+ )
321+ cogroups = [(left_groups [i ][0 ], right_groups [i ][0 ]) for i in range (num_groups )]
322+ return cogroups , schema
323+
292324
293325class MockUDFFactory :
294326 """Constructs UDF command payloads for the worker protocol."""
@@ -424,6 +456,93 @@ class ArrowBatchedUDFPeakmemBench(_ArrowBatchedBenchMixin, _PeakmemBenchBase):
424456 pass
425457
426458
459+ # -- SQL_COGROUPED_MAP_ARROW_UDF ------------------------------------------------
460+ # UDF receives two ``pa.Table`` (left, right) per co-group, returns ``pa.Table``.
461+
462+
463+ class _CogroupedMapArrowBenchMixin :
464+ """Provides _write_scenario for SQL_COGROUPED_MAP_ARROW_UDF."""
465+
466+ def _cogrouped_map_arrow_identity (left , right ):
467+ """Identity cogroup UDF: returns left table as-is."""
468+ return left
469+
470+ def _cogrouped_map_arrow_concat (left , right ):
471+ """Concat cogroup UDF: vertically concatenates left and right tables."""
472+ import pyarrow as pa
473+
474+ return pa .concat_tables ([left , right ])
475+
476+ def _cogrouped_map_arrow_left_semi (left , right ):
477+ """Left-semi cogroup UDF: filters left rows whose key exists in right."""
478+ key_col = left .column_names [0 ]
479+ return left .join (right .select ([key_col ]), keys = key_col , join_type = "left semi" )
480+
481+ _scenario_configs = {
482+ "few_groups_sm" : (50 , 5_000 , 1 , 4 ),
483+ "few_groups_lg" : (50 , 50_000 , 1 , 4 ),
484+ "many_groups_sm" : (2_000 , 500 , 1 , 4 ),
485+ "many_groups_lg" : (500 , 10_000 , 1 , 4 ),
486+ "wide_values" : (200 , 5_000 , 1 , 20 ),
487+ "multi_key" : (200 , 5_000 , 3 , 5 ),
488+ }
489+
490+ @staticmethod
491+ def _build_scenario (name ):
492+ """Build a cogroup scenario: two DataFrames with the same grouping structure.
493+
494+ Unlike grouped map (which wraps columns in a struct), cogroup batches
495+ have flat columns: [key_col_0, ..., key_col_k, val_col_0, ..., val_col_v].
496+ """
497+ np .random .seed (42 )
498+ num_groups , rows_per_group , num_key_cols , num_value_cols = (
499+ _CogroupedMapArrowBenchMixin ._scenario_configs [name ]
500+ )
501+ n_cols = num_key_cols + num_value_cols
502+ type_pool = MockDataFactory .MIXED_TYPES [:n_cols ]
503+ while len (type_pool ) < n_cols :
504+ type_pool = type_pool + MockDataFactory .MIXED_TYPES [: n_cols - len (type_pool )]
505+
506+ cogroups , schema = MockDataFactory .make_cogrouped_batches (
507+ num_groups = num_groups ,
508+ num_rows = rows_per_group ,
509+ num_cols = n_cols ,
510+ spark_type_pool = type_pool ,
511+ batch_size = rows_per_group ,
512+ )
513+ return_type = StructType (schema .fields [num_key_cols :])
514+ return (cogroups , return_type , num_key_cols , num_value_cols )
515+
516+ _udfs = {
517+ "identity_udf" : _cogrouped_map_arrow_identity ,
518+ "concat_udf" : _cogrouped_map_arrow_concat ,
519+ "left_semi_udf" : _cogrouped_map_arrow_left_semi ,
520+ }
521+ params = [list (_scenario_configs ), list (_udfs )]
522+ param_names = ["scenario" , "udf" ]
523+
524+ def _write_scenario (self , scenario , udf_name , buf ):
525+ groups , schema , num_key_cols , num_value_cols = self ._build_scenario (scenario )
526+ udf_func = self ._udfs [udf_name ]
527+ left_offsets = MockUDFFactory .make_grouped_arg_offsets (num_key_cols , num_value_cols )
528+ right_offsets = MockUDFFactory .make_grouped_arg_offsets (num_key_cols , num_value_cols )
529+ arg_offsets = left_offsets + right_offsets
530+ MockProtocolWriter .write_worker_input (
531+ PythonEvalType .SQL_COGROUPED_MAP_ARROW_UDF ,
532+ lambda b : MockProtocolWriter .write_udf_payload (udf_func , schema , arg_offsets , b ),
533+ lambda b : MockProtocolWriter .write_grouped_data_payload (groups , num_dfs = 2 , buf = b ),
534+ buf ,
535+ )
536+
537+
538+ class CogroupedMapArrowUDFTimeBench (_CogroupedMapArrowBenchMixin , _TimeBenchBase ):
539+ pass
540+
541+
542+ class CogroupedMapArrowUDFPeakmemBench (_CogroupedMapArrowBenchMixin , _PeakmemBenchBase ):
543+ pass
544+
545+
427546# -- SQL_GROUPED_AGG_ARROW_UDF ------------------------------------------------
428547# UDF receives ``pa.Array`` columns per group, returns scalar.
429548
@@ -456,7 +575,7 @@ def _build_scenario(name):
456575 """Build a single scenario by name."""
457576 np .random .seed (42 )
458577 num_groups , rows_per_group , n_cols = _GroupedAggArrowBenchMixin ._scenario_configs [name ]
459- return MockDataFactory .make_batch_groups (
578+ return MockDataFactory .make_grouped_batches (
460579 num_groups = num_groups ,
461580 num_rows = rows_per_group ,
462581 num_cols = n_cols ,
@@ -603,7 +722,7 @@ def _build_scenario(name):
603722 num_fields = n_fields ,
604723 base_types = MockDataFactory .MIXED_TYPES ,
605724 )
606- groups , schema = MockDataFactory .make_batch_groups (
725+ groups , schema = MockDataFactory .make_grouped_batches (
607726 num_groups = num_groups ,
608727 num_rows = rows_per_group ,
609728 num_cols = 1 ,
@@ -732,7 +851,7 @@ def _build_scenario(name):
732851 )
733852 return ([(b ,) for b in batches ] * 200 , schema )
734853 _kind , rows , n_cols , num_groups = cfg
735- groups , schema = MockDataFactory .make_batch_groups (
854+ groups , schema = MockDataFactory .make_grouped_batches (
736855 num_groups = num_groups ,
737856 num_rows = rows ,
738857 num_cols = n_cols ,
@@ -1012,7 +1131,7 @@ def _build_scenarios():
10121131 "many_groups_lg" : (500 , 10_000 , 5 ),
10131132 "wide_cols" : (200 , 5_000 , 20 ),
10141133 }.items ():
1015- groups , schema = MockDataFactory .make_batch_groups (
1134+ groups , schema = MockDataFactory .make_grouped_batches (
10161135 num_groups = num_groups ,
10171136 num_rows = rows_per_group ,
10181137 num_cols = n_cols ,
0 commit comments