diff --git a/AUTHORS b/AUTHORS index 374e6ad9bcc..b1dd40dbd4c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -400,6 +400,7 @@ Stefanie Molin Stefano Taschini Steffen Allner Stephan Obermann +Sven Sven-Hendrik Haase Sviatoslav Sydorenko Sylvain MariƩ diff --git a/changelog/12749.feature.rst b/changelog/12749.feature.rst new file mode 100644 index 00000000000..cda0db6c930 --- /dev/null +++ b/changelog/12749.feature.rst @@ -0,0 +1,5 @@ +New :confval:`collect_imported_tests`: when enabled (the default) pytest will collect classes/functions in test modules even if they are imported from another file. + +Setting this to False will make pytest collect classes/functions from test files only if they are defined in that file (as opposed to imported there). + +-- by :user:`FreerGit` diff --git a/doc/en/reference/reference.rst b/doc/en/reference/reference.rst index f7dfb3ffa71..53f9470f756 100644 --- a/doc/en/reference/reference.rst +++ b/doc/en/reference/reference.rst @@ -1301,6 +1301,20 @@ passed multiple times. The expected format is ``name=value``. For example:: variables, that will be expanded. For more information about cache plugin please refer to :ref:`cache_provider`. +.. confval:: collect_imported_tests + + .. versionadded:: 8.4 + + Setting this to ``false`` will make pytest collect classes/functions from test + files only if they are defined in that file (as opposed to imported there). + + .. code-block:: ini + + [pytest] + collect_imported_tests = false + + Default: ``true`` + .. confval:: consider_namespace_packages Controls if pytest should attempt to identify `namespace packages `__ @@ -1838,11 +1852,8 @@ passed multiple times. The expected format is ``name=value``. For example:: pytest testing doc - .. confval:: tmp_path_retention_count - - How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`. diff --git a/src/_pytest/main.py b/src/_pytest/main.py index e5534e98d69..41063a9bc18 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -78,6 +78,12 @@ def pytest_addoption(parser: Parser) -> None: type="args", default=[], ) + parser.addini( + "collect_imported_tests", + "Whether to collect tests in imported modules outside `testpaths`", + type="bool", + default=True, + ) group = parser.getgroup("general", "Running and selection options") group._addoption( "-x", diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 9c54dd20f80..9c00252c5f8 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -416,6 +416,15 @@ def collect(self) -> Iterable[nodes.Item | nodes.Collector]: if name in seen: continue seen.add(name) + + if not self.session.config.getini("collect_imported_tests"): + # Do not collect imported functions + if inspect.isfunction(obj) and isinstance(self, Module): + fn_defined_at = obj.__module__ + in_module = self._getobj().__name__ + if fn_defined_at != in_module: + continue + res = ihook.pytest_pycollect_makeitem( collector=self, name=name, obj=obj ) @@ -741,6 +750,16 @@ def newinstance(self): return self.obj() def collect(self) -> Iterable[nodes.Item | nodes.Collector]: + if not self.config.getini("collect_imported_tests"): + # This entire branch will discard (not collect) a class + # if it is imported (defined in a different module) + if isinstance(self, Class) and isinstance(self.parent, Module): + if inspect.isclass(self._getobj()): + class_defined_at = self._getobj().__module__ + in_module = self.parent._getobj().__name__ + if class_defined_at != in_module: + return [] + if not safe_getattr(self.obj, "__test__", True): return [] if hasinit(self.obj): diff --git a/testing/test_collect_imports.py b/testing/test_collect_imports.py new file mode 100644 index 00000000000..1c56c9155e5 --- /dev/null +++ b/testing/test_collect_imports.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import textwrap +from typing import Any + +from _pytest.fixtures import FixtureRequest +from _pytest.main import Session +from _pytest.pytester import Pytester +from _pytest.pytester import RecordedHookCall +from _pytest.pytester import RunResult +from _pytest.reports import CollectReport +import pytest + + +# Start of tests for classes + + +def run_import_class_test( + pytester: Pytester, passed: int = 0, errors: int = 0 +) -> RunResult: + src_dir = pytester.mkdir("src") + tests_dir = pytester.mkdir("tests") + src_file = src_dir / "foo.py" + + src_file.write_text( + textwrap.dedent("""\ + class Testament(object): + def __init__(self): + super().__init__() + self.collections = ["stamp", "coin"] + + def personal_property(self): + return [f"my {x} collection" for x in self.collections] + """), + encoding="utf-8", + ) + + test_file = tests_dir / "foo_test.py" + test_file.write_text( + textwrap.dedent("""\ + import sys + import os + + current_file = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file) + parent_dir = os.path.abspath(os.path.join(current_dir, '..')) + sys.path.append(parent_dir) + + from src.foo import Testament + + class TestDomain: + def test_testament(self): + testament = Testament() + assert testament.personal_property() + """), + encoding="utf-8", + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=passed, errors=errors) + return result + + +def test_collect_imports_disabled(pytester: Pytester) -> None: + pytester.makeini(""" + [pytest] + testpaths = "tests" + collect_imported_tests = false + """) + + run_import_class_test(pytester, passed=1) + + # Verify that the state of hooks + reprec = pytester.inline_run() + reports = reprec.getreports("pytest_collectreport") + modified = reprec.getcalls("pytest_collection_modifyitems") + items_collected = reprec.getcalls("pytest_itemcollected") + + assert len(reports) == 5 + assert len(modified) == 1 + assert len(items_collected) == 1 + for x in items_collected: + assert x.item._getobj().__name__ == "test_testament" + + +def test_collect_imports_default(pytester: Pytester) -> None: + run_import_class_test(pytester, errors=1) + + +def test_collect_imports_enabled(pytester: Pytester) -> None: + pytester.makeini(""" + [pytest] + collect_imported_tests = true + """) + + run_import_class_test(pytester, errors=1) + + +# End of tests for classes +################################# +# Start of tests for functions + + +def run_import_functions_test( + pytester: Pytester, passed: int, errors: int, failed: int +) -> RunResult: + # Note that these "tests" should _not_ be treated as tests if `collect_imported_tests = false` + # They are normal functions in that case, that happens to have test_* or *_test in the name. + # Thus should _not_ be collected! + pytester.makepyfile( + **{ + "src/foo.py": textwrap.dedent( + """\ + def test_function(): + some_random_computation = 5 + return some_random_computation + + def test_bar(): + pass + """ + ) + } + ) + + # Inferred from the comment above, this means that there is _only_ one actual test + # which should result in only 1 passing test being ran. + pytester.makepyfile( + **{ + "tests/foo_test.py": textwrap.dedent( + """\ + import sys + import os + + current_file = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file) + parent_dir = os.path.abspath(os.path.join(current_dir, '..')) + sys.path.append(parent_dir) + + from src.foo import * + + class TestDomain: + def test_important(self): + res = test_function() + if res == 5: + pass + """ + ) + } + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=passed, errors=errors, failed=failed) + return result + + +def test_collect_function_imports_enabled(pytester: Pytester) -> None: + pytester.makeini(""" + [pytest] + testpaths = "tests" + collect_imported_tests = true + """) + + run_import_functions_test(pytester, passed=2, errors=0, failed=1) + reprec = pytester.inline_run() + items_collected = reprec.getcalls("pytest_itemcollected") + # Recall that the default is `collect_imported_tests = true`. + # Which means that the normal functions are now interpreted as + # valid tests and `test_function()` will fail. + assert len(items_collected) == 3 + for x in items_collected: + assert x.item._getobj().__name__ in [ + "test_important", + "test_bar", + "test_function", + ] + + +def test_behaviour_without_testpaths_set_and_false(pytester: Pytester) -> None: + # Make sure `collect_imported_tests` has no dependence on `testpaths` + pytester.makeini(""" + [pytest] + collect_imported_tests = false + """) + + run_import_functions_test(pytester, passed=1, errors=0, failed=0) + reprec = pytester.inline_run() + items_collected = reprec.getcalls("pytest_itemcollected") + assert len(items_collected) == 1 + for x in items_collected: + assert x.item._getobj().__name__ == "test_important" + + +def test_behaviour_without_testpaths_set_and_true(pytester: Pytester) -> None: + # Make sure `collect_imported_tests` has no dependence on `testpaths` + pytester.makeini(""" + [pytest] + collect_imported_tests = true + """) + + run_import_functions_test(pytester, passed=2, errors=0, failed=1) + reprec = pytester.inline_run() + items_collected = reprec.getcalls("pytest_itemcollected") + assert len(items_collected) == 3 + + +class TestHookBehaviour: + collect_outcomes: dict[str, Any] = {} + + @pytest.mark.parametrize("step", [1, 2, 3]) + def test_hook_behaviour(self, pytester: Pytester, step: int) -> None: + if step == 1: + self._test_hook_default_behaviour(pytester) + elif step == 2: + self._test_hook_behaviour_when_collect_off(pytester) + elif step == 3: + self._test_hook_behaviour() + + @pytest.fixture(scope="class", autouse=True) + def setup_collect_outcomes(self, request: FixtureRequest) -> None: + request.cls.collect_outcomes = {} + + def _test_hook_default_behaviour(self, pytester: Pytester) -> None: + pytester.makepyfile( + **{ + "tests/foo_test.py": textwrap.dedent( + """\ + class TestDomain: + def test_important(self): + pass + """ + ) + } + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=1) + reprec = pytester.inline_run() + reports = reprec.getreports("pytest_collectreport") + modified = reprec.getcalls("pytest_collection_modifyitems") + items_collected = reprec.getcalls("pytest_itemcollected") + + self.collect_outcomes["default"] = { + "result": result.parseoutcomes(), + "modified": modified, + "items_collected": items_collected, + "reports": reports, + } + + def _test_hook_behaviour_when_collect_off(self, pytester: Pytester) -> None: + pytester.makeini(""" + [pytest] + collect_imported_tests = false + """) + res = run_import_functions_test(pytester, passed=1, errors=0, failed=0) + reprec = pytester.inline_run() + reports = reprec.getreports("pytest_collectreport") + modified = reprec.getcalls("pytest_collection_modifyitems") + items_collected = reprec.getcalls("pytest_itemcollected") + + self.collect_outcomes["collect_off"] = { + "result": res.parseoutcomes(), + "modified": modified, + "items_collected": items_collected, + "reports": reports, + } + + def _test_hook_behaviour(self) -> None: + default = self.collect_outcomes["default"] + collect_off = self.collect_outcomes["collect_off"] + + # Check that the two tests above did indeed result in the same outcome. + assert default["result"] == collect_off["result"] + + assert len(default["modified"]) == len(collect_off["modified"]) == 1 + + def_modified_record: RecordedHookCall = default["modified"][0] + off_modified_record: RecordedHookCall = collect_off["modified"][0] + def_sess: Session = def_modified_record.__dict__["session"] + off_sess: Session = off_modified_record.__dict__["session"] + + assert def_sess.exitstatus == off_sess.exitstatus + assert def_sess.testsfailed == off_sess.testsfailed + assert def_sess.testscollected == off_sess.testscollected + + def_items = def_modified_record.__dict__["items"] + off_items = off_modified_record.__dict__["items"] + assert len(def_items) == len(off_items) == 1 + assert def_items[0].name == off_items[0].name + + assert ( + len(default["items_collected"]) == len(collect_off["items_collected"]) == 1 + ) + + # Check if the same tests got collected + def_items_record: RecordedHookCall = default["items_collected"][0] + off_items_record: RecordedHookCall = collect_off["items_collected"][0] + def_items = def_items_record.__dict__["item"] + off_items = off_items_record.__dict__["item"] + assert def_items.name == off_items.name + + def compare_report(r1: CollectReport, r2: CollectReport) -> None: + assert r1.result[0].name == r2.result[0].name + assert len(r1.result) == len(r2.result) + assert r1.outcome == r2.outcome + + # Function test_important + compare_report(default["reports"][1], collect_off["reports"][2]) + # Class TestDomain + compare_report(default["reports"][2], collect_off["reports"][3]) + # Module foo_test.py + compare_report(default["reports"][3], collect_off["reports"][4]) + + # + 1 since src dir is collected + assert len(default["reports"]) + 1 == len(collect_off["reports"]) + + # Two Dirs will be collected, src and test. + assert len(collect_off["reports"][5].result) == 2