-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ChnPiiGenerator and Enhance Models (#191)
* add generator * add testcases * update DataProcessorManager * add `fit_data_empty ` in bas class * update data_processor order * add `fit_data_empty` support in ctgan * Improving the robustness of the 'remove_columns' method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
14ad5e8
commit a5936e7
Showing
6 changed files
with
250 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pandas as pd | ||
from faker import Faker | ||
|
||
from sdgx.data_models.metadata import Metadata | ||
from sdgx.data_processors.extension import hookimpl | ||
from sdgx.data_processors.generators.pii import PIIGenerator | ||
|
||
fake = Faker(locale="zh_CN") | ||
|
||
|
||
class ChnPiiGenerator(PIIGenerator): | ||
""" """ | ||
|
||
chn_id_columns_list: list = [] | ||
|
||
chn_phone_columns_list: list = [] | ||
|
||
chn_name_columns_list: list = [] | ||
|
||
fitted: bool = False | ||
|
||
@property | ||
def chn_pii_columns(self): | ||
return self.chn_id_columns_list + self.chn_name_columns_list + self.chn_phone_columns_list | ||
|
||
def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): | ||
|
||
for each_col in metadata.column_list: | ||
data_type = metadata.get_column_data_type(each_col) | ||
if data_type == "chinese_name": | ||
self.chn_name_columns_list.append(each_col) | ||
continue | ||
if data_type == "china_mainland_mobile_phone": | ||
self.chn_phone_columns_list.append(each_col) | ||
continue | ||
if data_type == "china_mainland_id": | ||
self.chn_id_columns_list.append(each_col) | ||
continue | ||
self.fitted = True | ||
|
||
def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame: | ||
# if empty, return directly | ||
if not self.chn_pii_columns: | ||
return raw_data | ||
processed_data = raw_data | ||
# remove every chn pii column from the dataframe | ||
for each_col in self.chn_pii_columns: | ||
processed_data = self.remove_columns(processed_data, each_col) | ||
return processed_data | ||
|
||
def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame: | ||
# if empty, return directly | ||
if not self.chn_pii_columns: | ||
return processed_data | ||
df_length = processed_data.shape[0] | ||
|
||
# chn id | ||
for each_col_name in self.chn_id_columns_list: | ||
each_email_col = [fake.ssn() for _ in range(df_length)] | ||
each_email_df = pd.DataFrame({each_col_name: each_email_col}) | ||
processed_data = self.attach_columns(processed_data, each_email_df) | ||
# chn phone | ||
for each_col_name in self.chn_phone_columns_list: | ||
each_email_col = [fake.phone_number() for _ in range(df_length)] | ||
each_email_df = pd.DataFrame({each_col_name: each_email_col}) | ||
processed_data = self.attach_columns(processed_data, each_email_df) | ||
# chn name | ||
for each_col_name in self.chn_name_columns_list: | ||
each_email_col = [fake.name() for _ in range(df_length)] | ||
each_email_df = pd.DataFrame({each_col_name: each_email_col}) | ||
processed_data = self.attach_columns(processed_data, each_email_df) | ||
|
||
return processed_data | ||
|
||
|
||
@hookimpl | ||
def register(manager): | ||
manager.register("chnpiigenerator", ChnPiiGenerator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ class MLSynthesizerModel(SynthesizerModel): | |
""" | ||
Base class for ML models | ||
""" | ||
|
||
fit_data_empty: bool = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
tests/data_processors/generators/test_chn_pii_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import random | ||
import re | ||
|
||
import pandas as pd | ||
import pytest | ||
from faker import Faker | ||
from pydantic import BaseModel, EmailStr | ||
|
||
from sdgx.data_models.metadata import Metadata | ||
from sdgx.data_processors.generators.chn_pii import ChnPiiGenerator | ||
|
||
fake = Faker(locale="zh_CN") | ||
fake_en = Faker(["en_US"]) | ||
|
||
|
||
@pytest.fixture | ||
def chn_personal_test_df(): | ||
row_cnt = 1000 | ||
today = datetime.datetime.today() | ||
X = [] | ||
header = [ | ||
"ssn_sfz", | ||
"chn_name", | ||
"eng_name", | ||
"gender", | ||
"birth_date", | ||
"age", | ||
"email", | ||
"mobile_phone_no", | ||
"chn_address", | ||
"postcode", | ||
"job", | ||
"company_name", | ||
] | ||
for _ in range(row_cnt): | ||
each_gender = random.choice(["male", "female"]) | ||
if each_gender == "male": | ||
each_name = fake.last_name() + fake.name_male() | ||
else: | ||
each_name = fake.last_name() + fake.name_female() | ||
each_eng_name = fake_en.name() | ||
each_birth_date = fake.date() | ||
each_age = today.year - int(each_birth_date[:4]) | ||
each_email = fake.email() | ||
each_phone = fake.phone_number() | ||
each_sfz = fake.ssn() | ||
each_address = fake.address() | ||
each_job = fake.job() | ||
each_corp = fake.company() | ||
each_postcode = fake.postcode() | ||
|
||
each_x = [ | ||
each_sfz, | ||
each_name, | ||
each_eng_name, | ||
each_gender, | ||
each_birth_date, | ||
each_age, | ||
each_email, | ||
each_phone, | ||
each_address, | ||
each_postcode, | ||
each_job, | ||
each_corp, | ||
] | ||
|
||
X.append(each_x) | ||
|
||
yield pd.DataFrame(X, columns=header) | ||
|
||
|
||
def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame): | ||
|
||
assert "chn_name" in chn_personal_test_df.columns | ||
assert "mobile_phone_no" in chn_personal_test_df.columns | ||
assert "ssn_sfz" in chn_personal_test_df.columns | ||
|
||
# get metadata | ||
metadata_df = Metadata.from_dataframe(chn_personal_test_df) | ||
|
||
# generator | ||
pii_generator = ChnPiiGenerator() | ||
assert not pii_generator.fitted | ||
pii_generator.fit(metadata_df) | ||
assert pii_generator.fitted | ||
assert pii_generator.chn_name_columns_list == ["chn_name"] | ||
assert pii_generator.chn_phone_columns_list == ["mobile_phone_no"] | ||
assert pii_generator.chn_id_columns_list == ["ssn_sfz"] | ||
|
||
converted_df = pii_generator.convert(chn_personal_test_df) | ||
assert len(converted_df) == len(chn_personal_test_df) | ||
assert converted_df.shape[1] != chn_personal_test_df.shape[1] | ||
assert converted_df.shape[1] == chn_personal_test_df.shape[1] - len( | ||
pii_generator.chn_pii_columns | ||
) | ||
assert "chn_name" not in converted_df.columns | ||
assert "mobile_phone_no" not in converted_df.columns | ||
assert "ssn_sfz" not in converted_df.columns | ||
|
||
reverse_converted_df = pii_generator.reverse_convert(converted_df) | ||
assert len(reverse_converted_df) == len(chn_personal_test_df) | ||
assert "chn_name" in reverse_converted_df.columns | ||
assert "mobile_phone_no" in reverse_converted_df.columns | ||
assert "ssn_sfz" in reverse_converted_df.columns | ||
# each generated value is sfz | ||
for each_value in chn_personal_test_df["ssn_sfz"].values: | ||
assert len(each_value) == 18 | ||
pattern = r"^\d{17}[0-9X]$" | ||
assert bool(re.match(pattern, each_value)) | ||
for each_value in chn_personal_test_df["chn_name"].values: | ||
pattern = r"^[\u4e00-\u9fa5]{2,5}$" | ||
assert len(each_value) >= 2 and len(each_value) <= 5 | ||
assert bool(re.match(pattern, each_value)) | ||
for each_value in chn_personal_test_df["mobile_phone_no"].values: | ||
assert each_value.startswith("1") | ||
assert len(each_value) == 11 | ||
pattern = r"^1[3-9]\d{9}$" | ||
assert bool(re.match(pattern, each_value)) |