Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

160 psi detector #162

Open
wants to merge 34 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2ff709c
Create PSI Detector.py
951378644 Sep 19, 2023
2c4a9e3
Move file to a different location
951378644 Sep 19, 2023
360f552
Update PSI Detector.py
951378644 Sep 19, 2023
fd3cac5
Update PSI Detector.py
951378644 Sep 19, 2023
b806132
Update PSI Detector.py
951378644 Sep 19, 2023
6e18b84
Update PSI Detector.py
951378644 Sep 19, 2023
5c7de0c
rename
951378644 Sep 19, 2023
5c0039f
Reformatted code with black
951378644 Sep 19, 2023
2b808ab
add PSI to data_drift.init, add skeleton for unit tests
tms-bananaquit Nov 17, 2023
9541574
Update refs.bib
951378644 Nov 21, 2023
395d840
Update psi_detector.py
951378644 Nov 21, 2023
6a131c9
formating code with black
951378644 Nov 21, 2023
9fed735
Update test_psi_detector.py
951378644 Nov 22, 2023
36f96e2
Update psi_detector.py
951378644 Nov 22, 2023
3db9642
Update test_psi_detector.py
951378644 Nov 22, 2023
46c72e7
Update test_psi_detector.py
951378644 Nov 22, 2023
4fbbe4c
Update test_psi_detector.py
951378644 Nov 22, 2023
3266872
Update test_psi_detector.py
951378644 Nov 22, 2023
4d4d532
Update test_psi_detector.py
951378644 Nov 22, 2023
6bb398d
Update test_psi_detector.py
951378644 Nov 22, 2023
38d15e4
Update test_psi_detector.py
951378644 Nov 22, 2023
67dd5ae
Update test_psi_detector.py
951378644 Nov 22, 2023
4e1cd44
Update psi_detector.py
951378644 Nov 22, 2023
cc85fac
Update psi_detector.py
951378644 Nov 22, 2023
5d9f06f
Update test_psi_detector.py
951378644 Nov 22, 2023
9e91dcc
Update test_psi_detector.py
951378644 Nov 22, 2023
a46710c
Update test_psi_detector.py
951378644 Nov 22, 2023
dab9abf
Update test_psi_detector.py
951378644 Nov 22, 2023
2e9656a
Update psi_detector.py
951378644 Nov 22, 2023
4d8cbc2
Update test_psi_detector.py
951378644 Nov 22, 2023
1530e7f
Update test_psi_detector.py
951378644 Nov 22, 2023
9bf792a
Update psi_detector.py
951378644 Nov 22, 2023
b3a724e
Update test_psi_detector.py
951378644 Nov 22, 2023
614eb0f
update
951378644 Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,20 @@ @misc{souza2020
year={2020},
howpublished="\url{https://arxiv.org/abs/2005.00113}",
note={Online; accessed 20-July-2022},
}
}

@misc{Psi2022,
title={Is your ML model stable? Checking model stability and population drift with PSI and CSI},
author={Vinícius Trevisan},
year={2022},
howpublished="\url{https://towardsdatascience.com/checking-model-stability-and-population-shift-with-psi-and-csi-6d12af008783}",
note={Online; accessed 20-June-2023},
}

@misc{Psi2022,
title={Population Stability Index (PSI)},
author={Selva Prabhakaran},
year={2022},
howpublished="\url{https://www.machinelearningplus.com/deployment/population-stability-index-psi/}",
note={Online; accessed 20-June-2023},
}
7 changes: 4 additions & 3 deletions menelaus/data_drift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
are applied and when the results are verified. Data drift detection is also
applicable in unsupervised learning settings.
"""
from menelaus.data_drift.cdbd import CDBD
from menelaus.data_drift.hdddm import HDDDM
from menelaus.data_drift.histogram_density_method import HistogramDensityMethod
from menelaus.data_drift.kdq_tree import KdqTreeStreaming, KdqTreeBatch
from menelaus.data_drift.pca_cd import PCACD
from menelaus.data_drift.nndvi import NNDVI
from menelaus.data_drift.cdbd import CDBD
from menelaus.data_drift.histogram_density_method import HistogramDensityMethod
from menelaus.data_drift.pca_cd import PCACD
from menelaus.data_drift.psi_detector import PSI
145 changes: 145 additions & 0 deletions menelaus/data_drift/psi_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from menelaus.detector import BatchDetector
import pandas as pd
import numpy as np


class PSI(BatchDetector):
"""
Parent class for PSI-based drift detector, it serves as a fundamental framework for batch data applications.

The PSI (Population Stability Index) is employed for detecting distributional shifts between a reference population
and a comparison population. This detector assesses changes by calculating the PSI, which measures the distributional
change based on percentiles. The psi function in the detector compares the distributions of scores in reference and
test populations and calculates the PSI values for different bins.

In summary, the PSI drift detector provides a robust mechanism for monitoring and detecting distributional changes in
populations, making it adaptable for various data settings and applications.

Ref. :cite:t:`Psi2022`
"""

input_type = "batch"

def __init__(self, eps=1e-4, threshold=0.1):
"""
Args:
eps:The eps parameter in the function represents a small constant (1e-4) introduced to prevent division by zero
when calculating percentages, ensuring numerical stability.
threshold: It represents the threshold for detecting drift, if the calculated PSI value for a feature exceeds
this threshold, it indicates drift in that feature, and the drift_state is set to 'drift'. This threshold is a
user-defined value, and when crossed, it signifies a significant distributional change between the reference
and test populations.
"""
super().__init__()
self.eps = eps
self.threshold = threshold

def set_reference(self, X, y_true=None, y_pred=None):
"""
Set the detector's reference batch to an updated value; typically
used in ``update``.

