diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 98d8e4b6d..8401f982c 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -110,6 +110,7 @@ def __init__( self._sequence_index = self.metadata.sequence_index self.context_columns = context_columns or [] + self._validate_sequence_key_and_context_columns() self._extra_context_columns = {} self.extended_columns = {} self.segment_size = segment_size @@ -194,6 +195,21 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') + def _validate_sequence_key_and_context_columns(self): + """Check that the sequence key is not present in the context colums. + + Args: + sequence_key (list[str]): + A list of column that identify which row(s) belong to which sequences. + context_columns (list[str]): + A list of strings, representing the columns that do not vary in a sequence. + """ + if set(self._sequence_key).intersection(set(self.context_columns)): + raise SynthesizerInputError( + f'The sequence key {self._sequence_key} cannot be a context column. ' + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + def _validate_context_columns(self, data): errors = [] if self.context_columns: diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 640caef9b..724b91b52 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -385,3 +385,22 @@ def test_par_sequence_index_is_numerical(): s1.fit(data) sample = s1.sample(2, 5) assert sample.columns.to_list() == data.columns.to_list() + + +def test_init_error_sequence_key_in_context(): + # Setup + metadata_dict = { + 'columns': { + 'A': {'sdtype': 'id'}, + 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + }, + 'sequence_key': 'A', + } + metadata = SingleTableMetadata.load_from_dict(metadata_dict) + sequence_key_context_column_error_msg = re.escape( + "The sequence key ['A'] cannot be a context column. " + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + # Run and Assert + with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): + PARSynthesizer(metadata, context_columns=['A']) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 0afc16e47..3ee048f29 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -110,12 +110,8 @@ def test___init___no_sequence_key(self): def test_add_constraints(self): """Test that that only simple constraints can be added to PARSynthesizer.""" # Setup - metadata = self.get_metadata() - synthesizer = PARSynthesizer(metadata=metadata, context_columns=['name', 'measurement']) - name_constraint = { - 'constraint_class': 'Mock', - 'constraint_parameters': {'column_name': 'name'}, - } + metadata = self.get_metadata(add_sequence_key=True) + synthesizer = PARSynthesizer(metadata=metadata, context_columns=['gender', 'measurement']) measurement_constraint = { 'constraint_class': 'Mock', 'constraint_parameters': {'column_name': 'measurement'}, @@ -130,7 +126,7 @@ def test_add_constraints(self): } multi_constraint = { 'constraint_class': 'Mock', - 'constraint_parameters': {'column_names': ['name', 'time']}, + 'constraint_parameters': {'column_names': ['gender', 'time']}, } overlapping_error_msg = re.escape( 'The PARSynthesizer cannot accommodate multiple constraints ' @@ -143,7 +139,7 @@ def test_add_constraints(self): # Run and Assert with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg): - synthesizer.add_constraints([name_constraint, gender_constraint]) + synthesizer.add_constraints([time_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg): synthesizer.add_constraints([time_constraint, measurement_constraint]) @@ -152,10 +148,10 @@ def test_add_constraints(self): synthesizer.add_constraints([multi_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([multi_constraint, name_constraint]) + synthesizer.add_constraints([multi_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): - synthesizer.add_constraints([name_constraint, name_constraint]) + synthesizer.add_constraints([gender_constraint, gender_constraint]) with pytest.raises(SynthesizerInputError, match=overlapping_error_msg): synthesizer.add_constraints([gender_constraint, gender_constraint]) @@ -935,3 +931,18 @@ def test_load(self, mock_file, cloudpickle_mock): mock_file.assert_called_once_with('synth.pkl', 'rb') cloudpickle_mock.load.assert_called_once_with(mock_file.return_value) assert loaded_instance == synthesizer_mock + + def test___init___error_sequence_key_in_context(self): + """Test that the sequence_key is not a context column""" + # Setup + metadata = self.get_metadata(add_sequence_key=True) + sequence_key_context_column_error_msg = re.escape( + "The sequence key ['name'] cannot be a context column. " + 'To proceed, please remove the sequence key from the context_columns parameter.' + ) + # Run and Assert + with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg): + PARSynthesizer( + metadata=metadata, + context_columns=['name'], + )