Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 286 detection remove fkey #289

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions sdmetrics/multi_table/detection/parent_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,39 @@ def _extract_foreign_keys(metadata):
return foreign_keys

@staticmethod
def _denormalize(data, foreign_key):
def _denormalize(data, foreign_key, metadata = None):
"""Denormalize the child table over the parent."""
parent_table, parent_key, child_table, child_key = foreign_key

flat = data[parent_table].set_index(parent_key).merge(
data[child_table].set_index(child_key),

if not isinstance(metadata, dict):
metadata = metadata.to_dict()

if metadata is not None:
table_meta_parent = metadata['tables'][parent_table]
table_meta_child = metadata['tables'][child_table]
else:
table_meta_parent = None
table_meta_child = None
to_drop_parent = []
to_drop_child = []

if table_meta_child is not None and 'primary_key' in table_meta_child:
to_drop_child.append(table_meta_child['primary_key'])

if table_meta_child is not None:
for field in table_meta_child['fields'].keys():
if ('ref' in table_meta_child['fields'][field].keys()) and (field!=child_key):
to_drop_child.append(field)

if table_meta_parent is not None:
for field in table_meta_parent['fields'].keys():
if 'ref' in table_meta_parent['fields'][field].keys():
to_drop_parent.append(field)


flat = data[parent_table].drop(to_drop_parent, axis=1).set_index(parent_key).merge(
data[child_table].drop(to_drop_child, axis=1).set_index(child_key),
how='outer',
left_index=True,
right_index=True,
Expand Down Expand Up @@ -98,10 +125,78 @@ def compute(cls, real_data, synthetic_data, metadata=None, foreign_keys=None):
raise ValueError('No foreign keys given')

scores = []

for foreign_key in foreign_keys:
real = cls._denormalize(real_data, foreign_key)
synth = cls._denormalize(synthetic_data, foreign_key)
scores.append(cls.single_table_metric.compute(real, synth))
parent_table, parent_key, child_table, child_key = foreign_key

# Keep attributes only
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

no_attribute_parent = []
no_attribute_child = []
if metadata is not None:
table_meta_parent = metadata['tables'][parent_table]
table_meta_child = metadata['tables'][child_table]
else:
table_meta_parent = None
table_meta_child = None

if table_meta_parent is not None:
if 'primary_key' in table_meta_parent:
no_attribute_parent.append(table_meta_parent['primary_key'])
for field in table_meta_parent['fields'].keys():
if 'ref' in table_meta_parent['fields'][field].keys():
no_attribute_parent.append(field)
if table_meta_child is not None:
if 'primary_key' in table_meta_child:
no_attribute_child.append(table_meta_child['primary_key'])
for field in table_meta_child['fields'].keys():
if 'ref' in table_meta_child['fields'][field].keys():
no_attribute_child.append(field)

for c in real_data[parent_table].columns:
if c not in no_attribute_parent:
real_data[parent_table] = real_data[parent_table].rename(columns={c: "parent."+c}).copy()
synthetic_data[parent_table] = synthetic_data[parent_table].rename(columns={c: "parent."+c}).copy()
if c in metadata['tables'][parent_table]['fields'].keys():
metadata['tables'][parent_table]['fields']["parent."+c] = metadata['tables'][parent_table]['fields'].pop(c)
for c in real_data[child_table].columns:
if c not in no_attribute_child:
real_data[child_table] = real_data[child_table].rename(columns={c: "child."+c}).copy()
synthetic_data[child_table] = synthetic_data[child_table].rename(columns={c: "child."+c}).copy()
if c in metadata['tables'][child_table]['fields'].keys():
metadata['tables'][child_table]['fields']["child."+c] = metadata['tables'][child_table]['fields'].pop(c)

# Denormalize and apply model
real = cls._denormalize(real_data, foreign_key, metadata)
synth = cls._denormalize(synthetic_data, foreign_key, metadata)


metadata_merged = {'fields': {}}
for field in real.columns:
if field in metadata['tables'][parent_table]['fields'].keys():
metadata_merged['fields'][field] = metadata['tables'][parent_table]['fields'][field]
elif field in metadata['tables'][child_table]['fields'].keys():
metadata_merged['fields'][field] = metadata['tables'][child_table]['fields'][field]

for c in real_data[parent_table].columns:
if c not in no_attribute_parent:
to_c = c[-len(c)+len('parent.'):]
real_data[parent_table] = real_data[parent_table].rename(columns={c: to_c}).copy()
synthetic_data[parent_table] = synthetic_data[parent_table].rename(columns={c: to_c}).copy()
if c in metadata['tables'][parent_table]['fields'].keys():
metadata['tables'][parent_table]['fields'][to_c] = metadata['tables'][parent_table]['fields'].pop(c)
for c in real_data[child_table].columns:
if c not in no_attribute_child:
to_c = c[-len(c)+len('child.'):]
real_data[child_table] = real_data[child_table].rename(columns={c: to_c}).copy()
synthetic_data[child_table] = synthetic_data[child_table].rename(columns={c: to_c}).copy()
if c in metadata['tables'][child_table]['fields'].keys():
metadata['tables'][child_table]['fields'][to_c] = metadata['tables'][child_table]['fields'].pop(c)

score = cls.single_table_metric.compute(real, synth, metadata_merged)
scores.append(score)

return np.mean(scores)

Expand Down
14 changes: 10 additions & 4 deletions sdmetrics/single_table/detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,16 @@ def compute(cls, real_data, synthetic_data, metadata=None):
else:
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data


for field in metadata['fields'].keys():
if 'ref' in metadata['fields'][field].keys():
transformed_real_data = transformed_real_data.drop(field, axis=1)
transformed_synthetic_data = transformed_synthetic_data.drop(field, axis=1)

ht = HyperTransformer()
transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy()
transformed_synthetic_data = ht.transform(transformed_synthetic_data).to_numpy()
col_names = list(transformed_real_data.columns)
transformed_real_data = ht.fit_transform(transformed_real_data[col_names]).to_numpy()
transformed_synthetic_data = ht.transform(transformed_synthetic_data[col_names]).to_numpy()
X = np.concatenate([transformed_real_data, transformed_synthetic_data])
y = np.hstack([
np.ones(len(transformed_real_data)), np.zeros(len(transformed_synthetic_data))
Expand All @@ -88,7 +94,7 @@ def compute(cls, real_data, synthetic_data, metadata=None):

try:
scores = []
kf = StratifiedKFold(n_splits=3, shuffle=True)
kf = StratifiedKFold(n_splits=3, shuffle=True, random_state=1234)
for train_index, test_index in kf.split(X, y):
y_pred = cls._fit_predict(X[train_index], y[train_index], X[test_index])
roc_auc = roc_auc_score(y[test_index], y_pred)
Expand Down