Skip to content

Commit 087bb02

Browse files
authored
Support serialization of internal objects (#312)
* Add a generic serialization interface * Add serialization for constraints * Add serialization for distributions * Add serialization for DataGenerators and ColumnGenerationSpecs * Add tests * Clean-up lint messages * Update serialization naming and implementation * Add serialization methods for TextGenerators * Update tests * Update documentation * Consolidate and rename interface methods * Remove PyYAML dependency * Update tests * Update documentation with new syntax
1 parent 7275090 commit 087bb02

28 files changed

+986
-22
lines changed

dbldatagen/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from ._version import __version__
3535
from .column_generation_spec import ColumnGenerationSpec
3636
from .column_spec_options import ColumnSpecOptions
37+
from .constraints import Constraint, ChainedRelation, LiteralRange, LiteralRelation, NegativeValues, PositiveValues, \
38+
RangedValues, SqlExpr, UniqueCombinations
3739
from .data_analyzer import DataAnalyzer
3840
from .schema_parser import SchemaParser
3941
from .daterange import DateRange
@@ -49,7 +51,7 @@
4951
__all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange",
5052
"column_generation_spec", "utils", "function_builder",
5153
"spark_singleton", "text_generators", "datarange", "datagen_constants",
52-
"text_generator_plugins", "html_utils", "datasets_object"
54+
"text_generator_plugins", "html_utils", "datasets_object", "constraints"
5355
]
5456

5557

dbldatagen/column_generation_spec.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .daterange import DateRange
2626
from .distributions import Normal, DataDistribution
2727
from .nrange import NRange
28+
from .serialization import SerializableToDict
2829
from .text_generators import TemplateGenerator
2930
from .utils import ensure, coalesce_values
3031
from .schema_parser import SchemaParser
@@ -40,7 +41,7 @@
4041
RAW_VALUES_COMPUTE_METHOD]
4142

4243

43-
class ColumnGenerationSpec(object):
44+
class ColumnGenerationSpec(SerializableToDict):
4445
""" Column generation spec object - specifies how column is to be generated
4546
4647
Each column to be output will have a corresponding ColumnGenerationSpec object.
@@ -119,7 +120,7 @@ def __init__(self, name, colType=None, minValue=0, maxValue=None, step=1, prefix
119120
if EXPR_OPTION not in kwargs:
120121
raise ValueError("Column generation spec must have `expr` attribute specified if datatype is inferred")
121122

122-
elif type(colType) == str:
123+
elif isinstance(colType, str):
123124
colType = SchemaParser.columnTypeFromString(colType)
124125

125126
assert isinstance(colType, DataType), f"colType `{colType}` is not instance of DataType"
@@ -299,6 +300,21 @@ def __init__(self, name, colType=None, minValue=0, maxValue=None, step=1, prefix
299300
# set up the temporary columns needed for data generation
300301
self._setupTemporaryColumns()
301302

303+
def _toInitializationDict(self):
304+
""" Converts an object to a Python dictionary. Keys represent the object's
305+
constructor arguments.
306+
:return: Python dictionary representation of the object
307+
"""
308+
_options = self._csOptions.options.copy()
309+
_options["colName"] = _options.pop("name", self.name)
310+
_options["colType"] = _options.pop("type", self.datatype).simpleString()
311+
_options["kind"] = self.__class__.__name__
312+
return {
313+
k: v._toInitializationDict()
314+
if isinstance(v, SerializableToDict) else v
315+
for k, v in _options.items() if v is not None
316+
}
317+
302318
def _temporaryRename(self, tmpName):
303319
""" Create enter / exit object to support temporary renaming of column spec
304320
@@ -451,7 +467,7 @@ def setBaseColumnDatatypes(self, columnDatatypes):
451467
assert type(columnDatatypes) is list, " `column_datatypes` parameter must be list"
452468
ensure(len(columnDatatypes) == len(self.baseColumns),
453469
"number of base column datatypes must match number of base columns")
454-
self._baseColumnDatatypes = [].append(columnDatatypes)
470+
self._baseColumnDatatypes = columnDatatypes.copy()
455471

