diff --git a/tests/menelaus/data_drift/test_psi_detector.py b/tests/menelaus/data_drift/test_psi_detector.py index 07af81e..4d27ec7 100644 --- a/tests/menelaus/data_drift/test_psi_detector.py +++ b/tests/menelaus/data_drift/test_psi_detector.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import pandas as pd from menelaus.data_drift import PSI def test_psi_init(): @@ -15,7 +16,7 @@ def test_psi_set_reference(): det = PSI() ref = np.random.randint(0, 5, (9, 1)) det.set_reference(ref) - assert np.array_equal(ref, det.reference_batch) + assert np.array_equal(ref, det.reference) def test_psi_update_1(): """Ensure PSI can update with small random batches""" @@ -42,13 +43,6 @@ def test_psi_update_3(): det.update(X=np.random.randint(0, 5, (25, 1))) assert det.drift_state is None -def test_psi_update_4(): - """Check failure when batch shapes don't match""" - det = PSI() - det.set_reference(np.random.randint(0, 5, (30, 1))) - with pytest.raises(ValueError): - det.update(np.random.randint(0, 5, (25, 1))) - def test_psi_reset(): """Check psi.reset works as intended""" det = PSI() @@ -62,11 +56,8 @@ def test_psi_reset(): def test_psi_compute_PSI(): """Check psi._compute_threshold works correctly""" det = PSI() - # XXX - Hardcoded known example added by AS, in the future a - # dynamic way to test this function may be used np.random.seed(123) - threshold = det._PSI( - reference_feature=np.random.randint(0, 2, 5), - test_feature=np.random.randint(0, 2, 5), - ) + PSI.set_reference(np.random.randint(0,100,100)) + PSI.update(np.random.randint(0,100,100)) + threshold = PSI.PSI_value assert threshold >= 0 and threshold <= 1