Skip to content

Commit

Permalink
Add ConstInspector and ConstValueTransformer for Handling Constant Co…
Browse files Browse the repository at this point in the history
…lumns (#202)

* add value_fields in metadata

* add ConstInspector

* add ConstValueTransformer and its testcase

* update comments and validator parameter name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update typo in test case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix metadata test error of value fields

The management of metadata fields may be flawed, necessitating an examination of the eq method or the manner in which fields are retrieved. We will open a separate pull request to address this issue.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add const_columns in default

* Refreshing the test cases

Addressing issues in pytest where erroneous references to certain pytest.fixture instances arise can be resolved through the utilization of deepcopy.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revise the unit tests for the module handling two constant collections to ensure they are comprehensive and reflect the latest functionality.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add deepcopy in const.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* restore metadata, modify const inspector and transformer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove const_values in inspector's unit test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove multiple metadata in unit test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* set all inspectors not ready  before fit chunk

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try reset the const columns after inspect

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clear column set before inspector fit

* change test func name

* add const type in test case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MooooCat and pre-commit-ci[bot] authored Jul 31, 2024
1 parent 3e0366c commit 7d37e58
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 2 deletions.
1 change: 1 addition & 0 deletions sdgx/data_models/inspectors/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.bool_columns = set()
self.bool_columns = self.bool_columns.union(
set(raw_data.infer_objects().select_dtypes(include=["bool"]).columns)
)
Expand Down
70 changes: 70 additions & 0 deletions sdgx/data_models/inspectors/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import copy
from typing import Any

import pandas as pd

from sdgx.data_models.inspectors.base import Inspector
from sdgx.data_models.inspectors.extension import hookimpl


class ConstInspector(Inspector):
"""
ConstInspector is a class designed to identify columns in a DataFrame that contain constant values.
It extends the base Inspector class and is used to fit the data and inspect it for constant columns.
Attributes:
const_columns (set[str]): A set of column names that contain constant values.
const_values (dict[Any]): A dictionary mapping column names to their constant values.
_inspect_level (int): The inspection level for this inspector, set to 80.
"""

const_columns: set[str] = set()
"""
A set of column names that contain constant values. This attribute is populated during the fit method by identifying columns in the DataFrame where all values are the same.
"""

const_values: dict[Any] = {}
"""
A dictionary mapping column names to their constant values. This attribute is populated during the fit method by storing the unique value found in each constant column.
"""

_inspect_level = 80
"""
The inspection level for this inspector, set to 80. This attribute indicates the priority or depth of inspection that this inspector performs relative to other inspectors.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""
Fit the inspector to the raw data.
This method identifies columns in the DataFrame that contain constant values. It populates the `const_columns` set with the names of these columns and the `const_values` dictionary with the constant values found in each column.
Args:
raw_data (pd.DataFrame): The raw data to be inspected.
Returns:
None
"""
self.const_columns = set()
# iterate each column
for column in raw_data.columns:
if len(raw_data[column].value_counts(normalize=True)) == 1:
self.const_columns.add(column)
# self.const_values[column] = raw_data[column][0]

self.ready = True

def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"const_columns": self.const_columns}


@hookimpl
def register(manager):
manager.register("ConstInspector", ConstInspector)
2 changes: 2 additions & 0 deletions sdgx/data_models/inspectors/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.datetime_columns = set()

self.datetime_columns = self.datetime_columns.union(
set(raw_data.infer_objects().select_dtypes(include=["datetime64"]).columns)
)
Expand Down
1 change: 1 addition & 0 deletions sdgx/data_models/inspectors/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.discrete_columns = set()