456472
def _setupTemporaryColumns(self):
457473
""" Set up any temporary columns needed for test data generation.

dbldatagen/constraints/chained_relation.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pyspark.sql import DataFrame
99
import pyspark.sql.functions as F
1010
from .constraint import Constraint, NoPrepareTransformMixin
11+
from ..serialization import SerializableToDict
1112

1213

1314
class ChainedRelation(NoPrepareTransformMixin, Constraint):
@@ -38,6 +39,18 @@ def __init__(self, columns, relation):
3839
if not isinstance(self._columns, list) or len(self._columns) <= 1:
3940
raise ValueError("ChainedRelation constraints must be defined across more than one column")
4041

42+
def _toInitializationDict(self):
43+
""" Converts an object to a Python dictionary. Keys represent the object's
44+
constructor arguments.
45+
:return: Python dictionary representation of the object
46+
"""
47+
_options = {"kind": self.__class__.__name__, "relation": self._relation, "columns": self._columns}
48+
return {
49+
k: v._toInitializationDict()
50+
if isinstance(v, SerializableToDict) else v
51+
for k, v in _options.items() if v is not None
52+
}
53+
4154
def _generateFilterExpression(self):
4255
""" Generated composite filter expression for chained set of filter expressions
4356

dbldatagen/constraints/constraint.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import types
99
from abc import ABC, abstractmethod
1010
from pyspark.sql import Column
11+
from ..serialization import SerializableToDict
1112

1213

