Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support null foreign key values in HMA #2100

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from sdv import version
from sdv._utils import (
_validate_foreign_keys_not_null,
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
Expand Down Expand Up @@ -448,7 +447,6 @@ def fit(self, data):
})

check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
Expand Down
45 changes: 32 additions & 13 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
self._table_sizes = {}
self._max_child_rows = {}
self._min_child_rows = {}
self._null_child_synthesizers = {}
self._augmented_tables = []
self._learned_relationships = 0
self._default_parameters = {}
Expand Down Expand Up @@ -303,26 +304,30 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)]

try:
if child_rows.empty:
if child_rows.empty and not pd.isna(foreign_key_value):
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
synthesizer = self._synthesizer(
table_meta, **self._table_parameters[child_name]
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
row = pd.Series(row)
row.index = f'__{child_name}__{foreign_key}__' + row.index
if not child_rows.empty:
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
row = pd.Series(row)
row.index = f'__{child_name}__{foreign_key}__' + row.index

if scale_columns is None:
scale_columns = [column for column in row.index if column.endswith('scale')]

if len(child_rows) == 1:
row.loc[scale_columns] = None

extension_rows.append(row)
index.append(foreign_key_value)
if pd.isna(foreign_key_value):
self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer
else:
extension_rows.append(row)
index.append(foreign_key_value)
except Exception:
# Skip children rows subsets that fail
pass
Expand Down Expand Up @@ -392,11 +397,12 @@ def _augment_table(self, table, tables, table_name):
table[num_rows_key] = table[num_rows_key].fillna(0)
self._max_child_rows[num_rows_key] = table[num_rows_key].max()
self._min_child_rows[num_rows_key] = table[num_rows_key].min()
self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - (table[num_rows_key].sum() / child_table.shape[0])
tables[table_name] = table
self._learned_relationships += 1

self._augmented_tables.append(table_name)
self._clear_nans(table)
# self._clear_nans(table) TODO: replace with standardizing nans?

return table

Expand Down Expand Up @@ -507,12 +513,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key):
def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
# A child table is created based on only one foreign key.
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters, default_parameters)
if parent_row is not None:
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
synthesizer._set_parameters(parameters, default_parameters)
else:
synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}']

synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor

return synthesizer
Expand Down Expand Up @@ -610,6 +621,13 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
except (AttributeError, np.linalg.LinAlgError):
likelihoods[parent_id] = None

if f'__{table_name}__{foreign_key}' in self._null_child_synthesizers:
try:
likelihoods[np.nan] = synthesizer._get_likelihood(table_rows)

except (AttributeError, np.linalg.LinAlgError):
likelihoods[np.nan] = None

return pd.DataFrame(likelihoods, index=table_rows.index)

def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key):
Expand Down Expand Up @@ -638,6 +656,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
primary_key = self.metadata.tables[parent_name].primary_key
parent_table = parent_table.set_index(primary_key)
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy()
num_rows.loc[np.nan] = child_table.shape[0] - num_rows.sum()

likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)
Expand Down
69 changes: 38 additions & 31 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Utility functions for the MultiTable models."""

import math
import warnings
from collections import defaultdict
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv._utils import _get_root_tables
from sdv.errors import InvalidDataError, SamplingError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS

Expand Down Expand Up @@ -449,22 +448,22 @@ def _drop_rows(data, metadata, drop_missing_values):
])


def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep):
def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_missing_values):
"""Subsample the disconnected roots tables and their descendants."""
relationships = metadata.relationships
roots = _get_disconnected_roots_from_table(relationships, table)
for root in roots:
data[root] = data[root].sample(frac=ratio_to_keep)

_drop_rows(data, metadata, drop_missing_values=True)
_drop_rows(data, metadata, drop_missing_values)


def _subsample_table_and_descendants(data, metadata, table, num_rows):
def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values):
"""Subsample the table and its descendants.

The logic is to first subsample all the NaN foreign keys of the table.
We raise an error if we cannot reach referential integrity while keeping the number of rows.
Then, we drop rows of the descendants to ensure referential integrity.
The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values``
is True. We raise an error if we cannot reach referential integrity while keeping
the number of rows. Then, we drop rows of the descendants to ensure referential integrity.

Args:
data (dict):
Expand All @@ -474,19 +473,26 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows):
Metadata of the datasets.
table (str):
Name of the table.
num_rows (int):
Number of rows to keep in the table.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to False.
"""
idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)
if drop_missing_values:
idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)

data[table] = data[table].sample(num_rows)
_drop_rows(data, metadata, drop_missing_values=True)
_drop_rows(data, metadata, drop_missing_values)


def _get_primary_keys_referenced(data, metadata):
Expand Down Expand Up @@ -593,7 +599,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
_subsample_ancestors(data, metadata, parent, primary_keys_referenced)


def _subsample_data(data, metadata, main_table_name, num_rows):
def _subsample_data(data, metadata, main_table_name, num_rows, drop_missing_values=False):
"""Subsample multi-table table based on a table and a number of rows.

