@@ -273,6 +273,58 @@ def test_find_candidate_parameters_sum(self, bins, max_candidates,
273
273
self .assertEqual ([0 ] * len (expected_candidates ),
274
274
candidates .min_sum_per_partition )
275
275
276
+ def test_find_candidate_parameters_count_multi_columns_sum (self ):
277
+ l0_histogram = histograms .Histogram (
278
+ histograms .HistogramType .L0_CONTRIBUTIONS ,
279
+ bins = [
280
+ _frequency_bin (max_value = 2 ),
281
+ _frequency_bin (max_value = 5 ),
282
+ ])
283
+ linf_histogram = histograms .Histogram (
284
+ histograms .HistogramType .LINF_CONTRIBUTIONS ,
285
+ bins = [
286
+ _frequency_bin (max_value = 4 ),
287
+ _frequency_bin (max_value = 6 ),
288
+ ])
289
+ linf_sum_contributions_histogram1 = histograms .Histogram (
290
+ histograms .HistogramType .LINF_SUM_CONTRIBUTIONS ,
291
+ bins = [
292
+ _frequency_bin (max_value = 1 ),
293
+ _frequency_bin (max_value = 2 ),
294
+ _frequency_bin (max_value = 3 )
295
+ ])
296
+ linf_sum_contributions_histogram2 = histograms .Histogram (
297
+ histograms .HistogramType .LINF_SUM_CONTRIBUTIONS ,
298
+ bins = [
299
+ _frequency_bin (max_value = 5 ),
300
+ _frequency_bin (max_value = 7 ),
301
+ ])
302
+
303
+ hist = histograms .DatasetHistograms (
304
+ l0_histogram , None , linf_histogram , [
305
+ linf_sum_contributions_histogram1 ,
306
+ linf_sum_contributions_histogram2
307
+ ], None , None , None )
308
+ parameters_to_tune = parameter_tuning .ParametersToTune (
309
+ max_partitions_contributed = True ,
310
+ max_contributions_per_partition = True ,
311
+ min_sum_per_partition = False ,
312
+ max_sum_per_partition = True )
313
+
314
+ candidates = parameter_tuning ._find_candidate_parameters (
315
+ hist ,
316
+ parameters_to_tune ,
317
+ _get_aggregate_params (
318
+ [pipeline_dp .Metrics .SUM , pipeline_dp .Metrics .COUNT ]),
319
+ max_candidates = 4 )
320
+
321
+ self .assertEqual ([1 , 1 , 5 , 5 ], candidates .max_partitions_contributed )
322
+ self .assertEqual ([1 , 6 , 1 , 6 ],
323
+ candidates .max_contributions_per_partition )
324
+ self .assertEqual ([(1.0 , 5.0 ), (3.0 , 7.0 ), (1.0 , 5.0 ), (3.0 , 7.0 )],
325
+ candidates .max_sum_per_partition )
326
+ self .assertEqual ([(0 , 0 )] * 4 , candidates .min_sum_per_partition )
327
+
276
328
def test_find_candidate_parameters_both_l0_and_linf_sum_to_be_tuned (self ):
277
329
mock_l0_histogram = histograms .Histogram (
278
330
histograms .HistogramType .L0_CONTRIBUTIONS , bins = [])
0 commit comments