From dae869ec9791df0eb0e2d03ec254df1e4962f642 Mon Sep 17 00:00:00 2001 From: MoooCat <141886018+MooooCat@users.noreply.github.com> Date: Tue, 16 Jan 2024 21:39:26 +0800 Subject: [PATCH] Add mutual information metric (#101) * test * test_v2 * no-test * pair_v1 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove_old_mi_sim * modify single&multi_table MISim * modify single_mi_sim by using pair_sim instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify multi_mi_sim by using pair_sim instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change_class_name_err * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify_paircolumn * mi only needs dataframe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify based on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * complete test_mi_sim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify test file * change_var_name * Update sdgx/metrics/multi_table/multitable_mi_sim.py Co-authored-by: MoooCat <141886018+MooooCat@users.noreply.github.com> * add MULTI_TABLE_DEMO_DATA * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify comments * JSD->MISIM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify base of pair_column * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add cls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change self into cls instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change cls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * series2array * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test * test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add label_encoder for category in mi_sim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use series.array * change le_fit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change transform type to np.array instead of list * add astype * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * series2array * foo * change test_suit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * all right? * all right * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Z712023 <3422685015@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Z712023 <132286135+Z712023@users.noreply.github.com> --- sdgx/metrics/column/base.py | 2 - sdgx/metrics/multi_table/base.py | 2 +- sdgx/metrics/multi_table/multitable_mi_sim.py | 71 ++++++++++++++ sdgx/metrics/pair_column/base.py | 75 ++++++++++++++ sdgx/metrics/pair_column/mi_sim.py | 97 +++++++++++++++++++ sdgx/metrics/single_table/base.py | 4 +- sdgx/metrics/single_table/single_mi_sim.py | 67 +++++++++++++ sdgx/utils.py | 9 +- tests/metrics/test_MISim.py | 77 +++++++++++++++ 9 files changed, 398 insertions(+), 6 deletions(-) create mode 100644 sdgx/metrics/multi_table/multitable_mi_sim.py create mode 100644 sdgx/metrics/pair_column/base.py create mode 100644 sdgx/metrics/pair_column/mi_sim.py create mode 100644 sdgx/metrics/single_table/single_mi_sim.py create mode 100644 tests/metrics/test_MISim.py diff --git a/sdgx/metrics/column/base.py b/sdgx/metrics/column/base.py index 7302404c..c6933572 100644 --- a/sdgx/metrics/column/base.py +++ b/sdgx/metrics/column/base.py @@ -61,10 +61,8 @@ def calculate( cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame ): """Calculate the metric value between columns between real table and synthetic table. - Args: real_data(pd.DataFrame or pd.Series): the real (original) data table / column. - synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column. """ # This method should first check the input diff --git a/sdgx/metrics/multi_table/base.py b/sdgx/metrics/multi_table/base.py index 1604e1b7..da8301b0 100644 --- a/sdgx/metrics/multi_table/base.py +++ b/sdgx/metrics/multi_table/base.py @@ -70,7 +70,7 @@ def check_output(raw_metric_value: float): """Check the output value. Args: - raw_metric_value (float): the calculated raw value of the JSD metric. + raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. """ raise NotImplementedError() diff --git a/sdgx/metrics/multi_table/multitable_mi_sim.py b/sdgx/metrics/multi_table/multitable_mi_sim.py new file mode 100644 index 00000000..34e73e92 --- /dev/null +++ b/sdgx/metrics/multi_table/multitable_mi_sim.py @@ -0,0 +1,71 @@ +import numpy as np +import pandas as pd +from scipy.stats import entropy +from sklearn.metrics.cluster import normalized_mutual_info_score + +from sdgx.metrics.multi_table.base import MultiTableMetric +from sdgx.metrics.pair_column.mi_sim import MISim + + +class MISim(MultiTableMetric): + """MISim : Mutual Information Similarity + + This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. + + Currently, we support discrete and continuous(need to be discretized) columns as inputs. + """ + + def __init__(self) -> None: + super().__init__() + self.lower_bound = 0 + self.upper_bound = 1 + self.metric_name = "mutual_information_similarity" + self.numerical_bins = 50 + + @classmethod + def calculate( + real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict + ) -> pd.DataFrame: + """ + Calculate the Mutual Information Similarity between a real column and a synthetic column. + Args: + real_data (pd.DataFrame): The real data. + synthetic_data (pd.DataFrame): The synthetic data. + metadata(dict): The metadata that describes the data type of each column + + Returns: + MI_similarity (float): The metric value. + """ + + # 传入概率分布数组 + + columns = synthetic_data.columns + n = len(columns) + mi_sim_instance = MISim() + nMI_sim = np.zeros((n, n)) + + for i in range(len(columns)): + for j in range(len(columns)): + syn_data = pd.concat( + [synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1 + ) + real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1) + + nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata) + + MI_sim = np.sum(nMI_sim) / n / n + # test + MISim.check_output(MI_sim) + + return MI_sim + + @classmethod + def check_output(cls, raw_metric_value: float): + """Check the output value. + + Args: + raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. + """ + instance = cls() + if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: + raise ValueError diff --git a/sdgx/metrics/pair_column/base.py b/sdgx/metrics/pair_column/base.py new file mode 100644 index 00000000..8f7e05d7 --- /dev/null +++ b/sdgx/metrics/pair_column/base.py @@ -0,0 +1,75 @@ +import pandas as pd + +from sdgx.log import logger + + +class PairMetric(object): + """PairMetric + Metrics used to evaluate the quality of synthetic data columns. + """ + + upper_bound = None + lower_bound = None + metric_name = "Correlation" + + def __init__(self) -> None: + pass + + @classmethod + def check_input(cls, src_col: pd.Series, tar_col: pd.Series, metadata: dict): + """Input check for table input. + Args: + src_data(pd.Series ): the source data column. + tar_data(pd.Series): the target data column . + metadata(dict): The metadata that describes the data type of each column + """ + # Input parameter must not contain None value + if real_data is None or synthetic_data is None: + raise TypeError("Input contains None.") + # check column_names + tar_name = tar_col.name + src_name = src_col.name + + # check column_types + if metadata[tar_name] != metadata[src_name]: + raise TypeError("Type of Pair is Conflicting.") + + # if type is pd.Series, return directly + if isinstance(real_data, pd.Series): + return src_col, tar_col + + # if type is not pd.Series or pd.DataFrame tranfer it to Series + try: + src_col = pd.Series(src_col) + tar_col = pd.Series(tar_col) + return src_col, tar_col + except Exception as e: + logger.error(f"An error occurred while converting to pd.Series: {e}") + + return None, None + + @classmethod + def calculate(cls, src_col: pd.Series, tar_col: pd.Series, metadata): + """Calculate the metric value between pair-columns between real table and synthetic table. + + Args: + src_data(pd.Series ): the source data column. + tar_data(pd.Series): the target data column . + metadata(dict): The metadata that describes the data type of each column + """ + # This method should first check the input + # such as: + real_data, synthetic_data = PairMetric.check_input(src_col, tar_col) + + raise NotImplementedError() + + @classmethod + def check_output(cls, raw_metric_value: float): + """Check the output value. + + Args: + raw_metric_value (float): the calculated raw value of the Mutual Information. + """ + raise NotImplementedError() + + pass diff --git a/sdgx/metrics/pair_column/mi_sim.py b/sdgx/metrics/pair_column/mi_sim.py new file mode 100644 index 00000000..b8963e9a --- /dev/null +++ b/sdgx/metrics/pair_column/mi_sim.py @@ -0,0 +1,97 @@ +import numpy as np +import pandas as pd +from scipy.stats import entropy +from sklearn.metrics.cluster import normalized_mutual_info_score +from sklearn.preprocessing import LabelEncoder + +from sdgx.metrics.pair_column.base import PairMetric +from sdgx.utils import time2int + + +class MISim(PairMetric): + """MISim : Mutual Information Similarity + + This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. + + Currently, we support discrete and continuous(need to be discretized) columns as inputs. + """ + + def __init__(instance) -> None: + super().__init__() + instance.lower_bound = 0 + instance.upper_bound = 1 + instance.metric_name = "mutual_information_similarity" + instance.numerical_bins = 50 + + @classmethod + def calculate( + cls, + src_col: pd.Series, + tar_col: pd.Series, + metadata: dict, + ) -> float: + """ + Calculate the MI similarity for the source data colum and the target data column. + Args: + src_data(pd.Series ): the source data column. + tar_data(pd.Series): the target data column . + metadata(dict): The metadata that describes the data type of each columns + Returns: + MI_similarity (float): The metric value. + """ + + # 传入概率分布数组 + instance = cls() + + col_name = src_col.name + data_type = metadata[col_name] + + if data_type == "numerical": + x = np.array(src_col.array) + src_col = pd.cut( + x, + instance.numerical_bins, + labels=range(instance.numerical_bins), + ) + x = np.array(tar_col.array) + tar_col = pd.cut( + x, + instance.numerical_bins, + labels=range(instance.numerical_bins), + ) + src_col = src_col.to_numpy() + tar_col = tar_col.to_numpy() + + elif data_type == "category": + le = LabelEncoder() + src_list = list(set(src_col.array)) + tar_list = list(set(tar_col.array)) + fit_list = tar_list + src_list + le.fit(fit_list) + + src_col = le.transform(np.array(src_col.array)) + tar_col = le.transform(np.array(tar_col.array)) + + elif data_type == "datetime": + src_col = src_col.apply(time2int) + tar_col = tar_col.apply(time2int) + src_col = pd.cut( + src_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins) + ) + tar_col = pd.cut( + tar_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins) + ) + src_col = src_col.to_numpy() + tar_col = tar_col.to_numpy() + + MI_sim = normalized_mutual_info_score(src_col, tar_col) + return MI_sim + + @classmethod + def check_output(cls, raw_metric_value: float): + """Check the output value. + + Args: + raw_metric_value (float): the calculated raw value of the MI similarity. + """ + pass diff --git a/sdgx/metrics/single_table/base.py b/sdgx/metrics/single_table/base.py index 6212f362..7addd8ae 100644 --- a/sdgx/metrics/single_table/base.py +++ b/sdgx/metrics/single_table/base.py @@ -55,7 +55,7 @@ def check_input(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame): return None, None - def calculate(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame): + def calculate(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame): """Calculate the metric value between a real table and a synthetic table. Args: @@ -71,7 +71,7 @@ def check_output(raw_metric_value: float): """Check the output value. Args: - raw_metric_value (float): the calculated raw value of the JSD metric. + raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. """ raise NotImplementedError() diff --git a/sdgx/metrics/single_table/single_mi_sim.py b/sdgx/metrics/single_table/single_mi_sim.py new file mode 100644 index 00000000..38d56c92 --- /dev/null +++ b/sdgx/metrics/single_table/single_mi_sim.py @@ -0,0 +1,67 @@ +import numpy as np +import pandas as pd +from scipy.stats import entropy +from sklearn.metrics.cluster import normalized_mutual_info_score + +from sdgx.metrics.pair_column.mi_sim import MISim +from sdgx.metrics.single_table.base import SingleTableMetric + + +class SinTabMISim(SingleTableMetric): + """MISim : Mutual Information Similarity + + This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. + + Currently, we support discrete and continuous(need to be discretized) columns as inputs. + """ + + def __init__(self) -> None: + super().__init__() + self.lower_bound = 0 + self.upper_bound = 1 + self.metric_name = "mutual_information_similarity" + self.numerical_bins = 50 + + @classmethod + def calculate(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata) -> pd.DataFrame: + """ + Calculate the Mutual Information Similarity between a real column and a synthetic column. + Args: + real_data (pd.DataFrame): The real data. + synthetic_data (pd.DataFrame): The synthetic data. + metadata(dict): The metadata that describes the data type of each column + Returns: + MI_similarity (float): The metric value. + """ + + # 传入概率分布数组 + + columns = synthetic_data.columns + n = len(columns) + mi_sim_instance = MISim() + nMI_sim = np.zeros((n, n)) + + for i in range(len(columns)): + for j in range(len(columns)): + syn_data = pd.concat( + [synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1 + ) + real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1) + + nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata) + + MI_sim = np.sum(nMI_sim) / n / n + MISim.check_output(MI_sim) + + return MI_sim + + @classmethod + def check_output(cls, raw_metric_value: float): + """Check the output value. + + Args: + raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. + """ + instance = cls() + if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: + raise ValueError diff --git a/sdgx/utils.py b/sdgx/utils.py index b8799ce6..66fbfbde 100644 --- a/sdgx/utils.py +++ b/sdgx/utils.py @@ -3,6 +3,7 @@ import functools import socket import threading +import time import urllib.request import warnings from contextlib import closing @@ -26,8 +27,8 @@ "find_free_port", "download_multi_table_demo_data", "get_demo_single_table", + "time2int", ] - MULTI_TABLE_DEMO_DATA = { "rossman": { "parent_table": "store", @@ -99,6 +100,12 @@ def get_demo_single_table(data_dir: str | Path = "./dataset"): return pd_obj, discrete_cols +def time2int(datetime, form): + time_array = time.strptime(datetime, form) + time_stamp = int(time.mktime(time_array)) + return time_stamp + + class Singleton(type): """ metaclass for singleton, thread-safe. diff --git a/tests/metrics/test_MISim.py b/tests/metrics/test_MISim.py new file mode 100644 index 00000000..76879ace --- /dev/null +++ b/tests/metrics/test_MISim.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import random + +import numpy as np +import pandas as pd +import pytest + +from sdgx.metrics.pair_column.mi_sim import MISim + +# 创建测试数据 + + +@pytest.fixture +def test_data_category(): + role_set = ["admin", "user", "guest"] + df = pd.DataFrame( + { + "role1": [random.choice(role_set) for _ in range(10)], + "role2": [random.choice(role_set) for _ in range(10)], + } + ) + return df + + +@pytest.fixture +def test_data_num(): + df = pd.DataFrame( + { + "feature_x": [random.random() for _ in range(10)], + "feature_y": [random.random() for _ in range(10)], + } + ) + return df + + +@pytest.fixture +def mi_sim_instance(): + return MISim() + + +def test_MISim_discrete(test_data_category, mi_sim_instance): + metadata = {"role1": "category", "role2": "category"} + col_src = "role1" + col_tar = "role2" + result = mi_sim_instance.calculate( + test_data_category[col_src], test_data_category[col_tar], metadata + ) + result1 = mi_sim_instance.calculate( + test_data_category[col_src], test_data_category[col_src], metadata + ) + result2 = mi_sim_instance.calculate( + test_data_category[col_tar], test_data_category[col_src], metadata + ) + + assert result >= 0 + assert result <= 1 + assert result1 == 1 + assert result2 == result + + +def test_MISim_continuous(test_data_num, mi_sim_instance): + metadata = {"feature_x": "numerical", "feature_y": "numerical"} + col_src = "feature_x" + col_tar = "feature_y" + result = mi_sim_instance.calculate(test_data_num[col_src], test_data_num[col_tar], metadata) + result1 = mi_sim_instance.calculate(test_data_num[col_src], test_data_num[col_src], metadata) + result2 = mi_sim_instance.calculate(test_data_num[col_tar], test_data_num[col_src], metadata) + + assert result >= 0 + assert result <= 1 + assert result1 == 1 + assert result2 == result + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__])