Skip to content

Commit

Permalink
Reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
951378644 committed Dec 20, 2023
1 parent a7bd06b commit 5c2ec1b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
8 changes: 7 additions & 1 deletion menelaus/data_drift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@
from menelaus.data_drift.nndvi import NNDVI
from menelaus.data_drift.cdbd import CDBD
from menelaus.data_drift.histogram_density_method import HistogramDensityMethod
from menelaus.data_drift.stat_test import GenericDetector, CHIDetector, KSDetector, CVMDetector, FETDetector
from menelaus.data_drift.stat_test import (
GenericDetector,
CHIDetector,
KSDetector,
CVMDetector,
FETDetector,
)
48 changes: 40 additions & 8 deletions menelaus/data_drift/stat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from menelaus.detector import BatchDetector
from scipy.stats import chi2_contingency, kstest, cramervonmises_2samp, fisher_exact


class GenericDetector(BatchDetector):
"""
Generic Detector class for detecting batched data drift.
Expand Down Expand Up @@ -37,8 +38,10 @@ def set_reference(self, X, y_true=None, y_pred=None):
y_pred (numpy.array): predicted labels, not used in GenericDetector
"""
X, _, _ = super()._validate_input(X, None, None)
X = X.reshape(len(X),)
X = self.representation.fit(X)#, y_true)
X = X.reshape(
len(X),
)
X = self.representation.fit(X) # , y_true)
self.reference = X

def update(self, X, y_true=None, y_pred=None, alt=None):
Expand All @@ -63,7 +66,9 @@ def update(self, X, y_true=None, y_pred=None, alt=None):
self.reference = self.test

X, _, _ = super()._validate_input(X, None, None)
X = X.reshape(len(X),)
X = X.reshape(
len(X),
)
X = self.representation.fit(X)
self.test = X # , y_true, y_pred)
if alt == None:
Expand All @@ -88,8 +93,10 @@ def reset(self):
self.div = None
pass


# region Validations


class IdentityValidation:
"""
Identity transform for numerical features.
Expand All @@ -116,7 +123,10 @@ def fit(self, X_ref, y_true=None):
if any(np.issubdtype(X_ref.dtype, dtype) for dtype in numeric_dtypes):
return X_ref
else:
raise ValueError("No numerical data detected. Please pass numerical features.")
raise ValueError(
"No numerical data detected. Please pass numerical features."
)


class CategoricalValidation:
"""
Expand All @@ -138,10 +148,15 @@ def fit(self, X_ref, y_true=None):
X_ref: Confirmed categorical data.
"""
categorical_dtypes = [np.object_]
if (any(np.issubdtype(X_ref.dtype, dtype) for dtype in categorical_dtypes)) or len(np.unique(X_ref)) < 10:
if (
any(np.issubdtype(X_ref.dtype, dtype) for dtype in categorical_dtypes)
) or len(np.unique(X_ref)) < 10:
return X_ref
else:
raise ValueError("No categorical columns detected. Please pass categorical features.")
raise ValueError(
"No categorical columns detected. Please pass categorical features."
)


class BinaryValidation:
"""
Expand All @@ -167,12 +182,16 @@ def fit(self, X_ref, y_true=None):
if (set(unique_values) == {0, 1}) or (set(unique_values) == {False, True}):
return X_ref
else:
raise ValueError("The X_ref data must consist of only (0,1)'s or (False,True)'s for the FETDrift detector.")
raise ValueError(
"The X_ref data must consist of only (0,1)'s or (False,True)'s for the FETDrift detector."
)


# endregion

# region Hypothesis test wrappers


def chi2Divergence(rep_ref, rep_test):
"""
Calculate the p-value for the chi-squared test of independence between two categorical distributions.
Expand All @@ -192,6 +211,7 @@ def chi2Divergence(rep_ref, rep_test):
_, pval, _, _ = chi2_contingency(contingency_table)
return pval


def ksDivergence(rep_ref, rep_test):
"""
Calculate the p-value for the Kolmogorov-Smirnov test between two distributions.
Expand All @@ -206,6 +226,7 @@ def ksDivergence(rep_ref, rep_test):
pval = kstest(rep_ref, rep_test).pvalue
return pval


def cvmDivergence(rep_ref, rep_test):
"""
Calculate the p-value for the Cramér-von Mises test between two distributions.
Expand All @@ -220,6 +241,7 @@ def cvmDivergence(rep_ref, rep_test):
pval = cramervonmises_2samp(rep_ref, rep_test).pvalue
return pval


def fetDivergence(ref, test, alternative):
"""
Calculate the p-value for the Fisher's Exact Test between two binary distributions.
Expand All @@ -239,15 +261,19 @@ def fetDivergence(ref, test, alternative):
n_ref, n = ref.shape[0], test.shape[0]
p_val, odds_ratio = np.empty(1), np.empty(1)

table = np.array([[np.sum(test), np.sum(ref)], [n - np.sum(test), n_ref - np.sum(ref)]])
table = np.array(
[[np.sum(test), np.sum(ref)], [n - np.sum(test), n_ref - np.sum(ref)]]
)
odds_ratio[0], p_val[0] = fisher_exact(table, alternative)

return p_val


# endregion

# region Critical value functions


def crit(pval, alpha=0.05):
"""
Compare p-value with alpha to determine statistical significance.
Expand All @@ -264,10 +290,12 @@ def crit(pval, alpha=0.05):
else:
return False


# endregion

# region Implemented classes


class CHIDetector(GenericDetector):
"""
Chi-squared Test Data drift detector using the chi-squared test for categorical data.
Expand All @@ -283,6 +311,7 @@ def __init__(self):
crit_function=crit,
)


class KSDetector(GenericDetector):
"""
Kolmogorov-Smirnov Test Data drift detector using the Kolmogorov-Smirnov test for numerical data.
Expand All @@ -298,6 +327,7 @@ def __init__(self):
crit_function=crit,
)


class CVMDetector(GenericDetector):
"""
Cramér-von MisesData Test drift detector using the Cramér-von Mises test for numerical data.
Expand All @@ -313,6 +343,7 @@ def __init__(self):
crit_function=crit,
)


class FETDetector(GenericDetector):
"""
Fisher's Exact Test Data drift detector using Fisher's Exact Test for binary data.
Expand All @@ -328,4 +359,5 @@ def __init__(self):
crit_function=crit,
)


# endregion

0 comments on commit 5c2ec1b

Please sign in to comment.