The strategy is to:
Expand All @@ -613,6 +619,10 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
Name of the main table.
num_rows (int):
Number of rows to keep in the main table.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to False.

Returns:
dict:
Expand All @@ -621,20 +631,17 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
result = deepcopy(data)
primary_keys_referenced = _get_primary_keys_referenced(result, metadata)
ratio_to_keep = num_rows / len(result[main_table_name])
try:
_validate_foreign_keys_not_null(metadata, result)
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils'
' to drop these rows before using ``get_random_subset``.'
)

try:
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
_subsample_disconnected_roots(
result, metadata, main_table_name, ratio_to_keep, drop_missing_values
)
_subsample_table_and_descendants(
result, metadata, main_table_name, num_rows, drop_missing_values
)
_subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced)
_drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys
_drop_rows(result, metadata, drop_missing_values)

except InvalidDataError as error:
if 'All references in table' not in str(error.args[0]):
raise error
Expand Down
29 changes: 23 additions & 6 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import warnings

import numpy as np
import pandas as pd

LOGGER = logging.getLogger(__name__)
Expand All @@ -24,6 +25,7 @@ class BaseHierarchicalSampler:

def __init__(self, metadata, table_synthesizers, table_sizes):
self.metadata = metadata
self._null_foreign_key_percentages = {}
self._table_synthesizers = table_synthesizers
self._table_sizes = table_sizes

Expand Down Expand Up @@ -103,7 +105,7 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
row_indices = sampled_rows.index
sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key]
else:
sampled_rows[foreign_key] = parent_row[parent_key]
sampled_rows[foreign_key] = parent_row[parent_key] if parent_row is not None else np.nan

previous = sampled_data.get(child_name)
if previous is None:
Expand Down Expand Up @@ -143,16 +145,18 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
"""
total_num_rows = round(self._table_sizes[child_name] * scale)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}']
total_parent_rows = round(total_num_rows * (1 - null_fk_pctg))
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
max_rows = self._max_child_rows[num_rows_key]
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int)

while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
while sum(sampled_data[table_name][num_rows_key]) != total_parent_rows:
num_rows_column = sampled_data[table_name][num_rows_key].argsort()

if sum(sampled_data[table_name][num_rows_key]) < total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) < total_parent_rows:
for i in num_rows_column:
# If the number of rows is already at the maximum, skip
# The exception is when the smallest value is already at the maximum,
Expand All @@ -164,7 +168,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
break

sampled_data[table_name].loc[i, num_rows_key] += 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows:
break

else:
Expand All @@ -179,7 +183,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
break

sampled_data[table_name].loc[i, num_rows_key] -= 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows:
break

def _sample_children(self, table_name, sampled_data, scale=1.0):
Expand Down Expand Up @@ -207,8 +211,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
sampled_data=sampled_data,
)

foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0]

if child_name not in sampled_data: # No child rows sampled, force row creation
foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0]
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
if num_rows_key in sampled_data[table_name].columns:
max_num_child_index = sampled_data[table_name][num_rows_key].idxmax()
Expand All @@ -224,6 +229,18 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
num_rows=1,
)

total_num_rows = round(self._table_sizes[child_name] * scale)
null_fk_pctg = self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}']
num_null_rows = round(total_num_rows * null_fk_pctg)
if num_null_rows > 0:
self._add_child_rows(
child_name=child_name,
parent_name=table_name,
parent_row=None,
sampled_data=sampled_data,
num_rows=num_null_rows
)

self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale)

def _finalize(self, sampled_data):
Expand Down
4 changes: 2 additions & 2 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references


def drop_unknown_references(data, metadata):
def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True):
"""Wrap the drop_unknown_references function from the utils module."""
warnings.warn(
"Please access the 'drop_unknown_references' function directly from the sdv.utils module"
'instead of sdv.utils.poc.',
FutureWarning,
)
return utils_drop_unknown_references(data, metadata)
return utils_drop_unknown_references(data, metadata, drop_missing_values, verbose)


def simplify_schema(data, metadata, verbose=True):
Expand Down
4 changes: 2 additions & 2 deletions sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sdv.multi_table.utils import _drop_rows


def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True):
def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True):
"""Drop rows with unknown foreign keys.

Args:
Expand All @@ -22,7 +22,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to True.
Defaults to False.
verbose (bool):
If True, print information about the rows that are dropped.
Defaults to True.
Expand Down
Loading
Loading