Skip to content

Commit

Permalink
Add ChnPiiGenerator and Enhance Models (#191)
Browse files Browse the repository at this point in the history
* add generator

* add testcases

* update DataProcessorManager

* add `fit_data_empty ` in bas class

* update data_processor order

* add `fit_data_empty` support in ctgan

* Improving the robustness of the 'remove_columns' method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MooooCat and pre-commit-ci[bot] authored Jun 24, 2024
1 parent 14ad5e8 commit a5936e7
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 3 deletions.
8 changes: 7 additions & 1 deletion sdgx/data_processors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import SynthesizerProcessorError
from sdgx.log import logger


class DataProcessor:
Expand Down Expand Up @@ -72,7 +73,12 @@ def remove_columns(tabular_data: pd.DataFrame, column_name_to_remove: list) -> p
result_data = tabular_data.copy()

# Remove specified columns
result_data = result_data.drop(columns=column_name_to_remove)
try:
result_data = result_data.drop(columns=column_name_to_remove)
except KeyError:
logger.warning(
"Duplicate column removal occurred, which might lead to unintended consequences."
)

return result_data

Expand Down
82 changes: 82 additions & 0 deletions sdgx/data_processors/generators/chn_pii.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

from typing import Any

import pandas as pd
from faker import Faker

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.extension import hookimpl
from sdgx.data_processors.generators.pii import PIIGenerator

fake = Faker(locale="zh_CN")


class ChnPiiGenerator(PIIGenerator):
""" """

chn_id_columns_list: list = []

chn_phone_columns_list: list = []

chn_name_columns_list: list = []

fitted: bool = False

@property
def chn_pii_columns(self):
return self.chn_id_columns_list + self.chn_name_columns_list + self.chn_phone_columns_list

def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):

for each_col in metadata.column_list:
data_type = metadata.get_column_data_type(each_col)
if data_type == "chinese_name":
self.chn_name_columns_list.append(each_col)
continue
if data_type == "china_mainland_mobile_phone":
self.chn_phone_columns_list.append(each_col)
continue
if data_type == "china_mainland_id":
self.chn_id_columns_list.append(each_col)
continue
self.fitted = True

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
# if empty, return directly
if not self.chn_pii_columns:
return raw_data
processed_data = raw_data
# remove every chn pii column from the dataframe
for each_col in self.chn_pii_columns:
processed_data = self.remove_columns(processed_data, each_col)
return processed_data

def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
# if empty, return directly
if not self.chn_pii_columns:
return processed_data
df_length = processed_data.shape[0]

# chn id
for each_col_name in self.chn_id_columns_list:
each_email_col = [fake.ssn() for _ in range(df_length)]
each_email_df = pd.DataFrame({each_col_name: each_email_col})
processed_data = self.attach_columns(processed_data, each_email_df)
# chn phone
for each_col_name in self.chn_phone_columns_list:
each_email_col = [fake.phone_number() for _ in range(df_length)]
each_email_df = pd.DataFrame({each_col_name: each_email_col})
processed_data = self.attach_columns(processed_data, each_email_df)
# chn name
for each_col_name in self.chn_name_columns_list:
each_email_col = [fake.name() for _ in range(df_length)]
each_email_df = pd.DataFrame({each_col_name: each_email_col})
processed_data = self.attach_columns(processed_data, each_email_df)

return processed_data


@hookimpl
def register(manager):
manager.register("chnpiigenerator", ChnPiiGenerator)
8 changes: 7 additions & 1 deletion sdgx/data_processors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ class DataProcessorManager(Manager):

preset_defalut_processors = [
p.lower()
for p in ["IntValueFormatter", "DatetimeFormatter", "NonValueTransformer", "EmailGenerator"]
for p in [
"NonValueTransformer",
"EmailGenerator",
"ChnPiiGenerator",
"IntValueFormatter",
"DatetimeFormatter",
]
] + ["ColumnOrderTransformer".lower()]
"""
preset_defalut_processors list stores the lowercase names of the transformers loaded by default. When using the synthesizer, they will be loaded by default to facilitate user operations.
Expand Down
2 changes: 2 additions & 0 deletions sdgx/models/ml/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ class MLSynthesizerModel(SynthesizerModel):
"""
Base class for ML models
"""

fit_data_empty: bool = False
32 changes: 31 additions & 1 deletion sdgx/models/ml/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **
if epochs is not None:
self._epochs = epochs
self._pre_fit(dataloader, discrete_columns)
if self.fit_data_empty:
logger.info("CTGAN fit finished because of empty df detected.")
return
logger.info("CTGAN prefit finished, start CTGAN training.")
self._fit(len(self._ndarry_loader))
logger.info("CTGAN training finished.")
Expand All @@ -221,7 +224,11 @@ def _pre_fit(self, dataloader: DataLoader, discrete_columns: list[str] = None) -
if not discrete_columns:
discrete_columns = []

self._validate_discrete_columns(dataloader.columns(), discrete_columns)
# self._validate_discrete_columns(dataloader.columns(), discrete_columns)
discrete_columns = self._filter_discrete_columns(dataloader.columns(), discrete_columns)
# if the df is empty, we don't need to do anything
if self.fit_data_empty:
return
# Fit Transformer and DataSampler
self._transformer = DataTransformer()
logger.info("Fitting model's transformer...")
Expand Down Expand Up @@ -364,6 +371,8 @@ def _fit(self, data_size: int):
)

def sample(self, count: int, *args, **kwargs) -> pd.DataFrame:
if self.fit_data_empty:
return pd.DataFrame(index=range(count))
return self._sample(count, *args, **kwargs)

@random_state
Expand Down Expand Up @@ -509,6 +518,27 @@ def _cond_loss(self, data, c, m):

return (loss * m).sum() / data.size()[0]

def _filter_discrete_columns(self, train_data, discrete_columns):
""" """
if isinstance(train_data, pd.DataFrame):
invalid_columns = set(discrete_columns) - set(train_data.columns)
elif isinstance(train_data, np.ndarray):
invalid_columns = []
for column in discrete_columns:
if column < 0 or column >= train_data.shape[1]:
invalid_columns.append(column)
elif isinstance(train_data, list):
invalid_columns = set(discrete_columns) - set(train_data)
else:
raise TypeError("``train_data`` should be either pd.DataFrame or np.array.")

rest_discrete_columns = set(discrete_columns) - set(invalid_columns)

if len(rest_discrete_columns) == 0:
self.fit_data_empty = True

return rest_discrete_columns

def _validate_discrete_columns(self, train_data, discrete_columns):
"""Check whether ``discrete_columns`` exists in ``train_data``.
Expand Down
121 changes: 121 additions & 0 deletions tests/data_processors/generators/test_chn_pii_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import datetime
import random
import re

