Skip to content

Commit acf9f67

Browse files
authored
Deriving Spark DataFrame schema on converting from RDD to DataFrame (#508)
1 parent 36fc26c commit acf9f67

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

pipeline_dp/dataframes.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Computing DP aggregations on (Pandas, Spark, Beam) Dataframes."""
1515
import abc
16+
import copy
1617
from collections import namedtuple
1718
from dataclasses import dataclass
1819
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
@@ -56,7 +57,7 @@ def dataframe_to_collection(df, columns: Columns):
5657
pass
5758

5859
@abc.abstractmethod
59-
def collection_to_dataframe(col, group_key_column: str):
60+
def collection_to_dataframe(col, metric_output_columns: Sequence[str]):
6061
pass
6162

6263

@@ -65,9 +66,11 @@ class SparkConverter(DataFrameConvertor):
6566

6667
def __init__(self, spark: pyspark.sql.SparkSession):
6768
self._spark = spark
69+
self._partition_key_schema = None
6870

6971
def dataframe_to_collection(self, df: SparkDataFrame,
7072
columns: Columns) -> pyspark.RDD:
73+
self._save_partition_key_schema(df, columns.partition_key)
7174
columns_to_keep = [columns.privacy_key]
7275
if isinstance(columns.partition_key, str):
7376
num_partition_columns = 1
@@ -90,8 +93,29 @@ def extractor(row):
9093

9194
return df.rdd.map(extractor)
9295

93-
def collection_to_dataframe(self, col: pyspark.RDD) -> SparkDataFrame:
94-
return self._spark.createDataFrame(col)
96+
def _save_partition_key_schema(self, df: SparkDataFrame,
97+
partition_key: Union[str, Sequence[str]]):
98+
col_name_to_schema = dict((col.name, col) for col in df.schema)
99+
self._partition_key_schema = []
100+
if isinstance(partition_key, str):
101+
self._partition_key_schema.append(col_name_to_schema[partition_key])
102+
else:
103+
for column_name in partition_key:
104+
self._partition_key_schema.append(
105+
col_name_to_schema[column_name])
106+
107+
def collection_to_dataframe(
108+
self, col: pyspark.RDD,
109+
metric_output_columns: Sequence[str]) -> SparkDataFrame:
110+
schema_fields = copy.deepcopy(self._partition_key_schema)
111+
float_type = pyspark.sql.types.DoubleType()
112+
for metric_column in metric_output_columns:
113+
schema_fields.append(
114+
pyspark.sql.types.StructField(metric_column,
115+
float_type,
116+
nullable=False))
117+
schema = pyspark.sql.types.StructType(schema_fields)
118+
return self._spark.createDataFrame(col, schema)
95119

96120

97121
def _create_backend_for_dataframe(
@@ -217,7 +241,7 @@ def convert_to_partition_metrics_tuple(row):
217241
"Convert to NamedTuple")
218242
# dp_result: PartitionMetricsTuple
219243

220-
return converter.collection_to_dataframe(dp_result)
244+
return converter.collection_to_dataframe(dp_result, output_columns)
221245

222246

223247
@dataclass

tests/dataframes_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,32 @@ def test_run_query_multiple_partition_keys_e2e_run(self):
402402
self.assertAlmostEqual(row1["count"], 1, delta=1e-3)
403403
self.assertAlmostEqual(row1["sum_column"], 5, delta=1e-3)
404404

405+
def test_run_query_e2e_run_empty_result(self):
406+
# Arrange
407+
spark = self._get_spark_session()
408+
df = spark.createDataFrame(get_pandas_df())
409+
columns = dataframes.Columns("privacy_key", "group_key", "value")
410+
metrics = {pipeline_dp.Metrics.COUNT: "count_column"}
411+
bounds = dataframes.ContributionBounds(
412+
max_partitions_contributed=2,
413+
max_contributions_per_partition=2,
414+
min_value=-5,
415+
max_value=5)
416+
query = dataframes.Query(df,
417+
columns,
418+
metrics,
419+
bounds,
420+
public_partitions=None)
421+
422+
# Act
423+
budget = dataframes.Budget(1, 1e-10)
424+
result_df = query.run_query(budget)
425+
426+
# Assert
427+
# The small input dataset and private partition selection. It almost
428+
# sure leads to the empty result.
429+
self.assertTrue(result_df.toPandas().empty)
430+
405431

406432
if __name__ == '__main__':
407433
absltest.main()

0 commit comments

Comments
 (0)