From 407307a4d40f46e22833c8de022ba93c60539e1e Mon Sep 17 00:00:00 2001 From: Roy Wedge Date: Thu, 20 Jun 2024 09:52:50 -0400 Subject: [PATCH 1/7] Remove input error on null foreign keys from BaseMultiTableSynthesizer.fit (#2077) --- sdv/multi_table/base.py | 2 -- tests/integration/multi_table/test_hma.py | 13 ++----------- tests/unit/multi_table/test_base.py | 4 +--- 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index c789ee3c5..978ac8c0a 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -14,7 +14,6 @@ from sdv import version from sdv._utils import ( - _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, @@ -448,7 +447,6 @@ def fit(self, data): }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() self._fitted = False processed_data = self.preprocess(data) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index a82049200..7f80f3ab3 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1314,7 +1314,7 @@ def test_metadata_updated_warning_detect(self): assert len(record) == 1 def test_null_foreign_keys(self): - """Test that the synthesizer crashes when there are null foreign keys.""" + """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup metadata = MultiTableMetadata() metadata.add_table('parent_table') @@ -1370,16 +1370,7 @@ def test_null_foreign_keys(self): metadata.validate_data(data) # Run and Assert - err_msg = re.escape( - 'The data contains null values in foreign key columns. This feature is currently ' - 'unsupported. Please remove null values to fit the synthesizer.\n' - '\n' - 'Affected columns:\n' - "Table 'child_table1', column(s) ['fk']\n" - "Table 'child_table2', column(s) ['fk1', 'fk2']\n" - ) - with pytest.raises(SynthesizerInputError, match=err_msg): - synthesizer.fit(data) + synthesizer.fit(data) parametrization = [ diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3c493606f..83a5ccca3 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -981,8 +981,7 @@ def test_fit_processed_data_raises_version_error(self): instance._check_metadata_updated.assert_not_called() @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base._validate_foreign_keys_not_null') - def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): + def test_fit(self, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1002,7 +1001,6 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): BaseMultiTableSynthesizer.fit(instance, data) # Assert - mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data) instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() From 4c04c80facc12e31eeac976bdfa93669dd82b5c2 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:13:54 +0100 Subject: [PATCH 2/7] Switch drop_missing_values in in drop_unknown_references to support null foreign keys by default (#2081) --- sdv/utils/poc.py | 4 ++-- sdv/utils/utils.py | 4 ++-- tests/integration/utils/test_utils.py | 2 +- tests/unit/utils/test_poc.py | 8 ++++++-- tests/unit/utils/test_utils.py | 4 ++-- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index 682895bad..360fcc164 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -16,14 +16,14 @@ from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references -def drop_unknown_references(data, metadata): +def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True): """Wrap the drop_unknown_references function from the utils module.""" warnings.warn( "Please access the 'drop_unknown_references' function directly from the sdv.utils module" 'instead of sdv.utils.poc.', FutureWarning, ) - return utils_drop_unknown_references(data, metadata) + return utils_drop_unknown_references(data, metadata, drop_missing_values, verbose) def simplify_schema(data, metadata, verbose=True): diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index f6e5db7c0..2c4b6b6ae 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -10,7 +10,7 @@ from sdv.multi_table.utils import _drop_rows -def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): +def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True): """Drop rows with unknown foreign keys. Args: @@ -22,7 +22,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr drop_missing_values (bool): Boolean describing whether or not to also drop foreign keys with missing values If True, drop rows with missing values in the foreign keys. - Defaults to True. + Defaults to False. verbose (bool): If True, print information about the rows that are dropped. Defaults to True. diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 5139ab0c1..0405703a8 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -110,7 +110,7 @@ def test_drop_unknown_references_drop_missing_values(metadata, data, capsys): data['child'].loc[4, 'parent_id'] = np.nan # Run - cleaned_data = drop_unknown_references(data, metadata) + cleaned_data = drop_unknown_references(data, metadata, drop_missing_values=True) metadata.validate_data(cleaned_data) captured = capsys.readouterr() diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index bbd9723c3..c8873e2c9 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -17,6 +17,8 @@ def test_drop_unknown_references(mock_drop_unknown_references): # Setup data = Mock() metadata = Mock() + drop_missing_values = Mock() + verbose = Mock() expected_message = re.escape( "Please access the 'drop_unknown_references' function directly from the sdv.utils module" 'instead of sdv.utils.poc.' @@ -24,10 +26,12 @@ def test_drop_unknown_references(mock_drop_unknown_references): # Run with pytest.warns(FutureWarning, match=expected_message): - drop_unknown_references(data, metadata) + drop_unknown_references(data, metadata, drop_missing_values, verbose) # Assert - mock_drop_unknown_references.assert_called_once_with(data, metadata) + mock_drop_unknown_references.assert_called_once_with( + data, metadata, drop_missing_values, verbose + ) @patch('sdv.utils.poc._get_total_estimated_columns') diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 7aad95693..a3a2d810c 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -87,7 +87,7 @@ def _drop_rows(data, metadata, drop_missing_values): } metadata.validate.assert_called_once() metadata.validate_data.assert_called_once_with(data) - mock_drop_rows.assert_called_once_with(result, metadata, True) + mock_drop_rows.assert_called_once_with(result, metadata, False) for table_name, table in result.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -189,7 +189,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {4}, 'grandchild': {0, 3, 4}}) # Run - result = drop_unknown_references(data, metadata, verbose=False) + result = drop_unknown_references(data, metadata, drop_missing_values=True, verbose=False) # Assert metadata.validate.assert_called_once() From 49213caec760ee368c01e83c0d7a1a6eb03deadd Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:17:46 +0100 Subject: [PATCH 3/7] Support null foreign keys in get_random_subset (#2082) --- sdv/multi_table/utils.py | 69 +++++++++--------- tests/integration/utils/test_poc.py | 21 +++--- tests/unit/multi_table/test_utils.py | 100 +++++++++++++++++---------- 3 files changed, 110 insertions(+), 80 deletions(-) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index e339bcca2..561c74f78 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -1,15 +1,14 @@ """Utility functions for the MultiTable models.""" import math -import warnings from collections import defaultdict from copy import deepcopy import numpy as np import pandas as pd -from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null -from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError +from sdv._utils import _get_root_tables +from sdv.errors import InvalidDataError, SamplingError from sdv.multi_table import HMASynthesizer from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS @@ -449,22 +448,22 @@ def _drop_rows(data, metadata, drop_missing_values): ]) -def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep): +def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_missing_values): """Subsample the disconnected roots tables and their descendants.""" relationships = metadata.relationships roots = _get_disconnected_roots_from_table(relationships, table) for root in roots: data[root] = data[root].sample(frac=ratio_to_keep) - _drop_rows(data, metadata, drop_missing_values=True) + _drop_rows(data, metadata, drop_missing_values) -def _subsample_table_and_descendants(data, metadata, table, num_rows): +def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values): """Subsample the table and its descendants. - The logic is to first subsample all the NaN foreign keys of the table. - We raise an error if we cannot reach referential integrity while keeping the number of rows. - Then, we drop rows of the descendants to ensure referential integrity. + The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values`` + is True. We raise an error if we cannot reach referential integrity while keeping + the number of rows. Then, we drop rows of the descendants to ensure referential integrity. Args: data (dict): @@ -474,19 +473,26 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows): Metadata of the datasets. table (str): Name of the table. + num_rows (int): + Number of rows to keep in the table. + drop_missing_values (bool): + Boolean describing whether or not to also drop foreign keys with missing values + If True, drop rows with missing values in the foreign keys. + Defaults to False. """ - idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table) - num_rows_to_drop = len(data[table]) - num_rows - if len(idx_nan_fk) > num_rows_to_drop: - raise SamplingError( - f"Referential integrity cannot be reached for table '{table}' while keeping " - f'{num_rows} rows. Please try again with a bigger number of rows.' - ) - else: - data[table] = data[table].drop(idx_nan_fk) + if drop_missing_values: + idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table) + num_rows_to_drop = len(data[table]) - num_rows + if len(idx_nan_fk) > num_rows_to_drop: + raise SamplingError( + f"Referential integrity cannot be reached for table '{table}' while keeping " + f'{num_rows} rows. Please try again with a bigger number of rows.' + ) + else: + data[table] = data[table].drop(idx_nan_fk) data[table] = data[table].sample(num_rows) - _drop_rows(data, metadata, drop_missing_values=True) + _drop_rows(data, metadata, drop_missing_values) def _get_primary_keys_referenced(data, metadata): @@ -593,7 +599,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced): _subsample_ancestors(data, metadata, parent, primary_keys_referenced) -def _subsample_data(data, metadata, main_table_name, num_rows): +def _subsample_data(data, metadata, main_table_name, num_rows, drop_missing_values=False): """Subsample multi-table table based on a table and a number of rows. The strategy is to: @@ -613,6 +619,10 @@ def _subsample_data(data, metadata, main_table_name, num_rows): Name of the main table. num_rows (int): Number of rows to keep in the main table. + drop_missing_values (bool): + Boolean describing whether or not to also drop foreign keys with missing values + If True, drop rows with missing values in the foreign keys. + Defaults to False. Returns: dict: @@ -621,20 +631,17 @@ def _subsample_data(data, metadata, main_table_name, num_rows): result = deepcopy(data) primary_keys_referenced = _get_primary_keys_referenced(result, metadata) ratio_to_keep = num_rows / len(result[main_table_name]) - try: - _validate_foreign_keys_not_null(metadata, result) - except SynthesizerInputError: - warnings.warn( - 'The data contains null values in foreign key columns. ' - 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils' - ' to drop these rows before using ``get_random_subset``.' - ) try: - _subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep) - _subsample_table_and_descendants(result, metadata, main_table_name, num_rows) + _subsample_disconnected_roots( + result, metadata, main_table_name, ratio_to_keep, drop_missing_values + ) + _subsample_table_and_descendants( + result, metadata, main_table_name, num_rows, drop_missing_values + ) _subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced) - _drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys + _drop_rows(result, metadata, drop_missing_values) + except InvalidDataError as error: if 'All references in table' not in str(error.args[0]): raise error diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index b74ac7905..0a3e02135 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -49,7 +49,7 @@ def data(): ) child = pd.DataFrame( - data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maye', 'No', 'No']} + data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maybe', 'No', 'No']} ) return {'parent': parent, 'child': child} @@ -229,20 +229,17 @@ def test_get_random_subset_disconnected_schema(): def test_get_random_subset_with_missing_values(metadata, data): - """Test ``get_random_subset`` when there is missing values in the foreign keys.""" + """Test ``get_random_subset`` when there is missing values in the foreign keys. + + Here there should be at least one missing values in the random subset. + """ # Setup data = deepcopy(data) - data['child'].loc[4, 'parent_id'] = np.nan - expected_warning = re.escape( - 'The data contains null values in foreign key columns. ' - 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils' - ' to drop these rows before using ``get_random_subset``.' - ) + data['child'].loc[[2, 3, 4], 'parent_id'] = np.nan # Run - with pytest.warns(UserWarning, match=expected_warning): - cleaned_data = get_random_subset(data, metadata, 'child', 3) + result = get_random_subset(data, metadata, 'child', 3) # Assert - assert len(cleaned_data['child']) == 3 - assert not pd.isna(cleaned_data['child']['parent_id']).any() + assert len(result['child']) == 3 + assert result['child']['parent_id'].isnull().sum() > 0 diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index ad3770f71..cbea576b4 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1302,13 +1302,15 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo expected_result = deepcopy(data) # Run - _subsample_disconnected_roots(data, metadata, 'disconnected_root', ratio_to_keep) + _subsample_disconnected_roots( + data, metadata, 'disconnected_root', ratio_to_keep, drop_missing_values=False + ) # Assert mock_get_disconnected_roots_from_table.assert_called_once_with( metadata.relationships, 'disconnected_root' ) - mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) + mock_drop_rows.assert_called_once_with(data, metadata, False) for table_name in metadata.tables: if table_name not in {'grandparent', 'other_root'}: pd.testing.assert_frame_equal(data[table_name], expected_result[table_name]) @@ -1317,8 +1319,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo @patch('sdv.multi_table.utils._drop_rows') -@patch('sdv.multi_table.utils._get_nan_fk_indices_table') -def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, mock_drop_rows): +def test__subsample_table_and_descendants(mock_drop_rows): """Test the ``_subsample_table_and_descendants`` method.""" # Setup data = { @@ -1339,40 +1340,17 @@ def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, mock_dr 'col_8': [6, 7, 8, 9, 10], }), } - mock_get_nan_fk_indices_table.return_value = {0} metadata = Mock() metadata.relationships = Mock() # Run - _subsample_table_and_descendants(data, metadata, 'parent', 3) + _subsample_table_and_descendants(data, metadata, 'parent', 3, drop_missing_values=False) # Assert - mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent') - mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) + mock_drop_rows.assert_called_once_with(data, metadata, False) assert len(data['parent']) == 3 -@patch('sdv.multi_table.utils._get_nan_fk_indices_table') -def test__subsample_table_and_descendants_nan_fk(mock_get_nan_fk_indices_table): - """Test the ``_subsample_table_and_descendants`` when there are too many NaN foreign keys.""" - # Setup - data = {'parent': [1, 2, 3, 4, 5, 6]} - mock_get_nan_fk_indices_table.return_value = {0, 1, 2, 3, 4} - metadata = Mock() - metadata.relationships = Mock() - expected_message = re.escape( - "Referential integrity cannot be reached for table 'parent' while keeping " - '3 rows. Please try again with a bigger number of rows.' - ) - - # Run - with pytest.raises(SamplingError, match=expected_message): - _subsample_table_and_descendants(data, metadata, 'parent', 3) - - # Assert - mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent') - - def test__get_primary_keys_referenced(): """Test the ``_get_primary_keys_referenced`` method.""" data = { @@ -1930,9 +1908,7 @@ def test__subsample_ancestors_schema_diamond_shape(): @patch('sdv.multi_table.utils._subsample_ancestors') @patch('sdv.multi_table.utils._get_primary_keys_referenced') @patch('sdv.multi_table.utils._drop_rows') -@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') def test__subsample_data( - mock_validate_foreign_keys_not_null, mock_drop_rows, mock_get_primary_keys_referenced, mock_subsample_ancestors, @@ -1954,24 +1930,74 @@ def test__subsample_data( result = _subsample_data(data, metadata, main_table, num_rows) # Assert - mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data) + mock_drop_rows.assert_called_once_with(data, metadata, False) mock_get_primary_keys_referenced.assert_called_once_with(data, metadata) - mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5) + mock_subsample_disconnected_roots.assert_called_once_with( + data, metadata, main_table, 0.5, False + ) mock_subsample_table_and_descendants.assert_called_once_with( - data, metadata, main_table, num_rows + data, metadata, main_table, num_rows, False ) mock_subsample_ancestors.assert_called_once_with( data, metadata, main_table, primary_key_reference ) - mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True) assert result == data +def test__subsample_data_with_null_foreing_keys(): + """Test the ``_subsample_data`` method when there are null foreign keys.""" + # Setup + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'parent': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'A': {'sdtype': 'categorical'}, + 'B': {'sdtype': 'numerical'}, + }, + 'primary_key': 'id', + }, + 'child': {'columns': {'parent_id': {'sdtype': 'id'}, 'C': {'sdtype': 'categorical'}}}, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id', + 'child_foreign_key': 'parent_id', + } + ], + }) + + parent = pd.DataFrame( + data={ + 'id': [0, 1, 2, 3, 4], + 'A': [True, True, False, True, False], + 'B': [0.434, 0.312, 0.212, 0.339, 0.491], + } + ) + + child = pd.DataFrame( + data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maybe', 'No', 'No']} + ) + + data = {'parent': parent, 'child': child} + data['child'].loc[[2, 3, 4], 'parent_id'] = np.nan + + # Run + result_with_nan = _subsample_data(data, metadata, 'child', 4, drop_missing_values=False) + result_without_nan = _subsample_data(data, metadata, 'child', 2, drop_missing_values=True) + + # Assert + assert len(result_with_nan['child']) == 4 + assert result_with_nan['child']['parent_id'].isnull().sum() > 0 + assert len(result_without_nan['child']) == 2 + assert set(result_without_nan['child'].index) == {0, 1} + + @patch('sdv.multi_table.utils._subsample_disconnected_roots') @patch('sdv.multi_table.utils._get_primary_keys_referenced') -@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') def test__subsample_data_empty_dataset( - mock_validate_foreign_keys_not_null, mock_get_primary_keys_referenced, mock_subsample_disconnected_roots, ): From a3275ce1a481f36c14618fb08c502da70c349160 Mon Sep 17 00:00:00 2001 From: rwedge Date: Wed, 26 Jun 2024 14:42:47 -0400 Subject: [PATCH 4/7] fit --- sdv/multi_table/hma.py | 12 +++++++++--- tests/integration/multi_table/test_hma.py | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index d7a4712db..3fcb64ba0 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -156,6 +156,8 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._table_sizes = {} self._max_child_rows = {} self._min_child_rows = {} + self._null_child_synthesizers = {} + self._null_foreign_key_percentages = {} self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} @@ -321,8 +323,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None - extension_rows.append(row) - index.append(foreign_key_value) + if pd.isna(foreign_key_value): + self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer + else: + extension_rows.append(row) + index.append(foreign_key_value) except Exception: # Skip children rows subsets that fail pass @@ -392,11 +397,12 @@ def _augment_table(self, table, tables, table_name): table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() self._min_child_rows[num_rows_key] = table[num_rows_key].min() + self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - (table[num_rows_key].sum() / child_table.shape[0]) tables[table_name] = table self._learned_relationships += 1 self._augmented_tables.append(table_name) - self._clear_nans(table) + # self._clear_nans(table) TODO: replace with standardizing nans? return table diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 7f80f3ab3..7d19e715e 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1331,6 +1331,7 @@ def test_null_foreign_keys(self): metadata.set_primary_key('child_table2', 'id') metadata.add_column('child_table2', 'fk1', sdtype='id') metadata.add_column('child_table2', 'fk2', sdtype='id') + metadata.add_column('child_table2', 'cat_type', sdtype='categorical') metadata.add_relationship( parent_table_name='parent_table', @@ -1360,6 +1361,7 @@ def test_null_foreign_keys(self): 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], 'fk2': [1, 2, np.nan], + 'cat_type': ['siamese','persian', 'american shorthair'], }), } @@ -1371,6 +1373,8 @@ def test_null_foreign_keys(self): # Run and Assert synthesizer.fit(data) + breakpoint() + parametrization = [ From 21cbbadf0e23ff9e873866b1e6bef8c88e324eca Mon Sep 17 00:00:00 2001 From: rwedge Date: Fri, 28 Jun 2024 15:00:20 -0400 Subject: [PATCH 5/7] sample (wip) --- sdv/multi_table/hma.py | 25 +++++++++++++++++------ sdv/sampling/hierarchical_sampler.py | 25 ++++++++++++++++++----- tests/integration/multi_table/test_hma.py | 24 +++++++++++----------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 3fcb64ba0..9058a60c8 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -157,7 +157,6 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._max_child_rows = {} self._min_child_rows = {} self._null_child_synthesizers = {} - self._null_foreign_key_percentages = {} self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} @@ -323,6 +322,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None + # TODO: handle null synthesizer when child_rows is empty if pd.isna(foreign_key_value): self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer else: @@ -513,12 +513,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): # A child table is created based on only one foreign key. foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] - parameters = self._extract_parameters(parent_row, child_name, foreign_key) - default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) - table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) - synthesizer._set_parameters(parameters, default_parameters) + if parent_row is not None: + parameters = self._extract_parameters(parent_row, child_name, foreign_key) + default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) + + table_meta = self.metadata.tables[child_name] + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer._set_parameters(parameters, default_parameters) + else: + synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] + synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor return synthesizer @@ -616,6 +621,13 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): except (AttributeError, np.linalg.LinAlgError): likelihoods[parent_id] = None + if f'__{table_name}__{foreign_key}' in self._null_child_synthesizers: + try: + likelihoods[np.nan] = synthesizer._get_likelihood(table_rows) + + except (AttributeError, np.linalg.LinAlgError): + likelihoods[np.nan] = None + return pd.DataFrame(likelihoods, index=table_rows.index) def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key): @@ -644,6 +656,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy() + num_rows.loc[np.nan] = child_table.shape[0] - num_rows.sum() likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 7ba175904..8fd4e9df2 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -3,6 +3,7 @@ import logging import warnings +import numpy as np import pandas as pd LOGGER = logging.getLogger(__name__) @@ -24,6 +25,7 @@ class BaseHierarchicalSampler: def __init__(self, metadata, table_synthesizers, table_sizes): self.metadata = metadata + self._null_foreign_key_percentages = {} self._table_synthesizers = table_synthesizers self._table_sizes = table_sizes @@ -103,7 +105,7 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num row_indices = sampled_rows.index sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key] else: - sampled_rows[foreign_key] = parent_row[parent_key] + sampled_rows[foreign_key] = parent_row[parent_key] if parent_row is not None else np.nan previous = sampled_data.get(child_name) if previous is None: @@ -143,16 +145,18 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): """ total_num_rows = round(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): + null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + total_parent_rows = round(total_num_rows * (1 - null_fk_pctg)) num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] max_rows = self._max_child_rows[num_rows_key] key_data = sampled_data[table_name][num_rows_key].fillna(0).round() sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int) - while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: + while sum(sampled_data[table_name][num_rows_key]) != total_parent_rows: num_rows_column = sampled_data[table_name][num_rows_key].argsort() - if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) < total_parent_rows: for i in num_rows_column: # If the number of rows is already at the maximum, skip # The exception is when the smallest value is already at the maximum, @@ -164,7 +168,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] += 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break else: @@ -179,7 +183,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] -= 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break def _sample_children(self, table_name, sampled_data, scale=1.0): @@ -224,6 +228,17 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): num_rows=1, ) + total_num_rows = round(self._table_sizes[child_name] * scale) + num_null_rows = total_num_rows - sampled_data[child_name].shape[0] + if num_null_rows > 0: + self._add_child_rows( + child_name=child_name, + parent_name=table_name, + parent_row=None, + sampled_data=sampled_data, + num_rows=num_null_rows + ) + self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale) def _finalize(self, sampled_data): diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 7d19e715e..25aab3397 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1321,10 +1321,10 @@ def test_null_foreign_keys(self): metadata.add_column('parent_table', 'id', sdtype='id') metadata.set_primary_key('parent_table', 'id') - metadata.add_table('child_table1') - metadata.add_column('child_table1', 'id', sdtype='id') - metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk', sdtype='id') + # metadata.add_table('child_table1') + # metadata.add_column('child_table1', 'id', sdtype='id') + # metadata.set_primary_key('child_table1', 'id') + # metadata.add_column('child_table1', 'fk', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') @@ -1333,12 +1333,12 @@ def test_null_foreign_keys(self): metadata.add_column('child_table2', 'fk2', sdtype='id') metadata.add_column('child_table2', 'cat_type', sdtype='categorical') - metadata.add_relationship( - parent_table_name='parent_table', - child_table_name='child_table1', - parent_primary_key='id', - child_foreign_key='fk', - ) + # metadata.add_relationship( + # parent_table_name='parent_table', + # child_table_name='child_table1', + # parent_primary_key='id', + # child_foreign_key='fk', + # ) metadata.add_relationship( parent_table_name='parent_table', @@ -1356,7 +1356,7 @@ def test_null_foreign_keys(self): data = { 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + # 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], @@ -1373,7 +1373,7 @@ def test_null_foreign_keys(self): # Run and Assert synthesizer.fit(data) - breakpoint() + sampled_data = synthesizer.sample() From d78f6108d70348f6148e305b6097db8ed1fb5c45 Mon Sep 17 00:00:00 2001 From: rwedge Date: Mon, 8 Jul 2024 19:15:31 -0400 Subject: [PATCH 6/7] handle no columns to learn from case --- sdv/multi_table/hma.py | 12 ++++++------ tests/integration/multi_table/test_hma.py | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 9058a60c8..e9f9abbd2 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -304,17 +304,18 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)] try: - if child_rows.empty: + if child_rows.empty and not pd.isna(foreign_key_value): row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: synthesizer = self._synthesizer( table_meta, **self._table_parameters[child_name] ) - synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) - row = synthesizer._get_parameters() - row = pd.Series(row) - row.index = f'__{child_name}__{foreign_key}__' + row.index + if not child_rows.empty: + synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) + row = synthesizer._get_parameters() + row = pd.Series(row) + row.index = f'__{child_name}__{foreign_key}__' + row.index if scale_columns is None: scale_columns = [column for column in row.index if column.endswith('scale')] @@ -322,7 +323,6 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc if len(child_rows) == 1: row.loc[scale_columns] = None - # TODO: handle null synthesizer when child_rows is empty if pd.isna(foreign_key_value): self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer else: diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 25aab3397..998cb3d91 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1321,10 +1321,10 @@ def test_null_foreign_keys(self): metadata.add_column('parent_table', 'id', sdtype='id') metadata.set_primary_key('parent_table', 'id') - # metadata.add_table('child_table1') - # metadata.add_column('child_table1', 'id', sdtype='id') - # metadata.set_primary_key('child_table1', 'id') - # metadata.add_column('child_table1', 'fk', sdtype='id') + metadata.add_table('child_table1') + metadata.add_column('child_table1', 'id', sdtype='id') + metadata.set_primary_key('child_table1', 'id') + metadata.add_column('child_table1', 'fk', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') @@ -1333,12 +1333,12 @@ def test_null_foreign_keys(self): metadata.add_column('child_table2', 'fk2', sdtype='id') metadata.add_column('child_table2', 'cat_type', sdtype='categorical') - # metadata.add_relationship( - # parent_table_name='parent_table', - # child_table_name='child_table1', - # parent_primary_key='id', - # child_foreign_key='fk', - # ) + metadata.add_relationship( + parent_table_name='parent_table', + child_table_name='child_table1', + parent_primary_key='id', + child_foreign_key='fk', + ) metadata.add_relationship( parent_table_name='parent_table', @@ -1356,7 +1356,7 @@ def test_null_foreign_keys(self): data = { 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - # 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], From a7af7eae0c23ba394be09e402fa9db960ec3057e Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 9 Jul 2024 12:54:49 -0400 Subject: [PATCH 7/7] adjust num null parent calculation --- sdv/sampling/hierarchical_sampler.py | 6 ++++-- tests/unit/sampling/test_hierarchical_sampler.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 8fd4e9df2..93068c5a0 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -211,8 +211,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): sampled_data=sampled_data, ) + foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] + if child_name not in sampled_data: # No child rows sampled, force row creation - foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] num_rows_key = f'__{child_name}__{foreign_key}__num_rows' if num_rows_key in sampled_data[table_name].columns: max_num_child_index = sampled_data[table_name][num_rows_key].idxmax() @@ -229,7 +230,8 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): ) total_num_rows = round(self._table_sizes[child_name] * scale) - num_null_rows = total_num_rows - sampled_data[child_name].shape[0] + null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] + num_null_rows = round(total_num_rows * null_fk_pctg) if num_null_rows > 0: self._add_child_rows( child_name=child_name, diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index 2689b2789..5f84c108f 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -177,7 +177,7 @@ def sample_children(table_name, sampled_data, scale): 'session_id': ['a', 'a', 'b'], }) - def _add_child_rows(child_name, parent_name, parent_row, sampled_data): + def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=None): if parent_name == 'users': if parent_row['user_id'] == 1: sampled_data[child_name] = pd.DataFrame({