Skip to content

Commit

Permalink
Merge pull request #768 from a-szulc/master
Browse files Browse the repository at this point in the history
add catching of expected errors in tests (fixes #752, #753, #761, #762)
  • Loading branch information
pplonski authored Aug 29, 2024
2 parents 19272cd + 617d6f0 commit 3f65896
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 7 deletions.
3 changes: 2 additions & 1 deletion supervised/preprocessing/exclude_missing_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def transform(
logger.debug("Exclude rows with missing target values")
if warn:
warnings.warn(
"There are samples with missing target values in the data which will be excluded for further analysis"
"There are samples with missing target values in the data which will be excluded for further analysis",
UserWarning
)
y = y.drop(y.index[y_missing])
y.reset_index(drop=True, inplace=True)
Expand Down
4 changes: 3 additions & 1 deletion supervised/validation/validator_kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, params):
self.repeats = self.params.get("repeats", 1)

if not self.shuffle and self.repeats > 1:
warnings.warn("Disable repeats in validation because shuffle is disabled")
warnings.warn(
"Disable repeats in validation because shuffle is disabled", UserWarning
)
self.repeats = 1

self.skf = []
Expand Down
4 changes: 3 additions & 1 deletion supervised/validation/validator_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, params):
self.repeats = self.params.get("repeats", 1)

if not self.shuffle and self.repeats > 1:
warnings.warn("Disable repeats in validation because shuffle is disabled")
warnings.warn(
"Disable repeats in validation because shuffle is disabled", UserWarning
)
self.repeats = 1

self._results_path = self.params.get("results_path")
Expand Down
23 changes: 21 additions & 2 deletions tests/tests_automl/test_targets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil
import unittest
import pytest

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -100,7 +101,16 @@ def test_bin_class_AB_missing_targets(self):
explain_level=0,
start_random_models=1,
)
automl.fit(X, y)

with pytest.warns(
expected_warning=UserWarning,
match="There are samples with missing target values in the data which will be excluded for further analysis",
) as record:
automl.fit(X, y)

# check that only one warning was raised
self.assertEqual(len(record), 1)

p = automl.predict(X)
pred = automl.predict(X)

Expand Down Expand Up @@ -256,7 +266,16 @@ def test_multi_class_abcd_missing_target(self):
explain_level=0,
start_random_models=1,
)
automl.fit(X, y)

with pytest.warns(
expected_warning=UserWarning,
match="There are samples with missing target values in the data which will be excluded for further analysis",
) as record:
automl.fit(X, y)

# check that only one warning was raised
self.assertEqual(len(record), 1)

pred = automl.predict(X)

u = np.unique(pred)
Expand Down
11 changes: 10 additions & 1 deletion tests/tests_validation/test_validator_kfold.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
import unittest
import pytest

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -194,7 +195,15 @@ def test_disable_repeats_when_disabled_shuffle(self):
"y_path": y_path,
"random_seed": 1,
}
vl = KFoldValidator(params)

with pytest.warns(
expected_warning=UserWarning,
match="Disable repeats in validation because shuffle is disabled",
) as record:
vl = KFoldValidator(params)

# check that only one warning was raised
self.assertEqual(len(record), 1)

self.assertEqual(params["k_folds"], vl.get_n_splits())
self.assertEqual(1, vl.get_repeats())
11 changes: 10 additions & 1 deletion tests/tests_validation/test_validator_split.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
import unittest
import pytest

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -211,7 +212,15 @@ def test_disable_repeats_when_disabled_shuffle(self):
"y_path": y_path,
"repeats": 3,
}
vl = SplitValidator(params)

with pytest.warns(
expected_warning=UserWarning,
match="Disable repeats in validation because shuffle is disabled",
) as record:
vl = SplitValidator(params)

# check that only one warning was raised
self.assertEqual(len(record), 1)

self.assertEqual(1, vl.get_n_splits())
self.assertEqual(1, vl.get_repeats())

0 comments on commit 3f65896

Please sign in to comment.