From 1e83f1d8dfaaaa33d9230dc3a23f86e2619c414b Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 14:41:20 -0400 Subject: [PATCH 01/10] add logic and integration test --- sdv/sequential/par.py | 8 ++++++++ tests/integration/sequential/test_par.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 98d8e4b6d..f5ac308d7 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -110,6 +110,7 @@ def __init__( self._sequence_index = self.metadata.sequence_index self.context_columns = context_columns or [] + self._validate_sequence_key_and_context_columns(self._sequence_key, self.context_columns) self._extra_context_columns = {} self.extended_columns = {} self.segment_size = segment_size @@ -194,6 +195,13 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') + def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): + if set(sequence_key).intersection(set(context_columns)): + raise SynthesizerInputError( + f'The sequence key {self._sequence_key} cannot be a context column. ' + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + def _validate_context_columns(self, data): errors = [] if self.context_columns: diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 640caef9b..c37546407 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -385,3 +385,22 @@ def test_par_sequence_index_is_numerical(): s1.fit(data) sample = s1.sample(2, 5) assert sample.columns.to_list() == data.columns.to_list() + + +def test_par_error_on_context_columns(): + metadata_dict = { + 'columns': { + 'A': {'sdtype': 'id'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'C': {'sdtype': 'numerical'}, + 'D': {'sdtype': 'categorical'}, + }, + 'sequence_key': 'A', + } + metadata = SingleTableMetadata.load_from_dict(metadata_dict) + sequence_key_context_column_error_msg = re.escape( + "The sequence key ['A'] cannot be a context column. " + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): + PARSynthesizer(metadata, context_columns=['A']) From cb880e003ef5a3521e0619fe61ccf535b5ec3e9c Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:07:12 -0400 Subject: [PATCH 02/10] fix unit test --- tests/unit/sequential/test_par.py | 33 ++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 0afc16e47..239816794 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -110,12 +110,8 @@ def test___init___no_sequence_key(self): def test_add_constraints(self): """Test that that only simple constraints can be added to PARSynthesizer.""" # Setup - metadata = self.get_metadata() - synthesizer = PARSynthesizer(metadata=metadata, context_columns=['name', 'measurement']) - name_constraint = { - 'constraint_class': 'Mock', - 'constraint_parameters': {'column_name': 'name'}, - } + metadata = self.get_metadata(add_sequence_key=True) + synthesizer = PARSynthesizer(metadata=metadata, context_columns=['gender', 'measurement']) measurement_constraint = { 'constraint_class': 'Mock', 'constraint_parameters': {'column_name': 'measurement'}, @@ -130,8 +126,12 @@ def test_add_constraints(self): } multi_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': {'column_names': ['name', 'time']}, + 'constraint_parameters': {'column_names': ['time', 'gender']}, } + # 'time': ['2020-01-01', '2020-01-02', '2020-01-03'], + # 'gender': ['F', 'M', 'M'], + # 'name': ['Jane', 'John', 'Doe'], + # 'measurement': [55, 60, 65], overlapping_error_msg = re.escape( 'The PARSynthesizer cannot accommodate multiple constraints ' 'that overlap on the same columns.' @@ -143,7 +143,7 @@ def test_add_constraints(self): # Run and Assert with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg): - synthesizer.add_constraints([name_constraint, gender_constraint]) + synthesizer.add_constraints([time_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg): synthesizer.add_constraints([time_constraint, measurement_constraint]) @@ -152,10 +152,7 @@ def test_add_constraints(self): synthesizer.add_constraints([multi_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([multi_constraint, name_constraint]) - - with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([name_constraint, name_constraint]) + synthesizer.add_constraints([multi_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): synthesizer.add_constraints([gender_constraint, gender_constraint]) @@ -935,3 +932,15 @@ def test_load(self, mock_file, cloudpickle_mock): mock_file.assert_called_once_with('synth.pkl', 'rb') cloudpickle_mock.load.assert_called_once_with(mock_file.return_value) assert loaded_instance == synthesizer_mock + + def test__par_error_on_context_columns(self): + metadata = self.get_metadata(add_sequence_key=True) + sequence_key_context_column_error_msg = re.escape( + "The sequence key ['name'] cannot be a context column. " + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): + PARSynthesizer( + metadata=metadata, + context_columns=['name'], + ) From b1b1423b462d832f0ac14fa6a3d550c4a6001164 Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:08:02 -0400 Subject: [PATCH 03/10] fix error msg --- sdv/sequential/par.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index f5ac308d7..fc5d74a21 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -198,7 +198,7 @@ def add_custom_constraint_class(self, class_object, class_name): def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): if set(sequence_key).intersection(set(context_columns)): raise SynthesizerInputError( - f'The sequence key {self._sequence_key} cannot be a context column. ' + f'The sequence key {sequence_key} cannot be a context column. ' 'To proceed, please remove the sequence key from the context_columns parameter.' ) From dc41006197f61608d4b14bea94d890be6676beb3 Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:09:39 -0400 Subject: [PATCH 04/10] shorter test --- tests/integration/sequential/test_par.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index c37546407..5eacaa644 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -392,8 +392,6 @@ def test_par_error_on_context_columns(): 'columns': { 'A': {'sdtype': 'id'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'C': {'sdtype': 'numerical'}, - 'D': {'sdtype': 'categorical'}, }, 'sequence_key': 'A', } From 56257008928cb51764b21ece8e8a007187c97ae0 Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:10:36 -0400 Subject: [PATCH 05/10] remove comments --- tests/unit/sequential/test_par.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 239816794..cd3928c6b 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -128,10 +128,6 @@ def test_add_constraints(self): 'constraint_class': 'Mock', 'constraint_parameters': {'column_names': ['time', 'gender']}, } - # 'time': ['2020-01-01', '2020-01-02', '2020-01-03'], - # 'gender': ['F', 'M', 'M'], - # 'name': ['Jane', 'John', 'Doe'], - # 'measurement': [55, 60, 65], overlapping_error_msg = re.escape( 'The PARSynthesizer cannot accommodate multiple constraints ' 'that overlap on the same columns.' From b0478b49bda5b581fd2dba64fb02f23b732f74c2 Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:37:55 -0400 Subject: [PATCH 06/10] fix constraints --- tests/unit/sequential/test_par.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index cd3928c6b..2e03b4cbe 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -126,7 +126,7 @@ def test_add_constraints(self): } multi_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': {'column_names': ['time', 'gender']}, + 'constraint_parameters': {'column_names': ['gender', 'time']}, } overlapping_error_msg = re.escape( 'The PARSynthesizer cannot accommodate multiple constraints ' @@ -148,7 +148,10 @@ def test_add_constraints(self): synthesizer.add_constraints([multi_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([multi_constraint, gender_constraint]) + synthesizer.add_constraints([multi_constraint, time_constraint]) + + with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): + synthesizer.add_constraints([multi_constraint, multi_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): synthesizer.add_constraints([gender_constraint, gender_constraint]) From 667d201129584bb2ee822e5656076d8d789b6be5 Mon Sep 17 00:00:00 2001 From: gsheni Date: Wed, 3 Jul 2024 15:41:54 -0400 Subject: [PATCH 07/10] make constraints similar --- tests/unit/sequential/test_par.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 2e03b4cbe..8fe19a7e7 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -148,10 +148,10 @@ def test_add_constraints(self): synthesizer.add_constraints([multi_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([multi_constraint, time_constraint]) + synthesizer.add_constraints([multi_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([multi_constraint, multi_constraint]) + synthesizer.add_constraints([gender_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): synthesizer.add_constraints([gender_constraint, gender_constraint]) From ab3c87b279dc66bbd177cb6803c5881f748dd745 Mon Sep 17 00:00:00 2001 From: gsheni Date: Mon, 8 Jul 2024 14:51:24 -0400 Subject: [PATCH 08/10] Address feedback --- sdv/sequential/par.py | 4 +++- tests/integration/sequential/test_par.py | 2 ++ tests/unit/sequential/test_par.py | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index fc5d74a21..760acbbc8 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -195,7 +195,9 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') - def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): + def _validate_sequence_key_and_context_columns( + self, sequence_key: list[str], context_columns: list[str] + ): if set(sequence_key).intersection(set(context_columns)): raise SynthesizerInputError( f'The sequence key {sequence_key} cannot be a context column. ' diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 5eacaa644..5d78f905f 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -388,6 +388,7 @@ def test_par_sequence_index_is_numerical(): def test_par_error_on_context_columns(): + # Setup metadata_dict = { 'columns': { 'A': {'sdtype': 'id'}, @@ -400,5 +401,6 @@ def test_par_error_on_context_columns(): "The sequence key ['A'] cannot be a context column. " 'To proceed, please remove the sequence key from the context_columns parameter.' ) + # Run and Assert with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): PARSynthesizer(metadata, context_columns=['A']) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 8fe19a7e7..4187acd88 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -933,11 +933,14 @@ def test_load(self, mock_file, cloudpickle_mock): assert loaded_instance == synthesizer_mock def test__par_error_on_context_columns(self): + """Test that the sequence_key is not a context column""" + # Setup metadata = self.get_metadata(add_sequence_key=True) sequence_key_context_column_error_msg = re.escape( "The sequence key ['name'] cannot be a context column. " 'To proceed, please remove the sequence key from the context_columns parameter.' ) + # Run and Assert with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): PARSynthesizer( metadata=metadata, From fd50434b8d1f0ff423cc9e58495c53f9ad969bf1 Mon Sep 17 00:00:00 2001 From: gsheni Date: Mon, 8 Jul 2024 15:01:58 -0400 Subject: [PATCH 09/10] update docstring --- sdv/sequential/par.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 760acbbc8..6fe4d0965 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -195,9 +195,15 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') - def _validate_sequence_key_and_context_columns( - self, sequence_key: list[str], context_columns: list[str] - ): + def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): + """Check that the sequence key is not present in the context colums. + + Args: + sequence_key (list[str]): + A list of column that identify which row(s) belong to which sequences. + context_columns (list[str]): + A list of strings, representing the columns that do not vary in a sequence. + """ if set(sequence_key).intersection(set(context_columns)): raise SynthesizerInputError( f'The sequence key {sequence_key} cannot be a context column. ' From 4c11961d057bf6de3e4b6aca6bd2fe8bcbabe173 Mon Sep 17 00:00:00 2001 From: gsheni Date: Mon, 8 Jul 2024 18:09:38 -0400 Subject: [PATCH 10/10] address feedback --- sdv/sequential/par.py | 8 ++++---- tests/integration/sequential/test_par.py | 2 +- tests/unit/sequential/test_par.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 6fe4d0965..8401f982c 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -110,7 +110,7 @@ def __init__( self._sequence_index = self.metadata.sequence_index self.context_columns = context_columns or [] - self._validate_sequence_key_and_context_columns(self._sequence_key, self.context_columns) + self._validate_sequence_key_and_context_columns() self._extra_context_columns = {} self.extended_columns = {} self.segment_size = segment_size @@ -195,7 +195,7 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') - def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): + def _validate_sequence_key_and_context_columns(self): """Check that the sequence key is not present in the context colums. Args: @@ -204,9 +204,9 @@ def _validate_sequence_key_and_context_columns(self, sequence_key, context_colum context_columns (list[str]): A list of strings, representing the columns that do not vary in a sequence. """ - if set(sequence_key).intersection(set(context_columns)): + if set(self._sequence_key).intersection(set(self.context_columns)): raise SynthesizerInputError( - f'The sequence key {sequence_key} cannot be a context column. ' + f'The sequence key {self._sequence_key} cannot be a context column. ' 'To proceed, please remove the sequence key from the context_columns parameter.' ) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 5d78f905f..724b91b52 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -387,7 +387,7 @@ def test_par_sequence_index_is_numerical(): assert sample.columns.to_list() == data.columns.to_list() -def test_par_error_on_context_columns(): +def test_init_error_sequence_key_in_context(): # Setup metadata_dict = { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 4187acd88..3ee048f29 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -932,7 +932,7 @@ def test_load(self, mock_file, cloudpickle_mock): cloudpickle_mock.load.assert_called_once_with(mock_file.return_value) assert loaded_instance == synthesizer_mock - def test__par_error_on_context_columns(self): + def test___init___error_sequence_key_in_context(self): """Test that the sequence_key is not a context column""" # Setup metadata = self.get_metadata(add_sequence_key=True)