Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep: add a base class of metrics which can evaluate how similar the synthetic dataset is to real dataset #3

Closed
1 task done
MooooCat opened this issue Aug 17, 2023 · 1 comment · Fixed by #5
Closed
1 task done
Labels
sweep Assigns Sweep to an issue or pull request.

Comments

@MooooCat
Copy link
Contributor

MooooCat commented Aug 17, 2023

please add the code of the base class to metrics/base.py, and use python 3.

please try to define enough methods in the base class to prevent repeated additions when implementing metrics.

Checklist
  • sdgx/metrics/base.py
  • Remove the pass statement from the file.
    • Import the ABC and abstractmethod from the abc module at the top of the file.
    • Create a new class named BaseMetric that inherits from ABC.
    • Inside the BaseMetric class, define an __init__ method that takes two parameters: real_data and synthetic_data. These parameters should be stored as instance variables.
    • Still inside the BaseMetric class, define an abstract method named calculate that takes no parameters. This method will be used to calculate the metric and should be implemented in each subclass.
    • Still inside the BaseMetric class, define a method named validate_datasets that takes no parameters. This method should check if the real_data and synthetic_data instance variables are valid datasets. For now, this method can simply pass.
@sweep-ai sweep-ai bot added the sweep Assigns Sweep to an issue or pull request. label Aug 17, 2023
@sweep-ai
Copy link
Contributor

sweep-ai bot commented Aug 17, 2023

Here's the PR! #5.

⚡ Sweep Free Trial: I used GPT-4 to create this ticket. You have 3 GPT-4 tickets left for the month and 1 for the day. For more GPT-4 tickets, visit our payment portal. To retrigger Sweep edit the issue.


Step 1: 🔍 Code Search

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I looked at (click to expand). If some file is missing from here, you can mention the path in the ticket description.

class base_metric:
pass

