Skip to content

Commit

Permalink
Update test_psi_detector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
951378644 committed Nov 22, 2023
1 parent 4fbbe4c commit 3266872
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions tests/menelaus/data_drift/test_psi_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
import pandas as pd
from menelaus.data_drift import PSI

def test_psi_init():
Expand All @@ -15,7 +16,7 @@ def test_psi_set_reference():
det = PSI()
ref = np.random.randint(0, 5, (9, 1))
det.set_reference(ref)
assert np.array_equal(ref, det.reference_batch)
assert np.array_equal(ref, det.reference)

def test_psi_update_1():
"""Ensure PSI can update with small random batches"""
Expand All @@ -42,13 +43,6 @@ def test_psi_update_3():
det.update(X=np.random.randint(0, 5, (25, 1)))
assert det.drift_state is None

def test_psi_update_4():
"""Check failure when batch shapes don't match"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (30, 1)))
with pytest.raises(ValueError):
det.update(np.random.randint(0, 5, (25, 1)))

def test_psi_reset():
"""Check psi.reset works as intended"""
det = PSI()
Expand All @@ -62,11 +56,8 @@ def test_psi_reset():
def test_psi_compute_PSI():
"""Check psi._compute_threshold works correctly"""
det = PSI()
# XXX - Hardcoded known example added by AS, in the future a
# dynamic way to test this function may be used
np.random.seed(123)
threshold = det._PSI(
reference_feature=np.random.randint(0, 2, 5),
test_feature=np.random.randint(0, 2, 5),
)
PSI.set_reference(np.random.randint(0,100,100))
PSI.update(np.random.randint(0,100,100))
threshold = PSI.PSI_value
assert threshold >= 0 and threshold <= 1

0 comments on commit 3266872

Please sign in to comment.