-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ConstInspector and ConstValueTransformer for Handling Constant Co…
…lumns (#202) * add value_fields in metadata * add ConstInspector * add ConstValueTransformer and its testcase * update comments and validator parameter name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update typo in test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix metadata test error of value fields The management of metadata fields may be flawed, necessitating an examination of the eq method or the manner in which fields are retrieved. We will open a separate pull request to address this issue. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add const_columns in default * Refreshing the test cases Addressing issues in pytest where erroneous references to certain pytest.fixture instances arise can be resolved through the utilization of deepcopy. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revise the unit tests for the module handling two constant collections to ensure they are comprehensive and reflect the latest functionality. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deepcopy in const.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore metadata, modify const inspector and transformer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove const_values in inspector's unit test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove multiple metadata in unit test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set all inspectors not ready before fit chunk * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try reset the const columns after inspect * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clear column set before inspector fit * change test func name * add const type in test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
3e0366c
commit 7d37e58
Showing
12 changed files
with
318 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
from typing import Any | ||
|
||
import pandas as pd | ||
|
||
from sdgx.data_models.inspectors.base import Inspector | ||
from sdgx.data_models.inspectors.extension import hookimpl | ||
|
||
|
||
class ConstInspector(Inspector): | ||
""" | ||
ConstInspector is a class designed to identify columns in a DataFrame that contain constant values. | ||
It extends the base Inspector class and is used to fit the data and inspect it for constant columns. | ||
Attributes: | ||
const_columns (set[str]): A set of column names that contain constant values. | ||
const_values (dict[Any]): A dictionary mapping column names to their constant values. | ||
_inspect_level (int): The inspection level for this inspector, set to 80. | ||
""" | ||
|
||
const_columns: set[str] = set() | ||
""" | ||
A set of column names that contain constant values. This attribute is populated during the fit method by identifying columns in the DataFrame where all values are the same. | ||
""" | ||
|
||
const_values: dict[Any] = {} | ||
""" | ||
A dictionary mapping column names to their constant values. This attribute is populated during the fit method by storing the unique value found in each constant column. | ||
""" | ||
|
||
_inspect_level = 80 | ||
""" | ||
The inspection level for this inspector, set to 80. This attribute indicates the priority or depth of inspection that this inspector performs relative to other inspectors. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def fit(self, raw_data: pd.DataFrame, *args, **kwargs): | ||
""" | ||
Fit the inspector to the raw data. | ||
This method identifies columns in the DataFrame that contain constant values. It populates the `const_columns` set with the names of these columns and the `const_values` dictionary with the constant values found in each column. | ||
Args: | ||
raw_data (pd.DataFrame): The raw data to be inspected. | ||
Returns: | ||
None | ||
""" | ||
self.const_columns = set() | ||
# iterate each column | ||
for column in raw_data.columns: | ||
if len(raw_data[column].value_counts(normalize=True)) == 1: | ||
self.const_columns.add(column) | ||
# self.const_values[column] = raw_data[column][0] | ||
|
||
self.ready = True | ||
|
||
def inspect(self, *args, **kwargs) -> dict[str, Any]: | ||
"""Inspect raw data and generate metadata.""" | ||
|
||
return {"const_columns": self.const_columns} | ||
|
||
|
||
@hookimpl | ||
def register(manager): | ||
manager.register("ConstInspector", ConstInspector) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
from typing import Any | ||
|
||
import pandas as pd | ||
|
||
from sdgx.data_models.metadata import Metadata | ||
from sdgx.data_processors.extension import hookimpl | ||
from sdgx.data_processors.transformers.base import Transformer | ||
from sdgx.utils import logger | ||
|
||
|
||
class ConstValueTransformer(Transformer): | ||
""" | ||
A transformer that replaces the input with a constant value. | ||
This class is used to transform any input data into a predefined constant value. | ||
It is particularly useful in scenarios where a consistent output is required regardless of the input. | ||
Attributes: | ||
const_value (dict[Any]): The constant value that will be returned. | ||
""" | ||
|
||
const_columns: list = [] | ||
|
||
const_values: dict[Any] = {} | ||
|
||
def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): | ||
""" | ||
Fit method for the transformer. | ||
This method processes the metadata to identify columns that should be replaced with a constant value. | ||
It updates the internal state of the transformer with the columns and their corresponding constant values. | ||
Args: | ||
metadata (Metadata | None): The metadata object containing information about the columns and their data types. | ||
**kwargs (dict[str, Any]): Additional keyword arguments. | ||
Returns: | ||
None | ||
""" | ||
|
||
for each_col in metadata.column_list: | ||
if metadata.get_column_data_type(each_col) == "const": | ||
self.const_columns.append(each_col) | ||
# self.const_values[each_col] = metadata.get("const_values")[each_col] | ||
|
||
logger.info("ConstValueTransformer Fitted.") | ||
|
||
self.fitted = True | ||
|
||
def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Convert method to handle missing values in the input data by replacing specified columns with constant values. | ||
This method iterates over the columns identified for replacement with constant values and removes them from the input DataFrame. | ||
The removal is based on the columns specified during the fitting process. | ||
Args: | ||
raw_data (pd.DataFrame): The input DataFrame containing the data to be processed. | ||
Returns: | ||
pd.DataFrame: A DataFrame with the specified columns removed. | ||
""" | ||
|
||
processed_data = copy.deepcopy(raw_data) | ||
|
||
logger.info("Converting data using ConstValueTransformer...") | ||
|
||
for each_col in self.const_columns: | ||
# record values here | ||
if each_col not in self.const_values.keys(): | ||
self.const_values[each_col] = processed_data[each_col].unique()[0] | ||
processed_data = self.remove_columns(processed_data, [each_col]) | ||
|
||
logger.info("Converting data using ConstValueTransformer... Finished.") | ||
|
||
return processed_data | ||
|
||
def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Reverse_convert method for the transformer. | ||
This method restores the original columns that were replaced with constant values during the conversion process. | ||
It iterates over the columns identified for replacement with constant values and adds them back to the DataFrame | ||
with the predefined constant values. | ||
Args: | ||
processed_data (pd.DataFrame): The input DataFrame containing the processed data. | ||
Returns: | ||
pd.DataFrame: A DataFrame with the original columns restored, filled with their corresponding constant values. | ||
""" | ||
df_length = processed_data.shape[0] | ||
|
||
for each_col_name in self.const_columns: | ||
each_value = self.const_values[each_col_name] | ||
each_const_col = [each_value for _ in range(df_length)] | ||
each_const_df = pd.DataFrame({each_col_name: each_const_col}) | ||
processed_data = self.attach_columns(processed_data, each_const_df) | ||
|
||
logger.info("Data reverse-converted by ConstValueTransformer.") | ||
|
||
return processed_data | ||
|
||
|
||
@hookimpl | ||
def register(manager): | ||
manager.register("ConstValueTransformer", ConstValueTransformer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import copy | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from sdgx.data_models.inspectors.const import ConstInspector | ||
|
||
|
||
@pytest.fixture | ||
def test_const_data(demo_single_table_path): | ||
const_col_df = pd.read_csv(demo_single_table_path) | ||
|
||
# Convert the columns to float to allow None values | ||
const_col_df["age"] = const_col_df["age"].astype(float) | ||
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float) | ||
|
||
# Set the values to None | ||
const_col_df["age"].values[:] = 100 | ||
const_col_df["fnlwgt"].values[:] = 3.14 | ||
const_col_df["workclass"].values[:] = "President" | ||
|
||
yield const_col_df | ||
|
||
|
||
def test_const_inspector(test_const_data: pd.DataFrame): | ||
inspector = ConstInspector() | ||
inspector.fit(test_const_data) | ||
assert inspector.ready | ||
assert inspector.const_columns | ||
|
||
assert sorted(inspector.inspect()["const_columns"]) == sorted(["age", "fnlwgt", "workclass"]) | ||
assert inspector.inspect_level == 80 | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-vv", "-s", __file__]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
tests/data_processors/transformers/test_transformers_const.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import copy | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from sdgx.data_models.metadata import Metadata | ||
from sdgx.data_processors.transformers.const import ConstValueTransformer | ||
|
||
|
||
@pytest.fixture | ||
def test_const_data(demo_single_table_path): | ||
|
||
const_col_df = pd.read_csv(demo_single_table_path) | ||
# Convert the columns to float to allow None values | ||
const_col_df["age"] = const_col_df["age"].astype(float) | ||
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float) | ||
|
||
# Set the values to None | ||
const_col_df["age"].values[:] = 100 | ||
const_col_df["fnlwgt"].values[:] = 1.41421 | ||
const_col_df["workclass"].values[:] = "President" | ||
|
||
yield const_col_df | ||
|
||
|
||
def test_const_handling_test_df(test_const_data: pd.DataFrame): | ||
""" | ||
Test the handling of const columns in a DataFrame. | ||
This function tests the behavior of a DataFrame when it contains const columns. | ||
It is designed to be used in a testing environment, where the DataFrame is passed as an argument. | ||
Parameters: | ||
test_const_data (pd.DataFrame): The DataFrame to test. | ||
Returns: | ||
None | ||
Raises: | ||
AssertionError: If the DataFrame does not handle const columns as expected. | ||
""" | ||
|
||
metadata = Metadata.from_dataframe(test_const_data) | ||
|
||
# Initialize the ConstValueTransformer. | ||
const_transformer = ConstValueTransformer() | ||
# Check if the transformer has not been fitted yet. | ||
assert const_transformer.fitted is False | ||
|
||
# Fit the transformer with the DataFrame. | ||
const_transformer.fit(metadata) | ||
|
||
# Check if the transformer has been fitted after the fit operation. | ||
assert const_transformer.fitted | ||
|
||
# Check the const column | ||
assert sorted(const_transformer.const_columns) == [ | ||
"age", | ||
"fnlwgt", | ||
"workclass", | ||
] | ||
|
||
# Transform the DataFrame using the transformer. | ||
transformed_df = const_transformer.convert(test_const_data) | ||
|
||
assert "age" not in transformed_df.columns | ||
assert "fnlwgt" not in transformed_df.columns | ||
assert "workclass" not in transformed_df.columns | ||
|
||
# reverse convert the df | ||
reverse_converted_df = const_transformer.reverse_convert(transformed_df) | ||
|
||
assert "age" in reverse_converted_df.columns | ||
assert "fnlwgt" in reverse_converted_df.columns | ||
assert "workclass" in reverse_converted_df.columns | ||
|
||
assert reverse_converted_df["age"][0] == 100 | ||
assert reverse_converted_df["fnlwgt"][0] == 1.41421 | ||
assert reverse_converted_df["workclass"][0] == "President" | ||
|
||
assert len(reverse_converted_df["age"].unique()) == 1 | ||
assert len(reverse_converted_df["fnlwgt"].unique()) == 1 | ||
assert len(reverse_converted_df["workclass"].unique()) == 1 |