---
sidebar_position: 1
---
# 快速入门
## 单表数据快速合成示例
```python
# 导入相关模块
from sdgx.tabular.synthesizers import CTGAN
from sdgx.tabular.data import get_single_table
import pandas as pd
# 读取数据
data = get_single_table()
```
真实数据如下:
```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K
[32561 rows x 15 columns]
```
```python
#定义模型
model = CTGAN()
#训练模型
model.fit(data)
# 生成合成数据
sampled = model.generate(num_rows=10)
```
合成数据如下:
```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```
## 多表数据快速合成示例
```python
# 导入相关模块
from sdgx.tabular.synthesizers import CWAMT
from sdgx.tabular.data import get_multi_table
import pandas as pd
# 读取数据
data = get_multi_table()
```
真实数据如下:
```
{'tables': {'table1': {'table_name': 'train', 'table_value': Store DayOfWeek Date ... Promo StateHoliday SchoolHoliday
0 1 5 2015-07-31 ... 1 0 1
1 2 5 2015-07-31 ... 1 0 1
2 3 5 2015-07-31 ... 1 0 1
3 4 5 2015-07-31 ... 1 0 1
4 5 5 2015-07-31 ... 1 0 1
... ... ... ... ... ... ... ...
1017204 1111 2 2013-01-01 ... 0 a 1
1017205 1112 2 2013-01-01 ... 0 a 1
1017206 1113 2 2013-01-01 ... 0 a 1
1017207 1114 2 2013-01-01 ... 0 a 1
1017208 1115 2 2013-01-01 ... 0 a 1
[1017209 rows x 9 columns]}, 'table2': {'table_name': 'store', 'table_value': Store StoreType ... Promo2SinceYear PromoInterval
0 1 c ... NaN NaN
1 2 a ... 2010.0 Jan,Apr,Jul,Oct
2 3 a ... 2011.0 Jan,Apr,Jul,Oct
3 4 c ... NaN NaN
4 5 a ... NaN NaN
... ... ... ... ... ...
1110 1111 a ... 2013.0 Jan,Apr,Jul,Oct
1111 1112 c ... NaN NaN
1112 1113 a ... NaN NaN
1113 1114 a ... NaN NaN
1114 1115 d ... 2012.0 Mar,Jun,Sept,Dec
[1115 rows x 10 columns]}}, 'relations': {'table1-table2': 'store'}}
```
```python
#定义模型
model = CWAMT()
#训练模型
model.fit(data)
# 生成合成数据
sampled = model.generate(num_rows=10)
```
合成数据如下:
```
{'table1': {'table_name': 'train', 'table_value': Store DayOfWeek Date ... Promo StateHoliday SchoolHoliday
0 3 2 2013-01-01 ... 0 a 1
1 5 2 2013-01-01 ... 0 a 1
2 5 2 2013-01-01 ... 0 a 1
3 6 2 2013-01-01 ... 0 a 1
4 2 2 2013-01-01 ... 0 a 1
5 1 2 2013-01-01 ... 0 a 1
6 7 2 2013-01-01 ... 0 a 1
7 2 2 2013-01-01 ... 0 a 1
8 8 2 2013-01-01 ... 0 a 1
9 5 2 2013-01-01 ... 0 a 1
10 9 2 2013-01-01 ... 0 a 1
11 3 2 2013-01-01 ... 0 a 1
12 2 2 2013-01-01 ... 0 a 1
13 4 2 2013-01-01 ... 0 a 1
14 4 2 2013-01-01 ... 0 a 1
15 7 2 2013-01-01 ... 0 a 1
16 8 2 2013-01-01 ... 0 a 1
17 10 2 2013-01-01 ... 0 a 1
18 3 2 2013-01-01 ... 0 a 1
19 7 2 2013-01-01 ... 0 a 1
[20 rows x 9 columns]}, 'table2': {'table_name': 'store', 'table_value': Store StoreType ... Promo2SinceYear PromoInterval
0 1 a ... 2013.0 Jan,Apr,Jul,Oct
1 2 a ... 2010.0 Jan,Apr,Jul,Oct
2 3 a ... NaN NaN
3 4 c ... 2012.0 Jan,Apr,Jul,Oct
4 5 c ... NaN NaN
5 6 a ... 2013.0 Jan,Apr,Jul,Oct
6 7 c ... NaN NaN
7 8 a ... NaN NaN
8 9 a ... NaN NaN
9 10 d ... 2012.0 Mar,Jun,Sept,Dec

"""
DataSampler 模块:
目前使用了CTGAN开源项目中的代码
后续还会根据实际业务需求进一步进行改写
以及进行一些性能优化
"""
import numpy as np
class DataSamplerCTGAN(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""
def __init__(self, data, output_info, log_frequency):
self._data = data
def is_discrete_column(column_info):
return len(column_info) == 1 and column_info[0].activation_fn == "softmax"
n_discrete_columns = sum(
[1 for column_info in output_info if is_discrete_column(column_info)]
)
self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype="int32")
# Store the row id for each category in each discrete column.
# For example _rid_by_cat_cols[a][b] is a list of all rows with the
# a-th discrete column equal value b.
self._rid_by_cat_cols = []
# Compute _rid_by_cat_cols
st = 0
for column_info in output_info:
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
rid_by_cat = []
for j in range(span_info.dim):
rid_by_cat.append(np.nonzero(data[:, st + j])[0])
self._rid_by_cat_cols.append(rid_by_cat)
st = ed
else:
st += sum([span_info.dim for span_info in column_info])
assert st == data.shape[1]
# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "sdgx"
description = "synthetic-data-generator"
keywords = ["sdg", "hitsz-ids"]
requires-python = ">=3.8"
# FIXME: Incompatibilities may exist and need to be covered by CI and test cases
# In case of incompatibility, please refer to the archived `requirements.txt.old`
dependencies = [
"setproctitle",
"PyMySQL",
"pandas",
"numpy",
"scikit-learn",
"torch",
"torchvision",
"rdt",
"joblib",
"dython",
"seaborn",
"table-evaluator",
]
dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
]
[project.optional-dependencies]
test = ["pytest"]
[[project.authors]]
name = "hitsz-ids"
[project.readme]
file = "README.md"
content-type = "text/markdown"
[project.license]
text = "Apache Software License 2.0"
[project.urls]
Source = "https://github.com/hitsz-ids/synthetic-data-generator"
[tool.check-manifest]
ignore = [".*"]
[tool.hatch.version]
path = "sdgx/__init__.py"

https://github.com/hitsz-ids/synthetic-data-generator/blob/bbe929a09e1b4dd2acaf6f0b922e8e942b6126b9/sdgx/statistics/base.py#L1-L0


Step 2: 🧐 Snippet Analysis

From looking through the relevant snippets, I decided to make the following modifications:

File Path Proposed Changes
sdgx/metrics/base.py Modify sdgx/metrics/base.py with contents:
* Remove the pass statement from the file.
• Import the ABC and abstractmethod from the abc module at the top of the file.
• Create a new class named BaseMetric that inherits from ABC.
• Inside the BaseMetric class, define an __init__ method that takes two parameters: real_data and synthetic_data. These parameters should be stored as instance variables.
• Still inside the BaseMetric class, define an abstract method named calculate that takes no parameters. This method will be used to calculate the metric and should be implemented in each subclass.
• Still inside the BaseMetric class, define a method named validate_datasets that takes no parameters. This method should check if the real_data and synthetic_data instance variables are valid datasets. For now, this method can simply pass.

Step 3: 📝 Planning

I have created a plan for writing the pull request. I am now working my plan and coding the required changes to address this issue. Here is the planned pull request:

Add base class for metrics in metrics/base.py
sweep/add-base-class-for-metrics

Description

This PR adds a base class for metrics in the metrics/base.py file. The base class, BaseMetric, provides a foundation for implementing various metrics that can evaluate the similarity between synthetic and real datasets.

Summary

  • Removed the pass statement from the metrics/base.py file.
  • Imported the ABC and abstractmethod from the abc module.
  • Created a new class named BaseMetric that inherits from ABC.
  • Defined an __init__ method in the BaseMetric class to store the real_data and synthetic_data as instance variables.
  • Defined an abstract method named calculate in the BaseMetric class to calculate the metric. This method will be implemented in each subclass.
  • Defined a method named validate_datasets in the BaseMetric class to validate the input datasets. Currently, this method does not perform any validation.

Please review and merge this PR. Thank you!


Step 4: ⌨️ Coding

File Instructions Progress
sdgx/metrics/base.py Modify sdgx/metrics/base.py with contents:
* Remove the pass statement from the file.
• Import the ABC and abstractmethod from the abc module at the top of the file.
• Create a new class named BaseMetric that inherits from ABC.
• Inside the BaseMetric class, define an __init__ method that takes two parameters: real_data and synthetic_data. These parameters should be stored as instance variables.
• Still inside the BaseMetric class, define an abstract method named calculate that takes no parameters. This method will be used to calculate the metric and should be implemented in each subclass.
• Still inside the BaseMetric class, define a method named validate_datasets that takes no parameters. This method should check if the real_data and synthetic_data instance variables are valid datasets. For now, this method can simply pass.
✅ Commit b96c716

Step 5: 🔁 Code Review

Here are my self-reviews of my changes at sweep/add-base-class-for-metrics.

Here is the 1st review

Thanks for your contribution. There are a few changes that need to be made:

  • In sdgx/metrics/base.py on lines 7-9, the calculate method is an abstract method but it doesn't raise a NotImplementedError. Please modify it to raise this error if it's not overridden in a subclass. Here's an example of how to do this:
@abstractmethod
def calculate(self):
    raise NotImplementedError("Subclass must implement this method")
  • Also in sdgx/metrics/base.py, on lines 11-13, the validate_datasets method is currently unimplemented. It's unclear what this method is supposed to do. If it's not necessary, please remove it. If it is necessary, please implement it or provide a comment explaining what it should do.

Please make these changes and update the pull request. Thanks!

I finished incorporating these changes.


To recreate the pull request edit the issue title or description.
Join Our Discord

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
sweep Assigns Sweep to an issue or pull request.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant