13
13
# limitations under the License.
14
14
"""Computing DP aggregations on (Pandas, Spark, Beam) Dataframes."""
15
15
import abc
16
+ import copy
16
17
from collections import namedtuple
17
18
from dataclasses import dataclass
18
19
from typing import Any , Dict , Iterable , List , Optional , Sequence , Tuple , Union
@@ -56,7 +57,7 @@ def dataframe_to_collection(df, columns: Columns):
56
57
pass
57
58
58
59
@abc .abstractmethod
59
- def collection_to_dataframe (col , group_key_column : str ):
60
+ def collection_to_dataframe (col , metric_output_columns : Sequence [ str ] ):
60
61
pass
61
62
62
63
@@ -65,9 +66,11 @@ class SparkConverter(DataFrameConvertor):
65
66
66
67
def __init__ (self , spark : pyspark .sql .SparkSession ):
67
68
self ._spark = spark
69
+ self ._partition_key_schema = None
68
70
69
71
def dataframe_to_collection (self , df : SparkDataFrame ,
70
72
columns : Columns ) -> pyspark .RDD :
73
+ self ._save_partition_key_schema (df , columns .partition_key )
71
74
columns_to_keep = [columns .privacy_key ]
72
75
if isinstance (columns .partition_key , str ):
73
76
num_partition_columns = 1
@@ -90,8 +93,29 @@ def extractor(row):
90
93
91
94
return df .rdd .map (extractor )
92
95
93
- def collection_to_dataframe (self , col : pyspark .RDD ) -> SparkDataFrame :
94
- return self ._spark .createDataFrame (col )
96
+ def _save_partition_key_schema (self , df : SparkDataFrame ,
97
+ partition_key : Union [str , Sequence [str ]]):
98
+ col_name_to_schema = dict ((col .name , col ) for col in df .schema )
99
+ self ._partition_key_schema = []
100
+ if isinstance (partition_key , str ):
101
+ self ._partition_key_schema .append (col_name_to_schema [partition_key ])
102
+ else :
103
+ for column_name in partition_key :
104
+ self ._partition_key_schema .append (
105
+ col_name_to_schema [column_name ])
106
+
107
+ def collection_to_dataframe (
108
+ self , col : pyspark .RDD ,
109
+ metric_output_columns : Sequence [str ]) -> SparkDataFrame :
110
+ schema_fields = copy .deepcopy (self ._partition_key_schema )
111
+ float_type = pyspark .sql .types .DoubleType ()
112
+ for metric_column in metric_output_columns :
113
+ schema_fields .append (
114
+ pyspark .sql .types .StructField (metric_column ,
115
+ float_type ,
116
+ nullable = False ))
117
+ schema = pyspark .sql .types .StructType (schema_fields )
118
+ return self ._spark .createDataFrame (col , schema )
95
119
96
120
97
121
def _create_backend_for_dataframe (
@@ -217,7 +241,7 @@ def convert_to_partition_metrics_tuple(row):
217
241
"Convert to NamedTuple" )
218
242
# dp_result: PartitionMetricsTuple
219
243
220
- return converter .collection_to_dataframe (dp_result )
244
+ return converter .collection_to_dataframe (dp_result , output_columns )
221
245
222
246
223
247
@dataclass
0 commit comments