Skip to content

Commit

Permalink
Add Chinese Company Name Support and Inspector (#201)
Browse files Browse the repository at this point in the history
* add ChineseCompanyNameInspector

* add chn_company_name support in ChnPiiGenerator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MooooCat and pre-commit-ci[bot] committed Jul 11, 2024
1 parent 5ef8916 commit 531a3e9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
15 changes: 15 additions & 0 deletions sdgx/data_models/inspectors/personal.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ def has_number(s):
return True


# 公司名
class ChineseCompanyNameInspector(RegexInspector):
pattern = r".*?公司.*?"

_match_percentage = 0.7

data_type_name = "chinese_company_name"

_inspect_level = 40

pii = False


@hookimpl
def register(manager):
manager.register("EmailInspector", EmailInspector)
Expand All @@ -198,3 +211,5 @@ def register(manager):
manager.register("ChineseNameInspector", ChineseNameInspector)

manager.register("EnglishNameInspector", EnglishNameInspector)

manager.register("ChineseCompanyNameInspector", ChineseCompanyNameInspector)
18 changes: 17 additions & 1 deletion sdgx/data_processors/generators/chn_pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ class ChnPiiGenerator(PIIGenerator):

chn_name_columns_list: list = []

chn_company_name_list: list = []

fitted: bool = False

@property
def chn_pii_columns(self):
return self.chn_id_columns_list + self.chn_name_columns_list + self.chn_phone_columns_list
return (
self.chn_id_columns_list
+ self.chn_name_columns_list
+ self.chn_phone_columns_list
+ self.chn_company_name_list
)

def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):

Expand All @@ -40,6 +47,9 @@ def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
if data_type == "china_mainland_id":
self.chn_id_columns_list.append(each_col)
continue
if data_type == "chinese_company_name":
self.chn_company_name_list.append(each_col)

self.fitted = True

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -56,6 +66,7 @@ def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
# if empty, return directly
if not self.chn_pii_columns:
return processed_data

df_length = processed_data.shape[0]

# chn id
Expand All @@ -73,6 +84,11 @@ def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
each_email_col = [fake.name() for _ in range(df_length)]
each_email_df = pd.DataFrame({each_col_name: each_email_col})
processed_data = self.attach_columns(processed_data, each_email_df)
# chn company
for each_col_name in self.chn_company_name_list:
each_company_col = [fake.company() for _ in range(df_length)]
each_company_df = pd.DataFrame({each_col_name: each_company_col})
processed_data = self.attach_columns(processed_data, each_company_df)

return processed_data

Expand Down
22 changes: 22 additions & 0 deletions tests/data_models/inspector/test_personal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChinaMainlandMobilePhoneInspector,
ChinaMainlandPostCode,
ChinaMainlandUnifiedSocialCreditCode,
ChineseCompanyNameInspector,
ChineseNameInspector,
EmailInspector,
EnglishNameInspector,
Expand Down Expand Up @@ -260,5 +261,26 @@ def test_eng_name_inspector_generated_data(chn_personal_test_df: pd.DataFrame):
assert inspector_ENG_name.pii is True


# Chinese Company Name
def test_chn_company_inspector_demo_data(raw_data):
inspector_PostCode = ChineseCompanyNameInspector()
inspector_PostCode.fit(raw_data)
assert not inspector_PostCode.regex_columns
assert sorted(inspector_PostCode.inspect()["chinese_company_name_columns"]) == sorted([])
assert inspector_PostCode.inspect_level == 40
assert inspector_PostCode.pii is False


def test_chn_company_inspector_generated_data(chn_personal_test_df: pd.DataFrame):
inspector_PostCode = ChineseCompanyNameInspector()
inspector_PostCode.fit(chn_personal_test_df)
assert inspector_PostCode.regex_columns
assert sorted(inspector_PostCode.inspect()["chinese_company_name_columns"]) == sorted(
["company_name"]
)
assert inspector_PostCode.inspect_level == 40
assert inspector_PostCode.pii is False


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])
7 changes: 7 additions & 0 deletions tests/data_processors/generators/test_chn_pii_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame):
assert "chn_name" in chn_personal_test_df.columns
assert "mobile_phone_no" in chn_personal_test_df.columns
assert "ssn_sfz" in chn_personal_test_df.columns
assert "company_name" in chn_personal_test_df.columns

# get metadata
metadata_df = Metadata.from_dataframe(chn_personal_test_df)
Expand All @@ -89,6 +90,7 @@ def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame):
assert pii_generator.chn_name_columns_list == ["chn_name"]
assert pii_generator.chn_phone_columns_list == ["mobile_phone_no"]
assert pii_generator.chn_id_columns_list == ["ssn_sfz"]
assert pii_generator.chn_company_name_list == ["company_name"]

converted_df = pii_generator.convert(chn_personal_test_df)
assert len(converted_df) == len(chn_personal_test_df)
Expand All @@ -99,12 +101,14 @@ def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame):
assert "chn_name" not in converted_df.columns
assert "mobile_phone_no" not in converted_df.columns
assert "ssn_sfz" not in converted_df.columns
assert "company_name" not in converted_df.columns

reverse_converted_df = pii_generator.reverse_convert(converted_df)
assert len(reverse_converted_df) == len(chn_personal_test_df)
assert "chn_name" in reverse_converted_df.columns
assert "mobile_phone_no" in reverse_converted_df.columns
assert "ssn_sfz" in reverse_converted_df.columns
assert "company_name" in reverse_converted_df.columns
# each generated value is sfz
for each_value in chn_personal_test_df["ssn_sfz"].values:
assert len(each_value) == 18
Expand All @@ -119,3 +123,6 @@ def test_chn_pii_generator(chn_personal_test_df: pd.DataFrame):
assert len(each_value) == 11
pattern = r"^1[3-9]\d{9}$"
assert bool(re.match(pattern, each_value))
for each_value in chn_personal_test_df["company_name"].values:
pattern = r".*?公司.*?"
assert bool(re.match(pattern, each_value))

0 comments on commit 531a3e9

Please sign in to comment.