Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpsynth/data_generation_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ def generate(
cross_attribute_constraints, discrete.domain
)

model = discrete_config.calibrate(zcdp_rho=discrete_zcdp_rho)(
result = discrete_config.calibrate(zcdp_rho=discrete_zcdp_rho)(
rng,
data=discrete,
initial_measurements=one_way_measurements,
initial_potentials=initial_potentials,
)

synthetic_data = model.synthetic_data()
synthetic_data = result.model.synthetic_data()
logging.info('[SynthKit Tabular]: Generated discrete synthetic data.')

# Convert synthetic data back to the original domain.
Expand Down
19 changes: 14 additions & 5 deletions dpsynth/data_generation_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def _create_initializers(
return initializers


@dataclasses.dataclass
class DataGenerationResult:
"""Result of end-to-end DP synthetic data generation."""

synthetic_data: pd.DataFrame


@dataclasses.dataclass
class DataGenerationV3(primitives.DPMechanism):
"""End-to-end DP synthetic data generation mechanism.
Expand Down Expand Up @@ -297,7 +304,7 @@ def dp_event(self) -> dp_accounting.DpEvent:

def __call__(
self, rng: np.random.Generator, data: pd.DataFrame
) -> pd.DataFrame:
) -> DataGenerationResult:
"""Generates differentially private synthetic data.

Args:
Expand All @@ -306,7 +313,7 @@ def __call__(
specified in ``domains``.

Returns:
A synthetic DataFrame with the same domain columns as the input.
A DataGenerationResult containing the synthetic DataFrame.

Raises:
ValueError: If calibrate() has not been called or if required columns are
Expand Down Expand Up @@ -349,13 +356,13 @@ def __call__(
initial_potentials = constraints.get_initial_parameters(
self.cross_attribute_constraints, discrete.domain
)
model = self.discrete_mechanism(
mechanism_result = self.discrete_mechanism(
rng,
data=discrete,
initial_measurements=one_way_measurements,
initial_potentials=initial_potentials,
)
synthetic_data = model.synthetic_data()
synthetic_data = mechanism_result.model.synthetic_data()
logging.info('[DPSynth]: Generated discrete synthetic data.')

# Phase 4: Decode synthetic data back to original domain.
Expand All @@ -373,4 +380,6 @@ def __call__(
logging.info('[DPSynth]: Converted data back to original domain.')

column_order = [col for col in data.columns if col in self.domains]
return pd.DataFrame(synthetic_columns)[column_order]
return DataGenerationResult(
synthetic_data=pd.DataFrame(synthetic_columns)[column_order]
)
6 changes: 6 additions & 0 deletions dpsynth/discrete_mechanisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
# pylint: disable=g-importing-member

from dpsynth.discrete_mechanisms.aim import AIMMechanism
from dpsynth.discrete_mechanisms.aim import AIMMechanismResult
from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanism
from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanismResult
from dpsynth.discrete_mechanisms.direct import DirectMechanism
from dpsynth.discrete_mechanisms.direct import DirectMechanismResult
from dpsynth.discrete_mechanisms.independent import IndependentMechanism
from dpsynth.discrete_mechanisms.independent import IndependentMechanismResult
from dpsynth.discrete_mechanisms.mst import MSTMechanism
from dpsynth.discrete_mechanisms.mst import MSTMechanismResult
from dpsynth.discrete_mechanisms.swift import SWIFTMechanism
from dpsynth.discrete_mechanisms.swift import SWIFTMechanismResult
from dpsynth.local_mode.primitives import DPMechanism as DiscreteMechanism

# Backwards-compatible aliases.
Expand Down
13 changes: 10 additions & 3 deletions dpsynth/discrete_mechanisms/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def _worst_approximated(
return keys[idx]


@dataclasses.dataclass
class AIMMechanismResult:
"""Result of running the AIM mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class AIMMechanism(primitives.DPMechanism):
"""Configuration for the AIM mechanism.
Expand Down Expand Up @@ -155,7 +162,7 @@ def __call__(
*,
initial_measurements: list[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> AIMMechanismResult:
"""Runs the AIM mechanism on the given data.

Args:
Expand All @@ -165,7 +172,7 @@ def __call__(
initial_potentials: Optional initial potentials (constraints).

Returns:
A MarkovRandomField representing the estimated data distribution.
An AIMMechanismResult containing the estimated data distribution.
"""
if self.zcdp_rho is None:
raise ValueError('Must call calibrate() before using the mechanism.')
Expand Down Expand Up @@ -288,4 +295,4 @@ def __call__(
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
logging.info('[AIM] Reducing sigma: %.1f', sigma)

return model
return AIMMechanismResult(model=model)
13 changes: 10 additions & 3 deletions dpsynth/discrete_mechanisms/aim_gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def _worst_approximated(
return max(current_scores, key=current_scores.get)


@dataclasses.dataclass
class AIMGDPMechanismResult:
"""Result of running the AIM-GDP mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class AIMGDPMechanism(primitives.DPMechanism):
"""Configuration for the AIM mechanism with Gaussian DP.
Expand Down Expand Up @@ -201,7 +208,7 @@ def __call__(
*,
initial_measurements: list[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> AIMGDPMechanismResult:
"""Runs the AIM-GDP mechanism on the given data.

Args:
Expand All @@ -211,7 +218,7 @@ def __call__(
initial_potentials: Optional initial potentials (constraints).

Returns:
A MarkovRandomField representing the estimated data distribution.
An AIMGDPMechanismResult containing the estimated data distribution.
"""
if self.gdp_sigma is None:
raise ValueError('Must call calibrate() before using the mechanism.')
Expand Down Expand Up @@ -358,4 +365,4 @@ def __call__(
'[AIM] Increasing budget per round: %.5f', budget_per_round
)

return model
return AIMGDPMechanismResult(model=model)
11 changes: 9 additions & 2 deletions dpsynth/discrete_mechanisms/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
import numpy as np


@dataclasses.dataclass
class DirectMechanismResult:
"""Result of running the direct mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class DirectMechanism(primitives.DPMechanism):
"""Configuration for the direct mechanism.
Expand Down Expand Up @@ -63,7 +70,7 @@ def __call__(
*,
initial_measurements: list[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> DirectMechanismResult:
"""Generate synthetic data using user specified two way marginals."""
if self.gdp_sigma is None:
raise ValueError('Must call calibrate() before using the mechanism.')
Expand All @@ -88,4 +95,4 @@ def __call__(
potentials=initial_potentials,
marginal_oracle=marginal_oracle,
)
return model
return DirectMechanismResult(model=model)
11 changes: 9 additions & 2 deletions dpsynth/discrete_mechanisms/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
import numpy as np


@dataclasses.dataclass
class IndependentMechanismResult:
"""Result of running the independent mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class IndependentMechanism(primitives.DPMechanism):
"""Configuration for the independent mechanism.
Expand Down Expand Up @@ -60,7 +67,7 @@ def __call__(
*,
initial_measurements: list[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> IndependentMechanismResult:
"""Generate synthetic data via the independent mechanism."""
if self.gdp_sigma is None:
raise ValueError('Must call calibrate() before using the mechanism.')
Expand Down Expand Up @@ -91,4 +98,4 @@ def __call__(
potentials=potentials,
marginal_oracle=marginal_oracle,
)
return model
return IndependentMechanismResult(model=model)
13 changes: 10 additions & 3 deletions dpsynth/discrete_mechanisms/mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ def _select_two_way_marginal_queries(
)


@dataclasses.dataclass
class MSTMechanismResult:
"""Result of running the MST mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class MSTMechanism(primitives.DPMechanism):
"""Configuration for the maximum spanning tree mechanism.
Expand Down Expand Up @@ -201,7 +208,7 @@ def __call__(
*,
initial_measurements: list[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> MSTMechanismResult:
"""Runs the MST mechanism on the given data.

Args:
Expand All @@ -212,7 +219,7 @@ def __call__(
estimation.

Returns:
A fitted MarkovRandomField model.
An MSTMechanismResult containing the estimated data distribution.

Raises:
ValueError: If calibrate() has not been called.
Expand Down Expand Up @@ -271,4 +278,4 @@ def __call__(
marginal_oracle=marginal_oracle,
)
logging.info('[MST]: Fit distribution to the noisy measurements.')
return model
return MSTMechanismResult(model=model)
13 changes: 10 additions & 3 deletions dpsynth/discrete_mechanisms/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
import tqdm


@dataclasses.dataclass
class SWIFTMechanismResult:
"""Result of running the SWIFT mechanism."""

model: mbi.MarkovRandomField


@dataclasses.dataclass
class SWIFTMechanism(primitives.DPMechanism):
"""Configuration for the SWIFT mechanism.
Expand Down Expand Up @@ -94,7 +101,7 @@ def __call__(
*,
initial_measurements: Sequence[mbi.LinearMeasurement] | None = None,
initial_potentials: mbi.CliqueVector | None = None,
) -> mbi.MarkovRandomField:
) -> SWIFTMechanismResult:
"""Runs the SWIFT mechanism on the given data.

Args:
Expand All @@ -105,7 +112,7 @@ def __call__(
estimation.

Returns:
A fitted MarkovRandomField model.
A SWIFTMechanismResult containing the estimated data distribution.

Raises:
ValueError: If calibrate() has not been called.
Expand Down Expand Up @@ -197,7 +204,7 @@ def __call__(
)
logging.info('[SWIFT] Estimated final model.')

return model
return SWIFTMechanismResult(model=model)


def _is_supported(clique: mbi.Clique, tree: nx.Graph) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions dpsynth/local_mode/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __call__(
) -> ColumnMeasurement:
"""Returns a ColumnMeasurement with the discretization transform."""
# Dedup: concentrated data can make quantiles return duplicate edges.
edges = _validate_mechanism(self.mechanism)(rng, data)
edges = _validate_mechanism(self.mechanism)(rng, data).quantiles
bin_edges = np.unique(np.asarray(edges, dtype=float))
cat_attr = vtx.categorical_attribute_from_edges(bin_edges, self.attribute)
return ColumnMeasurement(cat_attr, bin_edges)
Expand Down Expand Up @@ -136,7 +136,7 @@ def __call__(
"""Returns a ColumnMeasurement with the noisy histogram."""
mechanism = _validate_mechanism(self.mechanism)
encoded = vtx.discrete_encode(data, self.attribute)
noisy_counts = mechanism(rng, encoded)
noisy_counts = mechanism(rng, encoded).counts
measurement = mbi.LinearMeasurement(
noisy_counts, (self.name,), stddev=mechanism.sigma
)
Expand Down Expand Up @@ -185,7 +185,8 @@ def __call__(
mechanism = _validate_mechanism(self.mechanism)
# Map raw values to integer partition IDs for thresholding.
unique_values, inverse = np.unique(data, return_inverse=True)
selected_ids, counts, _ = mechanism(rng, inverse)
result = mechanism(rng, inverse)
selected_ids, counts = result.selected_partitions, result.estimated_counts
selected_values = list(unique_values[selected_ids])

# Build the discovered domain: default first, then selected values.
Expand Down
Loading
Loading