import pandas as pd
import pytest
from faker import Faker
from pydantic import BaseModel, EmailStr

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.generators.chn_pii import ChnPiiGenerator

fake = Faker(locale="zh_CN")
fake_en = Faker(["en_US"])


@pytest.fixture
def chn_personal_test_df():
row_cnt = 1000
today = datetime.datetime.today()
X = []
header = [
"ssn_sfz",
"chn_name",
"eng_name",
"gender",
"birth_date",
"age",
"email",
"mobile_phone_no",
"chn_address",
"postcode",
"job",
"company_name",
]
for _ in range(row_cnt):
each_gender = random.choice(["male", "female"])
if each_gender == "male":
each_name = fake.last_name() + fake.name_male()
else:
each_name = fake.last_name() + fake.name_female()
each_eng_name = fake_en.name()
each_birth_date = fake.date()
each_age = today.year - int(each_birth_date[:4])
each_email = fake.email()
each_phone = fake.phone_number()
each_sfz = fake.ssn()
each_address = fake.address()
each_job = fake.job()
each_corp = fake.company()
each_postcode = fake.postcode()

each_x = [
each_sfz,
each_name,
each_eng_name,
each_gender,
each_birth_date,
each_age,
each_email,
each_phone,
each_address,
each_postcode,
each_job,
each_corp,
]

X.append(each_x)

yield pd.DataFrame(X, columns=header)


def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame):

assert "chn_name" in chn_personal_test_df.columns
assert "mobile_phone_no" in chn_personal_test_df.columns
assert "ssn_sfz" in chn_personal_test_df.columns

# get metadata
metadata_df = Metadata.from_dataframe(chn_personal_test_df)

# generator
pii_generator = ChnPiiGenerator()
assert not pii_generator.fitted
pii_generator.fit(metadata_df)
assert pii_generator.fitted
assert pii_generator.chn_name_columns_list == ["chn_name"]
assert pii_generator.chn_phone_columns_list == ["mobile_phone_no"]
assert pii_generator.chn_id_columns_list == ["ssn_sfz"]

converted_df = pii_generator.convert(chn_personal_test_df)
assert len(converted_df) == len(chn_personal_test_df)
assert converted_df.shape[1] != chn_personal_test_df.shape[1]
assert converted_df.shape[1] == chn_personal_test_df.shape[1] - len(
pii_generator.chn_pii_columns
)
assert "chn_name" not in converted_df.columns
assert "mobile_phone_no" not in converted_df.columns
assert "ssn_sfz" not in converted_df.columns

reverse_converted_df = pii_generator.reverse_convert(converted_df)
assert len(reverse_converted_df) == len(chn_personal_test_df)
assert "chn_name" in reverse_converted_df.columns
assert "mobile_phone_no" in reverse_converted_df.columns
assert "ssn_sfz" in reverse_converted_df.columns
# each generated value is sfz
for each_value in chn_personal_test_df["ssn_sfz"].values:
assert len(each_value) == 18
pattern = r"^\d{17}[0-9X]$"
assert bool(re.match(pattern, each_value))
for each_value in chn_personal_test_df["chn_name"].values:
pattern = r"^[\u4e00-\u9fa5]{2,5}$"
assert len(each_value) >= 2 and len(each_value) <= 5
assert bool(re.match(pattern, each_value))
for each_value in chn_personal_test_df["mobile_phone_no"].values:
assert each_value.startswith("1")
assert len(each_value) == 11
pattern = r"^1[3-9]\d{9}$"
assert bool(re.match(pattern, each_value))

0 comments on commit a5936e7

Please sign in to comment.