Skip to content

Commit 39bb45e

Browse files
authored
fix(datarelations): drop_nan before VIF computation (#62)
* drop_nan before VIF computation * fix linalg error * codacy linting * finish linting * add test suite for data relations engine * lint test file * pytest correct fixture definition * add encoding to open
1 parent b2b05a8 commit 39bb45e

File tree

5 files changed

+126
-33
lines changed

5 files changed

+126
-33
lines changed

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ pytest
44
sphinx
55
myst-parser
66
twine
7+
pytest
8+
nbconvert

src/ydata_quality/data_relations/engine.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
"""
44
from typing import List, Optional, Tuple
55

6+
from numpy import argwhere, ones, tril
67
from pandas import DataFrame
7-
from numpy import ones, tril, argwhere
8-
98
from src.ydata_quality.core.warnings import Priority
109

1110
from ..core import QualityEngine, QualityWarning
@@ -40,7 +39,7 @@ def dtypes(self):
4039
def dtypes(self, df_dtypes: Tuple[DataFrame, dict]):
4140
df, dtypes = df_dtypes
4241
if not isinstance(dtypes, dict):
43-
self._logger.warning("Property 'dtypes' should be a dictionary. Defaulting to all column dtypes inference.")
42+
self._logger.debug("Property 'dtypes' should be a dictionary. Defaulting to all column dtypes inference.")
4443
dtypes = {}
4544
cols_not_in_df = [col for col in dtypes if col not in df.columns]
4645
if len(cols_not_in_df) > 0:
@@ -49,7 +48,7 @@ def dtypes(self, df_dtypes: Tuple[DataFrame, dict]):
4948
wrong_dtypes = [col for col, dtype in dtypes.items() if dtype not in supported_dtypes]
5049
if len(wrong_dtypes) > 0:
5150
self._logger.warning(
52-
"Columns %s of dtypes where not defined with a supported dtype and will be inferred.", wrong_dtypes)
51+
"Columns %s have no valid dtypes. Supported dtypes will be inferred.", wrong_dtypes)
5352
dtypes = {key: val for key, val in dtypes.items() if key not in cols_not_in_df + wrong_dtypes}
5453
df_col_set = set(df.columns)
5554
dtypes_col_set = set(dtypes.keys())
@@ -64,7 +63,7 @@ def dtypes(self, df_dtypes: Tuple[DataFrame, dict]):
6463
def evaluate(self, df: DataFrame, dtypes: Optional[dict] = None, label: str = None, corr_th: float = 0.8,
6564
vif_th: float = 5, p_th: float = 0.05, plot: bool = True, summary: bool = True) -> dict:
6665
"""Runs tests to the validation run results and reports based on found errors.
67-
We perform standard normalization of numerical features in order to unbias VIF and partial correlation methods.
66+
Standard normalization of numerical features is performed as a preprocessing operation.
6867
This bias correction produces results equivalent to adding a constant feature to the dataset.
6968
7069
Args:
@@ -74,17 +73,25 @@ def evaluate(self, df: DataFrame, dtypes: Optional[dict] = None, label: str = No
7473
label (Optional[str]): A string identifying the label feature column
7574
corr_th (float): Absolute threshold for high correlation detection. Defaults to 0.8.
7675
vif_th (float): Variance Inflation Factor threshold for numerical independence test.
77-
Typically 5-10 is recommended. Defaults to 5.
78-
p_th (float): Fraction of the right tail of the chi squared CDF.
79-
Defines threshold for categorical independence test. Defaults to 0.05.
76+
Typically a minimum of 5-10 is recommended. Defaults to 5.
77+
p_th (float): Fraction of the right tail of the chi squared CDF defining threshold for categorical
78+
independence test. Defaults to 0.05.
8079
plot (bool): Pass True to produce all available graphical outputs, False to suppress all graphical output.
8180
summary (bool): Print a report containing all the warnings detected during the data quality analysis.
8281
"""
83-
assert label in df.columns or not label, "The provided label name does not exist as a column in the dataset"
82+
results = {}
83+
nan_or_const = df.nunique() < 2 # Constant columns or all nan columns
84+
label = None if label in nan_or_const else label
85+
self._logger.warning('The columns %s are constant or all NaNs and \
86+
were dropped from this evaluation.', list(nan_or_const.index[nan_or_const]))
87+
df = df.drop(columns=nan_or_const.index[nan_or_const]) # Constant columns or all nan columns are dropped
88+
if df.shape[1] < 2:
89+
self._logger.warning('There are fewer than 2 columns on the dataset where correlations can be computed. \
90+
Skipping the DataRelations engine execution.')
91+
return results
8492
self.dtypes = (df, dtypes) # Consider refactoring QualityEngine dtypes (df as argument of setter)
8593
df = standard_normalize(df, self.dtypes)
86-
results = {}
87-
corr_mat, _ = correlation_matrix(df, self.dtypes, True)
94+
corr_mat, _ = correlation_matrix(df, self.dtypes, label, True)
8895
p_corr_mat = partial_correlation_matrix(corr_mat)
8996
results['Correlations'] = {'Correlation matrix': corr_mat, 'Partial correlation matrix': p_corr_mat}
9097
if plot:
@@ -96,9 +103,12 @@ def evaluate(self, df: DataFrame, dtypes: Optional[dict] = None, label: str = No
96103
results['Colliders'] = self._collider_detection(corr_mat, p_corr_mat, corr_th)
97104
else:
98105
self._logger.warning('The partial correlation matrix is not computable for this dataset. \
99-
Skipping potential confounder and collider detection tests.')
106+
Skipped potential confounder and collider detection tests.')
100107
if label:
101-
results['Feature Importance'] = self._feature_importance(corr_mat, p_corr_mat, label, corr_th)
108+
try:
109+
results['Feature Importance'] = self._feature_importance(corr_mat, p_corr_mat, label, corr_th)
110+
except AssertionError as exception:
111+
self._logger.warning(str(exception))
102112
results['High Collinearity'] = self._high_collinearity_detection(df, self.dtypes, label, vif_th, p_th=p_th)
103113
self._clean_warnings()
104114
if summary:
@@ -123,9 +133,9 @@ def _confounder_detection(self, corr_mat: DataFrame, par_corr_mat: DataFrame,
123133
QualityWarning(
124134
test=QualityWarning.Test.CONFOUNDED_CORRELATIONS, category=QualityWarning.Category.DATA_RELATIONS,
125135
priority=Priority.P2, data=confounded_pairs,
126-
description=f"""
127-
Found {len(confounded_pairs)} independently correlated variable pairs that disappeared after controling\
128-
for the remaining variables. This is an indicator of potential confounder effects in the dataset."""))
136+
description=f"""Found {len(confounded_pairs)} independently correlated variable pairs that \
137+
disappeared after controling for the remaining variables. This is an indicator of potential confounder effects \
138+
in the dataset."""))
129139
return confounded_pairs
130140

131141
def _collider_detection(self, corr_mat: DataFrame, par_corr_mat: DataFrame,
@@ -147,8 +157,8 @@ def _collider_detection(self, corr_mat: DataFrame, par_corr_mat: DataFrame,
147157
test=QualityWarning.Test.COLLIDER_CORRELATIONS, category=QualityWarning.category.DATA_RELATIONS,
148158
priority=Priority.P2, data=colliding_pairs,
149159
description=f"Found {len(colliding_pairs)} independently uncorrelated variable pairs that showed \
150-
correlation after controling for the remaining variables. \
151-
This is an indicator of potential colliding bias with other covariates."))
160+
correlation after controling for the remaining variables. This is an indicator of potential colliding bias with other \
161+
covariates."))
152162
return colliding_pairs
153163

154164
@staticmethod
@@ -159,7 +169,8 @@ def _feature_importance(corr_mat: DataFrame, par_corr_mat: DataFrame,
159169
160170
This method returns a summary of all detected important features.
161171
The summary contains zero, full order partial correlation and a note regarding potential confounding."""
162-
assert label in corr_mat.columns, f"The provided label {label} does not exist as a column in the DataFrame."
172+
assert label in corr_mat.columns, f"The correlations of the label '{label}', required for the feature \
173+
importance test, were not computed (this column has less than the minimum of 2 unique values needed)."
163174
label_corrs = corr_mat.loc[label].drop(label)
164175
mask = ones(label_corrs.shape, dtype='bool')
165176
mask[label_corrs.abs() <= corr_th] = False # Drop pairs with zero order correlation below threshold
@@ -204,7 +215,7 @@ def _high_collinearity_detection(self, df: DataFrame, dtypes: dict, label: str =
204215
category=QualityWarning.Category.DATA_RELATIONS, priority=Priority.P2, data=inflated,
205216
description=f"""Found {len(inflated)} numerical variables with high Variance Inflation Factor \
206217
(VIF>{vif_th:.1f}). The variables listed in results are highly collinear with other variables in the dataset. \
207-
These will make model explainability harder and potentially give way to issues like overfitting.\
218+
These will make model explainability harder and potentially give way to issues like overfitting. \
208219
Depending on your end goal you might want to remove the highest VIF variables."""))
209220
if len(cat_coll_scores) > 0:
210221
# TODO: Merge warning messages (make one warning for the whole test,

src/ydata_quality/utils/auxiliary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ def find_duplicate_columns(df: DataFrame, is_close=False) -> dict:
9292
return dups
9393

9494

95-
def drop_column_list(df: DataFrame, column_list: dict):
95+
def drop_column_list(df: DataFrame, column_list: dict, label: str = None):
9696
"Drops from a DataFrame a duplicates mapping of columns to duplicate lists. Works inplace."
9797
for col, dup_list in column_list.items():
98+
dup_list = [col for col in dup_list if col != label]
9899
if col in df.columns: # Ensures we will not drop both members of duplicate pairs
99100
df.drop(columns=dup_list, index=dup_list, inplace=True)
100101

src/ydata_quality/utils/correlations.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from itertools import combinations
77
from typing import List, Optional
88

9-
from pandas import DataFrame, Series, crosstab
10-
from numpy.linalg import pinv
9+
from matplotlib.pyplot import figure as pltfigure, show as pltshow
1110
from numpy import (
1211
nan,
1312
fill_diagonal,
@@ -30,11 +29,13 @@
3029
isnan,
3130
triu_indices_from,
3231
)
33-
from scipy.stats import pearsonr, chi2_contingency
32+
from numpy.linalg import pinv
33+
from pandas import DataFrame, Series, crosstab
34+
from scipy.stats import chi2_contingency, pearsonr
3435
from scipy.stats.distributions import chi2
35-
from statsmodels.stats.outliers_influence import variance_inflation_factor as vif
36-
from seaborn import heatmap, diverging_palette
37-
from matplotlib.pyplot import show as pltshow, figure as pltfigure
36+
from seaborn import diverging_palette, heatmap
37+
from statsmodels.stats.outliers_influence import \
38+
variance_inflation_factor as vif
3839

3940
from .auxiliary import drop_column_list, find_duplicate_columns
4041

@@ -91,7 +92,8 @@ def unbiased_cramers_v(col1: ndarray, col2: ndarray) -> float:
9192
phi_sq_hat = npmax([0, phi_sq - ((r_vals - 1) * (k_vals - 1)) / (n_elements - 1)])
9293
k_hat = k_vals - square(k_vals - 1) / (n_elements - 1)
9394
r_hat = r_vals - square(r_vals - 1) / (n_elements - 1)
94-
return sqrt(phi_sq_hat / npmin([k_hat - 1, r_hat - 1])) # Note: this is strictly positive
95+
den = npmin([k_hat - 1, r_hat - 1])
96+
return sqrt(phi_sq_hat / den) if den != 0 else nan # Note: this is strictly positive
9597

9698

9799
def correlation_ratio(col1: ndarray, col2: ndarray) -> float:
@@ -102,6 +104,8 @@ def correlation_ratio(col1: ndarray, col2: ndarray) -> float:
102104
col1 (ndarray): A categorical column with no null values
103105
col2 (ndarray): A numerical column with no null values"""
104106
uniques = unique(col1)
107+
if len(uniques) < 2:
108+
return nan
105109
y_x_hat = zeros(len(uniques))
106110
counts = zeros(len(uniques))
107111
for count, value in enumerate(uniques):
@@ -116,7 +120,7 @@ def correlation_ratio(col1: ndarray, col2: ndarray) -> float:
116120

117121

118122
# pylint: disable=too-many-locals
119-
def correlation_matrix(df: DataFrame, dtypes: dict, drop_dups: bool = False) -> DataFrame:
123+
def correlation_matrix(df: DataFrame, dtypes: dict, label: str, drop_dups: bool = False) -> DataFrame:
120124
"""Returns the correlation matrix.
121125
The methods used for computing correlations are mapped according to the column dtypes of each pair."""
122126
corr_funcs = { # Map supported correlation functions
@@ -146,8 +150,8 @@ def correlation_matrix(df: DataFrame, dtypes: dict, drop_dups: bool = False) ->
146150
if drop_dups:
147151
# Find duplicate row lists in absolute correlation matrix
148152
dup_pairs = find_duplicate_columns(corr_mat.abs(), True)
149-
drop_column_list(corr_mat, dup_pairs)
150-
drop_column_list(p_vals, dup_pairs)
153+
drop_column_list(corr_mat, dup_pairs, label)
154+
drop_column_list(p_vals, dup_pairs, label)
151155
return corr_mat, p_vals
152156

153157

@@ -195,10 +199,14 @@ def vif_collinearity(data: DataFrame, dtypes: dict, label: str = None) -> Series
195199
if label and label in data.columns:
196200
data = data.drop(columns=label)
197201
num_columns = [col for col in data.columns if dtypes[col] == 'numerical']
202+
data = data.dropna(subset=num_columns)
198203
warnings.filterwarnings("ignore", category=RuntimeWarning)
199-
vifs = [vif(data[num_columns].values, i) for i in range(len(data[num_columns].columns))]
204+
if data.empty:
205+
vifs = {}
206+
else:
207+
vifs = {num_columns[i]: vif(data[num_columns].values, i) for i in range(len(data[num_columns].columns))}
200208
warnings.resetwarnings()
201-
return Series(data=vifs, index=num_columns).sort_values(ascending=False)
209+
return Series(data=vifs, dtype=float).sort_values(ascending=False)
202210

203211

204212
# pylint: disable=too-many-locals

tests/engines/test_data_relations.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"Tests for the DataRelations module."
2+
from pytest import fixture
3+
from pandas import read_csv
4+
import nbformat
5+
from nbconvert.preprocessors import ExecutePreprocessor
6+
7+
from ydata_quality.data_relations.engine import DataRelationsDetector
8+
9+
10+
@fixture(name='data_relations')
11+
def fixture_data_relations():
12+
return DataRelationsDetector()
13+
14+
15+
@fixture(name='example_dataset_transformed')
16+
def fixture_example_dataset_transformed():
17+
dataset_path = 'datasets/transformed/census_10k.csv'
18+
return read_csv(dataset_path)
19+
20+
21+
@fixture(name='ipynb_tutorial')
22+
def fixture_ipynb_tutorial():
23+
path = "tutorials/data_relations.ipynb"
24+
with open(path, encoding='utf8', errors='strict') as file:
25+
ntb = nbformat.read(file, as_version=4)
26+
return ntb
27+
28+
29+
@fixture(name='dr_results_no_pcorr')
30+
def fixture_dr_results_no_pcorr(data_relations, example_dataset_transformed):
31+
results = data_relations.evaluate(df=example_dataset_transformed,
32+
dtypes=None,
33+
label='income',
34+
plot=False)
35+
return data_relations, results
36+
37+
38+
@fixture(name='dr_results_pc_corr')
39+
def fixture_dr_results_pc_corr(data_relations, example_dataset_transformed):
40+
df = example_dataset_transformed.drop(columns=['education-num'])
41+
results = data_relations.evaluate(df=df,
42+
dtypes=None,
43+
label='income',
44+
plot=False)
45+
return data_relations, results
46+
47+
48+
def test_get_warnings(dr_results_no_pcorr):
49+
new_drd = DataRelationsDetector()
50+
assert isinstance(new_drd.get_warnings(), list)
51+
assert len(new_drd.get_warnings()) == 0
52+
53+
ran_data_relations, _ = dr_results_no_pcorr
54+
assert isinstance(ran_data_relations.get_warnings(), list)
55+
assert len(ran_data_relations.get_warnings()) > 0
56+
57+
58+
def test_results(dr_results_no_pcorr, dr_results_pc_corr):
59+
_, results = dr_results_no_pcorr
60+
assert isinstance(results, dict)
61+
assert set(results.keys()) == set(['Correlations', 'Feature Importance', 'High Collinearity'])
62+
63+
_, results2 = dr_results_pc_corr
64+
assert isinstance(results2, dict)
65+
assert set(results2.keys()) == set(['Correlations', 'Confounders', 'Colliders',
66+
'Feature Importance', 'High Collinearity'])
67+
68+
69+
def test_tutorial_notebook_execution(ipynb_tutorial):
70+
exp = ExecutePreprocessor(timeout=600, kernel_name='python3')
71+
assert exp.preprocess(ipynb_tutorial, {'metadata': {'path': "tutorials"}})

0 commit comments

Comments
 (0)