self.discrete_columns = self.discrete_columns.union(
set(raw_data.select_dtypes(include="object").columns)
Expand Down
2 changes: 2 additions & 0 deletions sdgx/data_models/inspectors/i_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
raw_data (pd.DataFrame): Raw data
"""

self.ID_columns = set()

df_length = len(raw_data)
candidate_columns = set(raw_data.select_dtypes(include=["object", "int64"]).columns)

Expand Down
3 changes: 3 additions & 0 deletions sdgx/data_models/inspectors/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
raw_data (pd.DataFrame): Raw data
"""

self.int_columns = set()
self.float_columns = set()

self.df_length = len(raw_data)

float_candidate = self.float_columns.union(
Expand Down
4 changes: 4 additions & 0 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def check_column_list(cls, value) -> Any:
bool_columns: Set[str] = set()
discrete_columns: Set[str] = set()
datetime_columns: Set[str] = set()
const_columns: Set[str] = set()
datetime_format: Dict = defaultdict(str)

# version info
Expand Down Expand Up @@ -298,6 +299,9 @@ def from_dataloader(
inspectors = im.init_inspcetors(
include_inspectors, exclude_inspectors, **(inspector_init_kwargs or {})
)
# set all inspectors not ready
for inspector in inspectors:
inspector.ready = False
for i, chunk in enumerate(dataloader.iter()):
for inspector in inspectors:
if not inspector.ready:
Expand Down
6 changes: 5 additions & 1 deletion sdgx/data_processors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ class DataProcessorManager(Manager):
"IntValueFormatter",
"DatetimeFormatter",
]
] + ["EmptyTransformer".lower(), "ColumnOrderTransformer".lower()]
] + [
"ConstValueTransformer".lower(),
"EmptyTransformer".lower(),
"ColumnOrderTransformer".lower(),
]
"""
preset_defalut_processors list stores the lowercase names of the transformers loaded by default. When using the synthesizer, they will be loaded by default to facilitate user operations.
Expand Down
110 changes: 110 additions & 0 deletions sdgx/data_processors/transformers/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import copy
from typing import Any

import pandas as pd

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.extension import hookimpl
from sdgx.data_processors.transformers.base import Transformer
from sdgx.utils import logger


class ConstValueTransformer(Transformer):
"""
A transformer that replaces the input with a constant value.
This class is used to transform any input data into a predefined constant value.
It is particularly useful in scenarios where a consistent output is required regardless of the input.
Attributes:
const_value (dict[Any]): The constant value that will be returned.
"""

const_columns: list = []

const_values: dict[Any] = {}

def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
"""
Fit method for the transformer.
This method processes the metadata to identify columns that should be replaced with a constant value.
It updates the internal state of the transformer with the columns and their corresponding constant values.
Args:
metadata (Metadata | None): The metadata object containing information about the columns and their data types.
**kwargs (dict[str, Any]): Additional keyword arguments.
Returns:
None
"""

for each_col in metadata.column_list:
if metadata.get_column_data_type(each_col) == "const":
self.const_columns.append(each_col)
# self.const_values[each_col] = metadata.get("const_values")[each_col]

logger.info("ConstValueTransformer Fitted.")

self.fitted = True

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
"""
Convert method to handle missing values in the input data by replacing specified columns with constant values.
This method iterates over the columns identified for replacement with constant values and removes them from the input DataFrame.
The removal is based on the columns specified during the fitting process.
Args:
raw_data (pd.DataFrame): The input DataFrame containing the data to be processed.
Returns:
pd.DataFrame: A DataFrame with the specified columns removed.
"""

processed_data = copy.deepcopy(raw_data)

logger.info("Converting data using ConstValueTransformer...")

for each_col in self.const_columns:
# record values here
if each_col not in self.const_values.keys():
self.const_values[each_col] = processed_data[each_col].unique()[0]
processed_data = self.remove_columns(processed_data, [each_col])

logger.info("Converting data using ConstValueTransformer... Finished.")

return processed_data

def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
"""
Reverse_convert method for the transformer.
This method restores the original columns that were replaced with constant values during the conversion process.
It iterates over the columns identified for replacement with constant values and adds them back to the DataFrame
with the predefined constant values.
Args:
processed_data (pd.DataFrame): The input DataFrame containing the processed data.
Returns:
pd.DataFrame: A DataFrame with the original columns restored, filled with their corresponding constant values.
"""
df_length = processed_data.shape[0]

for each_col_name in self.const_columns:
each_value = self.const_values[each_col_name]
each_const_col = [each_value for _ in range(df_length)]
each_const_df = pd.DataFrame({each_col_name: each_const_col})
processed_data = self.attach_columns(processed_data, each_const_df)

logger.info("Data reverse-converted by ConstValueTransformer.")

return processed_data


@hookimpl
def register(manager):
manager.register("ConstValueTransformer", ConstValueTransformer)
36 changes: 36 additions & 0 deletions tests/data_models/inspector/test_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import copy

import pandas as pd
import pytest

from sdgx.data_models.inspectors.const import ConstInspector


@pytest.fixture
def test_const_data(demo_single_table_path):
const_col_df = pd.read_csv(demo_single_table_path)

# Convert the columns to float to allow None values
const_col_df["age"] = const_col_df["age"].astype(float)
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float)

# Set the values to None
const_col_df["age"].values[:] = 100
const_col_df["fnlwgt"].values[:] = 3.14
const_col_df["workclass"].values[:] = "President"

yield const_col_df


def test_const_inspector(test_const_data: pd.DataFrame):
inspector = ConstInspector()
inspector.fit(test_const_data)
assert inspector.ready
assert inspector.const_columns

assert sorted(inspector.inspect()["const_columns"]) == sorted(["age", "fnlwgt", "workclass"])
assert inspector.inspect_level == 80


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])
2 changes: 1 addition & 1 deletion tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata):
assert demo_multi_data_child_matadata.get_column_data_type("Store") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("Date") == "datetime"
assert demo_multi_data_child_matadata.get_column_data_type("Customers") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "const"
assert demo_multi_data_child_matadata.get_column_data_type("Sales") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("Promo") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("DayOfWeek") == "int"
Expand Down
83 changes: 83 additions & 0 deletions tests/data_processors/transformers/test_transformers_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import copy

import numpy as np
import pandas as pd
import pytest

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.transformers.const import ConstValueTransformer


@pytest.fixture
def test_const_data(demo_single_table_path):

const_col_df = pd.read_csv(demo_single_table_path)
# Convert the columns to float to allow None values
const_col_df["age"] = const_col_df["age"].astype(float)
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float)

# Set the values to None
const_col_df["age"].values[:] = 100
const_col_df["fnlwgt"].values[:] = 1.41421
const_col_df["workclass"].values[:] = "President"

yield const_col_df


def test_const_handling_test_df(test_const_data: pd.DataFrame):
"""
Test the handling of const columns in a DataFrame.
This function tests the behavior of a DataFrame when it contains const columns.
It is designed to be used in a testing environment, where the DataFrame is passed as an argument.
Parameters:
test_const_data (pd.DataFrame): The DataFrame to test.
Returns:
None
Raises:
AssertionError: If the DataFrame does not handle const columns as expected.
"""

metadata = Metadata.from_dataframe(test_const_data)

# Initialize the ConstValueTransformer.
const_transformer = ConstValueTransformer()
# Check if the transformer has not been fitted yet.
assert const_transformer.fitted is False

# Fit the transformer with the DataFrame.
const_transformer.fit(metadata)

# Check if the transformer has been fitted after the fit operation.
assert const_transformer.fitted

# Check the const column
assert sorted(const_transformer.const_columns) == [
"age",
"fnlwgt",
"workclass",
]

# Transform the DataFrame using the transformer.
transformed_df = const_transformer.convert(test_const_data)

assert "age" not in transformed_df.columns
assert "fnlwgt" not in transformed_df.columns
assert "workclass" not in transformed_df.columns

# reverse convert the df
reverse_converted_df = const_transformer.reverse_convert(transformed_df)

assert "age" in reverse_converted_df.columns
assert "fnlwgt" in reverse_converted_df.columns
assert "workclass" in reverse_converted_df.columns

assert reverse_converted_df["age"][0] == 100
assert reverse_converted_df["fnlwgt"][0] == 1.41421
assert reverse_converted_df["workclass"][0] == "President"

assert len(reverse_converted_df["age"].unique()) == 1
assert len(reverse_converted_df["fnlwgt"].unique()) == 1
assert len(reverse_converted_df["workclass"].unique()) == 1

0 comments on commit 7d37e58

Please sign in to comment.