Attributes:
X (numpy.array): updated reference batch
y_true (numpy.array): true labels, not used in NNDVI
y_pred (numpy.array): predicted labels, not used in NNDVI
"""
X, _, _ = super()._validate_input(X, None, None)
self.reference = X.reshape(
len(X),
)

def reset(self):
"""
Initialize relevant attributes to original values, to ensure information
only stored from samples_since_reset onwards. Intended for use
after ``drift_state == 'drift'``.
"""
super().reset()

def update(self, X: np.array, y_true=None, y_pred=None):
"""
Update the detector with a new test batch. If drift is detected, new
reference batch becomes most recent test batch.

Args:
X (numpy.array): next batch of data to detect drift on.
y_true (numpy.array): true labels, not used in PSI
y_pred (numpy.array): predicted labels, not used in PSI
"""
if self._drift_state == "drift":
self.reset()

X, _, _ = super()._validate_input(X, None, None)

super().update(X=X, y_true=None, y_pred=None)
test_batch = (np.array(X)).reshape(
len(X),
)
min_val = min(min(self.reference), min(test_batch))
max_val = max(max(self.reference), max(test_batch))
bins = self._bin_data(self.reference, min_val, max_val)
bins_initial = pd.cut(self.reference, bins=bins, labels=range(1, len(bins)))
df_initial = pd.DataFrame({"initial": self.reference, "bin": bins_initial})
grp_initial = df_initial.groupby("bin").count()
grp_initial["percent_initial"] = grp_initial["initial"] / sum(
grp_initial["initial"]
)
bins_new = pd.cut(test_batch, bins=bins, labels=range(1, len(bins)))
df_new = pd.DataFrame({"new": test_batch, "bin": bins_new})
grp_new = df_new.groupby("bin").count()
grp_new["percent_new"] = grp_new["new"] / sum(grp_new["new"])
psi_value = self._PSI(grp_initial, grp_new)
if psi_value >= self.threshold:
self._drift_state = "drift"
self.set_reference(test_batch)
return psi_value

def _bin_data(self, feature, min, max):
"""
Bin the given feature based on the specified minimum and maximum values.

Args:
feature (numpy.array): The feature to be binned.
min (float): The minimum value for binning.
max (float): The maximum value for binning.

Returns:
list: A list of bin edges for the given feature.
"""
if len(np.unique(feature)) < 10:
bins = [
min + (max - min) * (i) / len(np.unique(feature))
for i in range(len(np.unique(feature)) + 1)
]
bins[0] = min - self.eps
bins[-1] = max + self.eps
return bins
else:
bins = [min + (max - min) * (i) / 10 for i in range(10 + 1)]
bins[0] = min - self.eps
bins[-1] = max + self.eps
return bins

def _PSI(self, reference_feature, test_feature):
"""
Calculate the Population Stability Index (PSI) between reference and test features.

Args:
reference_feature (pandas.DataFrame): Reference feature distribution.
test_feature (pandas.DataFrame): Test feature distribution.

Returns:
float: The calculated PSI value indicating distributional change.
"""
psi_df = reference_feature.join(test_feature, on="bin", how="inner")
psi_df["percent_initial"] = psi_df["percent_initial"].apply(
lambda x: self.eps if x == 0 else x
)
psi_df["percent_new"] = psi_df["percent_new"].apply(
lambda x: self.eps if x == 0 else x
)
psi_df["psi"] = (psi_df["percent_initial"] - psi_df["percent_new"]) * np.log(
psi_df["percent_initial"] / psi_df["percent_new"]
)
return np.mean(psi_df["psi"])
50 changes: 50 additions & 0 deletions tests/menelaus/data_drift/test_psi_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import numpy as np
import pandas as pd
from menelaus.data_drift import PSI

def test_psi_init():
"""Test correct default initialization for PSI"""
det = PSI()
assert det.eps == 1e-4
assert det.threshold == 0.1
assert det.batches_since_reset == 0
assert det.drift_state is None

def test_psi_set_reference():
"""Assert PSI.set_reference works as intended"""
det = PSI()
ref = np.random.randint(0, 5, (100,1))
det.set_reference(ref)
assert (det.reference).ndim == 1

def test_psi_update_1():
"""Ensure PSI can update with small random batches"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (10, 1)))
det.update(X=np.random.randint(0, 5, (10, 1)))

def test_psi_update_2():
"""Ensure PSI can update with drift actions triggered"""
det = PSI()
np.random.seed(123)
det.set_reference(np.random.randint(0, 100, (200, 1)))
det.update(X=np.random.randint(150, 200, (100, 1)))
assert det.drift_state is not None

def test_psi_update_3():
"""Check PSI.update behavior after drift alarm"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (25, 1)))
det._drift_state = "drift"
det.update(X=np.random.randint(0, 5, (25, 1)))
assert det.drift_state is None

def test_psi_reset():
"""Check psi.reset works as intended"""
det = PSI()
det.batches_since_reset = 1
det.drift_state = "drift"
det.reset()
assert det.batches_since_reset == 0
assert det.drift_state is None