Skip to content

Commit 39d0f3a

Browse files
committed
test added
1 parent 1af1c22 commit 39d0f3a

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

analysis/tests/parameter_tuning_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,58 @@ def test_find_candidate_parameters_sum(self, bins, max_candidates,
273273
self.assertEqual([0] * len(expected_candidates),
274274
candidates.min_sum_per_partition)
275275

276+
def test_find_candidate_parameters_count_multi_columns_sum(self):
277+
l0_histogram = histograms.Histogram(
278+
histograms.HistogramType.L0_CONTRIBUTIONS,
279+
bins=[
280+
_frequency_bin(max_value=2),
281+
_frequency_bin(max_value=5),
282+
])
283+
linf_histogram = histograms.Histogram(
284+
histograms.HistogramType.LINF_CONTRIBUTIONS,
285+
bins=[
286+
_frequency_bin(max_value=4),
287+
_frequency_bin(max_value=6),
288+
])
289+
linf_sum_contributions_histogram1 = histograms.Histogram(
290+
histograms.HistogramType.LINF_SUM_CONTRIBUTIONS,
291+
bins=[
292+
_frequency_bin(max_value=1),
293+
_frequency_bin(max_value=2),
294+
_frequency_bin(max_value=3)
295+
])
296+
linf_sum_contributions_histogram2 = histograms.Histogram(
297+
histograms.HistogramType.LINF_SUM_CONTRIBUTIONS,
298+
bins=[
299+
_frequency_bin(max_value=5),
300+
_frequency_bin(max_value=7),
301+
])
302+
303+
hist = histograms.DatasetHistograms(
304+
l0_histogram, None, linf_histogram, [
305+
linf_sum_contributions_histogram1,
306+
linf_sum_contributions_histogram2
307+
], None, None, None)
308+
parameters_to_tune = parameter_tuning.ParametersToTune(
309+
max_partitions_contributed=True,
310+
max_contributions_per_partition=True,
311+
min_sum_per_partition=False,
312+
max_sum_per_partition=True)
313+
314+
candidates = parameter_tuning._find_candidate_parameters(
315+
hist,
316+
parameters_to_tune,
317+
_get_aggregate_params(
318+
[pipeline_dp.Metrics.SUM, pipeline_dp.Metrics.COUNT]),
319+
max_candidates=4)
320+
321+
self.assertEqual([1, 1, 5, 5], candidates.max_partitions_contributed)
322+
self.assertEqual([1, 6, 1, 6],
323+
candidates.max_contributions_per_partition)
324+
self.assertEqual([(1.0, 5.0), (3.0, 7.0), (1.0, 5.0), (3.0, 7.0)],
325+
candidates.max_sum_per_partition)
326+
self.assertEqual([(0, 0)] * 4, candidates.min_sum_per_partition)
327+
276328
def test_find_candidate_parameters_both_l0_and_linf_sum_to_be_tuned(self):
277329
mock_l0_histogram = histograms.Histogram(
278330
histograms.HistogramType.L0_CONTRIBUTIONS, bins=[])

0 commit comments

Comments
 (0)