13-
class Constraint(ABC):
14+
class Constraint(SerializableToDict, ABC):
1415
""" Constraint object - base class for predefined and custom constraints
1516
1617
This class is meant for internal use only.

dbldatagen/constraints/literal_range_constraint.py

+19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyspark.sql.functions as F
99

1010
from .constraint import Constraint, NoPrepareTransformMixin
11+
from ..serialization import SerializableToDict
1112

1213

1314
class LiteralRange(NoPrepareTransformMixin, Constraint):
@@ -29,6 +30,24 @@ def __init__(self, columns, lowValue, highValue, strict=False):
2930
self._highValue = highValue
3031
self._strict = strict
3132

33+
def _toInitializationDict(self):
34+
""" Converts an object to a Python dictionary. Keys represent the object's
35+
constructor arguments.
36+
:return: Python dictionary representation of the object
37+
"""
38+
_options = {
39+
"kind": self.__class__.__name__,
40+
"columns": self._columns,
41+
"lowValue": self._lowValue,
42+
"highValue": self._highValue,
43+
"strict": self._strict
44+
}
45+
return {
46+
k: v._toInitializationDict()
47+
if isinstance(v, SerializableToDict) else v
48+
for k, v in _options.items() if v is not None
49+
}
50+
3251
def _generateFilterExpression(self):
3352
""" Generate a SQL filter expression that may be used for filtering"""
3453
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/literal_relation_constraint.py

+18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyspark.sql.functions as F
99

1010
from .constraint import Constraint, NoPrepareTransformMixin
11+
from ..serialization import SerializableToDict
1112

1213

1314
class LiteralRelation(NoPrepareTransformMixin, Constraint):
@@ -29,6 +30,23 @@ def __init__(self, columns, relation, value):
2930
if relation not in self.SUPPORTED_OPERATORS:
3031
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")
3132

33+
def _toInitializationDict(self):
34+
""" Converts an object to a Python dictionary. Keys represent the object's
35+
constructor arguments.
36+
:return: Python dictionary representation of the object
37+
"""
38+
_options = {
39+
"kind": self.__class__.__name__,
40+
"columns": self._columns,
41+
"relation": self._relation,
42+
"value": self._value
43+
}
44+
return {
45+
k: v._toInitializationDict()
46+
if isinstance(v, SerializableToDict) else v
47+
for k, v in _options.items() if v is not None
48+
}
49+
3250
def _generateFilterExpression(self):
3351
expressions = [F.col(colname) for colname in self._columns]
3452
literalValue = F.lit(self._value)

dbldatagen/constraints/negative_values.py

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
import pyspark.sql.functions as F
99
from .constraint import Constraint, NoPrepareTransformMixin
10+
from ..serialization import SerializableToDict
1011

1112

1213
class NegativeValues(NoPrepareTransformMixin, Constraint):
@@ -27,6 +28,18 @@ def __init__(self, columns, strict=False):
2728
self._columns = self._columnsFromListOrString(columns)
2829
self._strict = strict
2930

31+
def _toInitializationDict(self):
32+
""" Converts an object to a Python dictionary. Keys represent the object's
33+
constructor arguments.
34+
:return: Python dictionary representation of the object
35+
"""
36+
_options = {"kind": self.__class__.__name__, "columns": self._columns, "strict": self._strict}
37+
return {
38+
k: v._toInitializationDict()
39+
if isinstance(v, SerializableToDict) else v
40+
for k, v in _options.items() if v is not None
41+
}
42+
3043
def _generateFilterExpression(self):
3144
expressions = [F.col(colname) for colname in self._columns]
3245
if self._strict:

dbldatagen/constraints/positive_values.py

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
import pyspark.sql.functions as F
99
from .constraint import Constraint, NoPrepareTransformMixin
10+
from ..serialization import SerializableToDict
1011

1112

1213
class PositiveValues(NoPrepareTransformMixin, Constraint):
@@ -27,6 +28,18 @@ def __init__(self, columns, strict=False):
2728
self._columns = self._columnsFromListOrString(columns)
2829
self._strict = strict
2930

31+
def _toInitializationDict(self):
32+
""" Converts an object to a Python dictionary. Keys represent the object's
33+
constructor arguments.
34+
:return: Python dictionary representation of the object
35+
"""
36+
_options = {"kind": self.__class__.__name__, "columns": self._columns, "strict": self._strict}
37+
return {
38+
k: v._toInitializationDict()
39+
if isinstance(v, SerializableToDict) else v
40+
for k, v in _options.items() if v is not None
41+
}
42+
3043
def _generateFilterExpression(self):
3144
""" Generate a filter expression that may be used for filtering"""
3245
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/ranged_values_constraint.py

+19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyspark.sql.functions as F
99

1010
from .constraint import Constraint, NoPrepareTransformMixin
11+
from ..serialization import SerializableToDict
1112

1213

1314
class RangedValues(NoPrepareTransformMixin, Constraint):
@@ -28,6 +29,24 @@ def __init__(self, columns, lowValue, highValue, strict=False):
2829
self._highValue = highValue
2930
self._strict = strict
3031

32+
def _toInitializationDict(self):
33+
""" Returns an internal mapping dictionary for the object. Keys represent the
34+
class constructor arguments and values representing the object's internal data.
35+
:return: Python dictionary mapping constructor options to the object properties
36+
"""
37+
_options = {
38+
"kind": self.__class__.__name__,
39+
"columns": self._columns,
40+
"lowValue": self._lowValue,
41+
"highValue": self._highValue,
42+
"strict": self._strict
43+
}
44+
return {
45+
k: v._toInitializationDict()
46+
if isinstance(v, SerializableToDict) else v
47+
for k, v in _options.items() if v is not None
48+
}
49+
3150
def _generateFilterExpression(self):
3251
""" Generate a SQL filter expression that may be used for filtering"""
3352
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/sql_expr.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyspark.sql.functions as F
99

1010
from .constraint import Constraint, NoPrepareTransformMixin
11+
from ..serialization import SerializableToDict
1112

1213

1314
class SqlExpr(NoPrepareTransformMixin, Constraint):
@@ -25,6 +26,18 @@ def __init__(self, expr: str):
2526
assert isinstance(expr, str) and len(expr.strip()) > 0, "Expression must be a valid SQL string"
2627
self._expr = expr
2728

29+
def _toInitializationDict(self):
30+
""" Converts an object to a Python dictionary. Keys represent the object's
31+
constructor arguments.
32+
:return: Python dictionary representation of the object
33+
"""
34+
_options = {"kind": self.__class__.__name__, "expr": self._expr}
35+
return {
36+
k: v._toInitializationDict()
37+
if isinstance(v, SerializableToDict) else v
38+
for k, v in _options.items() if v is not None
39+
}
40+
2841
def _generateFilterExpression(self):
2942
""" Generate a SQL filter expression that may be used for filtering"""
3043
return F.expr(self._expr)

dbldatagen/constraints/unique_combinations.py

+13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
This module defines the Positive class
77
"""
88
from .constraint import Constraint, NoFilterMixin
9+
from ..serialization import SerializableToDict
910

1011

1112
class UniqueCombinations(NoFilterMixin, Constraint):
@@ -45,6 +46,18 @@ def __init__(self, columns=None):
4546
else:
4647
self._columns = None
4748

49+
def _toInitializationDict(self):
50+
""" Converts an object to a Python dictionary. Keys represent the object's
51+
constructor arguments.
52+
:return: Python dictionary representation of the object
53+
"""
54+
_options = {"kind": self.__class__.__name__, "columns": self._columns}
55+
return {
56+
k: v._toInitializationDict()
57+
if isinstance(v, SerializableToDict) else v
58+
for k, v in _options.items() if v is not None
59+
}
60+
4861
def prepareDataGenerator(self, dataGenerator):
4962
""" Prepare the data generator to generate data that matches the constraint
5063

0 commit comments

Comments
 (0)