From 5c2ec1bba2db3641242dd5d2753e9c697572fdfd Mon Sep 17 00:00:00 2001 From: Tim Chen <115333718+951378644@users.noreply.github.com> Date: Wed, 20 Dec 2023 13:49:32 -0500 Subject: [PATCH] Reformatted --- menelaus/data_drift/__init__.py | 8 +++++- menelaus/data_drift/stat_test.py | 48 ++++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/menelaus/data_drift/__init__.py b/menelaus/data_drift/__init__.py index beaf091..ced65d8 100644 --- a/menelaus/data_drift/__init__.py +++ b/menelaus/data_drift/__init__.py @@ -19,4 +19,10 @@ from menelaus.data_drift.nndvi import NNDVI from menelaus.data_drift.cdbd import CDBD from menelaus.data_drift.histogram_density_method import HistogramDensityMethod -from menelaus.data_drift.stat_test import GenericDetector, CHIDetector, KSDetector, CVMDetector, FETDetector +from menelaus.data_drift.stat_test import ( + GenericDetector, + CHIDetector, + KSDetector, + CVMDetector, + FETDetector, +) diff --git a/menelaus/data_drift/stat_test.py b/menelaus/data_drift/stat_test.py index ec88141..2dc1549 100644 --- a/menelaus/data_drift/stat_test.py +++ b/menelaus/data_drift/stat_test.py @@ -3,6 +3,7 @@ from menelaus.detector import BatchDetector from scipy.stats import chi2_contingency, kstest, cramervonmises_2samp, fisher_exact + class GenericDetector(BatchDetector): """ Generic Detector class for detecting batched data drift. @@ -37,8 +38,10 @@ def set_reference(self, X, y_true=None, y_pred=None): y_pred (numpy.array): predicted labels, not used in GenericDetector """ X, _, _ = super()._validate_input(X, None, None) - X = X.reshape(len(X),) - X = self.representation.fit(X)#, y_true) + X = X.reshape( + len(X), + ) + X = self.representation.fit(X) # , y_true) self.reference = X def update(self, X, y_true=None, y_pred=None, alt=None): @@ -63,7 +66,9 @@ def update(self, X, y_true=None, y_pred=None, alt=None): self.reference = self.test X, _, _ = super()._validate_input(X, None, None) - X = X.reshape(len(X),) + X = X.reshape( + len(X), + ) X = self.representation.fit(X) self.test = X # , y_true, y_pred) if alt == None: @@ -88,8 +93,10 @@ def reset(self): self.div = None pass + # region Validations + class IdentityValidation: """ Identity transform for numerical features. @@ -116,7 +123,10 @@ def fit(self, X_ref, y_true=None): if any(np.issubdtype(X_ref.dtype, dtype) for dtype in numeric_dtypes): return X_ref else: - raise ValueError("No numerical data detected. Please pass numerical features.") + raise ValueError( + "No numerical data detected. Please pass numerical features." + ) + class CategoricalValidation: """ @@ -138,10 +148,15 @@ def fit(self, X_ref, y_true=None): X_ref: Confirmed categorical data. """ categorical_dtypes = [np.object_] - if (any(np.issubdtype(X_ref.dtype, dtype) for dtype in categorical_dtypes)) or len(np.unique(X_ref)) < 10: + if ( + any(np.issubdtype(X_ref.dtype, dtype) for dtype in categorical_dtypes) + ) or len(np.unique(X_ref)) < 10: return X_ref else: - raise ValueError("No categorical columns detected. Please pass categorical features.") + raise ValueError( + "No categorical columns detected. Please pass categorical features." + ) + class BinaryValidation: """ @@ -167,12 +182,16 @@ def fit(self, X_ref, y_true=None): if (set(unique_values) == {0, 1}) or (set(unique_values) == {False, True}): return X_ref else: - raise ValueError("The X_ref data must consist of only (0,1)'s or (False,True)'s for the FETDrift detector.") + raise ValueError( + "The X_ref data must consist of only (0,1)'s or (False,True)'s for the FETDrift detector." + ) + # endregion # region Hypothesis test wrappers + def chi2Divergence(rep_ref, rep_test): """ Calculate the p-value for the chi-squared test of independence between two categorical distributions. @@ -192,6 +211,7 @@ def chi2Divergence(rep_ref, rep_test): _, pval, _, _ = chi2_contingency(contingency_table) return pval + def ksDivergence(rep_ref, rep_test): """ Calculate the p-value for the Kolmogorov-Smirnov test between two distributions. @@ -206,6 +226,7 @@ def ksDivergence(rep_ref, rep_test): pval = kstest(rep_ref, rep_test).pvalue return pval + def cvmDivergence(rep_ref, rep_test): """ Calculate the p-value for the Cramér-von Mises test between two distributions. @@ -220,6 +241,7 @@ def cvmDivergence(rep_ref, rep_test): pval = cramervonmises_2samp(rep_ref, rep_test).pvalue return pval + def fetDivergence(ref, test, alternative): """ Calculate the p-value for the Fisher's Exact Test between two binary distributions. @@ -239,15 +261,19 @@ def fetDivergence(ref, test, alternative): n_ref, n = ref.shape[0], test.shape[0] p_val, odds_ratio = np.empty(1), np.empty(1) - table = np.array([[np.sum(test), np.sum(ref)], [n - np.sum(test), n_ref - np.sum(ref)]]) + table = np.array( + [[np.sum(test), np.sum(ref)], [n - np.sum(test), n_ref - np.sum(ref)]] + ) odds_ratio[0], p_val[0] = fisher_exact(table, alternative) return p_val + # endregion # region Critical value functions + def crit(pval, alpha=0.05): """ Compare p-value with alpha to determine statistical significance. @@ -264,10 +290,12 @@ def crit(pval, alpha=0.05): else: return False + # endregion # region Implemented classes + class CHIDetector(GenericDetector): """ Chi-squared Test Data drift detector using the chi-squared test for categorical data. @@ -283,6 +311,7 @@ def __init__(self): crit_function=crit, ) + class KSDetector(GenericDetector): """ Kolmogorov-Smirnov Test Data drift detector using the Kolmogorov-Smirnov test for numerical data. @@ -298,6 +327,7 @@ def __init__(self): crit_function=crit, ) + class CVMDetector(GenericDetector): """ Cramér-von MisesData Test drift detector using the Cramér-von Mises test for numerical data. @@ -313,6 +343,7 @@ def __init__(self): crit_function=crit, ) + class FETDetector(GenericDetector): """ Fisher's Exact Test Data drift detector using Fisher's Exact Test for binary data. @@ -328,4 +359,5 @@ def __init__(self): crit_function=crit, ) + # endregion