Skip to content

Commit

Permalink
Update test_psi_detector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
951378644 committed Nov 22, 2023
1 parent 46c72e7 commit 4fbbe4c
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/menelaus/data_drift/test_psi_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_psi_init():
def test_psi_set_reference():
"""Assert PSI.set_reference works as intended"""
det = PSI()
ref = np.random.randint(0, 5, (3, 3))
ref = np.random.randint(0, 5, (9, 1))
det.set_reference(ref)
assert np.array_equal(ref, det.reference_batch)

def test_psi_update_1():
"""Ensure PSI can update with small random batches"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (10, 10)))
det.update(X=np.random.randint(0, 5, (10, 10)))
det.set_reference(np.random.randint(0, 5, (100, 1)))
det.update(X=np.random.randint(0, 5, (100, 1)))

def test_psi_update_2():
"""Ensure PSI can update with drift actions triggered"""
Expand All @@ -30,24 +30,24 @@ def test_psi_update_2():
# is otherwise hard to induce drift in, for small data
# examples. More stable alternatives may exist
np.random.seed(123)
det.set_reference(np.random.randint(0, 5, (10, 10)))
det.update(X=np.random.randint(10, 40, (10, 10)))
det.set_reference(np.random.randint(0, 5, (100, 1)))
det.update(X=np.random.randint(10, 40, (100, 1)))
assert det.drift_state is not None

def test_psi_update_3():
"""Check PSI.update behavior after drift alarm"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (5, 5)))
det.set_reference(np.random.randint(0, 5, (25, 1)))
det._drift_state = "drift"
det.update(X=np.random.randint(0, 5, (5, 5)))
det.update(X=np.random.randint(0, 5, (25, 1)))
assert det.drift_state is None

def test_psi_update_4():
"""Check failure when batch shapes don't match"""
det = PSI()
det.set_reference(np.random.randint(0, 5, (5, 6)))
det.set_reference(np.random.randint(0, 5, (30, 1)))
with pytest.raises(ValueError):
det.update(np.random.randint(0, 5, (5, 5)))
det.update(np.random.randint(0, 5, (25, 1)))

def test_psi_reset():
"""Check psi.reset works as intended"""
Expand All @@ -66,7 +66,7 @@ def test_psi_compute_PSI():
# dynamic way to test this function may be used
np.random.seed(123)
threshold = det._PSI(
v_ref=np.random.randint(0, 2, 5),
v_test=np.random.randint(0, 2, 5),
reference_feature=np.random.randint(0, 2, 5),
test_feature=np.random.randint(0, 2, 5),
)
assert threshold >= 0 and threshold <= 1

0 comments on commit 4fbbe4c

Please sign in to comment.