Skip to content

Commit

Permalink
Per partition combiner tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym committed Sep 11, 2024
1 parent 861005a commit 6f92164
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
5 changes: 4 additions & 1 deletion analysis/per_partition_combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def create_accumulator(
self, data: Tuple[np.ndarray, np.ndarray,
np.ndarray]) -> AccumulatorType:
count, partition_sum, n_partitions = data
if self._i_column != -1:
if self._i_column is not None:
# When i_column is set, it means that this is a multi-column
# case and this combiner process i-th column. The partition_sum
# will be 2d np array: n_examples*n_columns
# extract corresponding column in case of multi-column case.
partition_sum = partition_sum[:, self._i_column]
del count # not used for SumCombiner
Expand Down
29 changes: 23 additions & 6 deletions analysis/tests/per_partition_combiners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _create_sparse_combiner_acc(
return (counts, sums, n_partitions)


class UtilityAnalysisCountCombinerTest(parameterized.TestCase):
class CountCombinerTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='empty',
Expand Down Expand Up @@ -273,7 +273,7 @@ def _create_combiner_params_for_sum(
))


class UtilityAnalysisSumCombinerTest(parameterized.TestCase):
class SumCombinerTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='empty',
Expand Down Expand Up @@ -366,6 +366,23 @@ def test_merge(self):
# Test that no type is np.float64
self.assertTrue(_check_none_are_np_float64(merged_acc))

def test_create_accumulator_for_multi_columns(self):
params = _create_combiner_params_for_sum(0, 5)
combiner = combiners.SumCombiner(*params, i_column=1)
data = (np.array([1, 1]), np.array([[1, 10],
[2, 20]]), np.array([100, 150]))
partition_sum, clipping_to_min_error, clipping_to_max_error, expected_l0_bounding_error, var_cross_partition_error = combiner.create_accumulator(
data)
self.assertEqual(partition_sum, 30)
self.assertEqual(clipping_to_min_error, 0)
self.assertEqual(clipping_to_max_error, -20)
self.assertAlmostEqual(expected_l0_bounding_error,
-9.91666667,
delta=1e-8)
self.assertAlmostEqual(var_cross_partition_error,
0.41305556,
delta=1e-8)


def _create_combiner_params_for_privacy_id_count() -> Tuple[
pipeline_dp.budget_accounting.MechanismSpec, pipeline_dp.AggregateParams]:
Expand All @@ -381,7 +398,7 @@ def _create_combiner_params_for_privacy_id_count() -> Tuple[
))


class UtilityAnalysisPrivacyIdCountCombinerTest(parameterized.TestCase):
class PrivacyIdCountCombinerTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='empty',
Expand Down Expand Up @@ -463,13 +480,13 @@ def test_merge(self):
self.assertTrue(_check_none_are_np_float64(merged_acc))


class UtilityAnalysisCompoundCombinerTest(parameterized.TestCase):
class CompoundCombinerTest(parameterized.TestCase):

def _create_combiner(self) -> combiners.CompoundCombiner:
mechanism_spec, params = _create_combiner_params_for_count()
count_combiner = combiners.CountCombiner(mechanism_spec, params)
return combiners.CompoundCombiner([count_combiner],
return_named_tuple=False)
n_sum_aggregations=0)

def test_create_accumulator_empty_data(self):
sparse, dense = self._create_combiner().create_accumulator(())
Expand Down Expand Up @@ -611,7 +628,7 @@ def test_two_internal_combiners(self):
sum_mechanism_spec, sum_params = _create_combiner_params_for_sum(0, 5)
sum_combiner = combiners.SumCombiner(sum_mechanism_spec, sum_params)
combiner = combiners.CompoundCombiner([count_combiner, sum_combiner],
return_named_tuple=False)
n_sum_aggregations=1)

data, n_partitions = [1, 2, 3], 100
acc = combiner.create_accumulator((len(data), sum(data), n_partitions))
Expand Down

0 comments on commit 6f92164

Please sign in to comment.