Skip to content

Commit

Permalink
Update test_psi_detector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
951378644 committed Nov 22, 2023
1 parent 36f96e2 commit 3db9642
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/menelaus/data_drift/test_psi_detector.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
import pytest
import numpy as np
from menelaus.data_drift.psi_detector import PSI_Detector
from menelaus.data_drift.psi_detector import PSI

def test_psi_init():
"""Test correct default initialization for PSI"""
det = PSI_Detector()
det = PSI()
assert det.eps == 1e-4
assert det.threshold == 0.1
assert det.batches_since_reset == 0
assert det.drift_state is None

def test_psi_set_reference():
"""Assert PSI.set_reference works as intended"""
det = PSI_Detector()
det = PSI()
ref = np.random.randint(0, 5, (3, 3))
det.set_reference(ref)
assert np.array_equal(ref, det.reference_batch)

def test_psi_update_1():
"""Ensure PSI can update with small random batches"""
det = PSI_Detector()
det = PSI()
det.set_reference(np.random.randint(0, 5, (10, 10)))
det.update(X=np.random.randint(0, 5, (10, 10)))

def test_psi_update_2():
"""Ensure PSI can update with drift actions triggered"""
det = PSI_Detector()
det = PSI()
# XXX - AS added this method of forcing drift in psi, which
# is otherwise hard to induce drift in, for small data
# examples. More stable alternatives may exist
Expand All @@ -36,22 +36,22 @@ def test_psi_update_2():

def test_psi_update_3():
"""Check PSI.update behavior after drift alarm"""
det = PSI_Detector()
det = PSI()
det.set_reference(np.random.randint(0, 5, (5, 5)))
det._drift_state = "drift"
det.update(X=np.random.randint(0, 5, (5, 5)))
assert det.drift_state is None

def test_psi_update_4():
"""Check failure when batch shapes don't match"""
det = PSI_Detector()
det = PSI()
det.set_reference(np.random.randint(0, 5, (5, 6)))
with pytest.raises(ValueError):
det.update(np.random.randint(0, 5, (5, 5)))

def test_psi_reset():
"""Check psi.reset works as intended"""
det = PSI_Detector()
det = PSI()
det.batches_since_reset = 1
det.drift_state = "drift"
det.reset()
Expand All @@ -61,7 +61,7 @@ def test_psi_reset():

def test_psi_compute_PSI():
"""Check psi._compute_threshold works correctly"""
det = PSI_Detector()
det = PSI()
# XXX - Hardcoded known example added by AS, in the future a
# dynamic way to test this function may be used
np.random.seed(123)
Expand Down

0 comments on commit 3db9642

Please sign in to comment.