Skip to content

Commit

Permalink
Merge pull request #19 from Quansight-Labs/detect_warnings_handling
Browse files Browse the repository at this point in the history
Mark tests as single-threaded if warning handling takes place
  • Loading branch information
andfoy authored Nov 18, 2024
2 parents c1c5fa4 + 9f4531a commit 76a04a2
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 71 deletions.
76 changes: 5 additions & 71 deletions src/pytest_run_parallel/plugin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import functools
import threading
import types

import _pytest.outcomes
import pytest

try:
import numpy as np

numpy_available = True
except ImportError:
numpy_available = False
from pytest_run_parallel.utils import ThreadComparator, identify_warnings_handling


def pytest_addoption(parser):
Expand Down Expand Up @@ -115,6 +109,10 @@ def pytest_itemcollected(item):
n_workers = 1
item.add_marker(pytest.mark.parallel_threads(1))

if identify_warnings_handling(item.obj):
n_workers = 1
item.add_marker(pytest.mark.parallel_threads(1))

if n_workers > 1 or n_iterations > 1:
original_globals = item.obj.__globals__
item.obj = wrap_function_parallel(item.obj, n_workers, n_iterations)
Expand Down Expand Up @@ -143,70 +141,6 @@ def num_iterations(request):
return n_iterations


class ThreadComparator:
def __init__(self, n_threads):
self._barrier = threading.Barrier(n_threads)
self._reset_evt = threading.Event()
self._entry_barrier = threading.Barrier(n_threads)

self._thread_ids = []
self._values = {}
self._entry_lock = threading.Lock()
self._entry_counter = 0

def __call__(self, **values):
"""
Compares a set of values across threads.
For each value, type equality as well as comparison takes place. If any
of the values is a function, then address comparison is performed.
Also, if any of the values is a `numpy.ndarray`, then approximate
numerical comparison is performed.
"""
tid = id(threading.current_thread())
self._entry_barrier.wait()
with self._entry_lock:
if self._entry_counter == 0:
# Reset state before comparison
self._barrier.reset()
self._reset_evt.clear()
self._thread_ids = []
self._values = {}
self._entry_barrier.reset()
self._entry_counter += 1

self._values[tid] = values
self._thread_ids.append(tid)
self._barrier.wait()

if tid == self._thread_ids[0]:
thread_ids = list(self._values)
try:
for value_name in values:
for i in range(1, len(thread_ids)):
tid_a = thread_ids[i - 1]
tid_b = thread_ids[i]
value_a = self._values[tid_a][value_name]
value_b = self._values[tid_b][value_name]
assert type(value_a) is type(value_b)
if numpy_available and isinstance(value_a, np.ndarray):
if len(value_a.shape) == 0:
assert value_a == value_b
else:
assert np.allclose(value_a, value_b, equal_nan=True)
elif isinstance(value_a, types.FunctionType):
assert id(value_a) == id(value_b)
elif value_a != value_a:
assert value_b != value_b
else:
assert value_a == value_b
finally:
self._entry_counter = 0
self._reset_evt.set()
else:
self._reset_evt.wait()


@pytest.fixture
def thread_comp(num_parallel_threads):
return ThreadComparator(num_parallel_threads)
129 changes: 129 additions & 0 deletions src/pytest_run_parallel/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import ast
import inspect
import threading
import types
from textwrap import dedent

try:
import numpy as np

numpy_available = True
except ImportError:
numpy_available = False


class WarningNodeVisitor(ast.NodeVisitor):
def __init__(self, fn):
self.catches_warns = False
self.blacklist = {
("pytest", "warns"),
("pytest", "deprecated_call"),
("_pytest.recwarn", "warns"),
("_pytest.recwarn", "deprecated_call"),
("warnings", "catch_warnings"),
}
modules = {mod.split(".")[0] for mod, _ in self.blacklist}
modules |= {mod for mod, _ in self.blacklist}

self.modules_aliases = {}
self.func_aliases = {}
for var_name in fn.__globals__:
value = fn.__globals__[var_name]
if inspect.ismodule(value) and value.__name__ in modules:
self.modules_aliases[var_name] = value.__name__
elif inspect.isfunction(value):
real_name = value.__name__
for mod in modules:
if mod == value.__module__:
self.func_aliases[var_name] = (mod, real_name)
break

super().__init__()

def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
real_mod = node.func.value.id
if real_mod in self.modules_aliases:
real_mod = self.modules_aliases[real_mod]
if (real_mod, node.func.attr) in self.blacklist:
self.catches_warns = True
elif isinstance(node.func, ast.Name):
if node.func.id in self.func_aliases:
if self.func_aliases[node.func.id] in self.blacklist:
self.catches_warns = True


def identify_warnings_handling(fn):
try:
src = inspect.getsource(fn)
tree = ast.parse(dedent(src))
except Exception:
return False
visitor = WarningNodeVisitor(fn)
visitor.visit(tree)
return visitor.catches_warns


class ThreadComparator:
def __init__(self, n_threads):
self._barrier = threading.Barrier(n_threads)
self._reset_evt = threading.Event()
self._entry_barrier = threading.Barrier(n_threads)

self._thread_ids = []
self._values = {}
self._entry_lock = threading.Lock()
self._entry_counter = 0

def __call__(self, **values):
"""
Compares a set of values across threads.
For each value, type equality as well as comparison takes place. If any
of the values is a function, then address comparison is performed.
Also, if any of the values is a `numpy.ndarray`, then approximate
numerical comparison is performed.
"""
tid = id(threading.current_thread())
self._entry_barrier.wait()
with self._entry_lock:
if self._entry_counter == 0:
# Reset state before comparison
self._barrier.reset()
self._reset_evt.clear()
self._thread_ids = []
self._values = {}
self._entry_barrier.reset()
self._entry_counter += 1

self._values[tid] = values
self._thread_ids.append(tid)
self._barrier.wait()

if tid == self._thread_ids[0]:
thread_ids = list(self._values)
try:
for value_name in values:
for i in range(1, len(thread_ids)):
tid_a = thread_ids[i - 1]
tid_b = thread_ids[i]
value_a = self._values[tid_a][value_name]
value_b = self._values[tid_b][value_name]
assert type(value_a) is type(value_b)
if numpy_available and isinstance(value_a, np.ndarray):
if len(value_a.shape) == 0:
assert value_a == value_b
else:
assert np.allclose(value_a, value_b, equal_nan=True)
elif isinstance(value_a, types.FunctionType):
assert id(value_a) == id(value_b)
elif value_a != value_a:
assert value_b != value_b
else:
assert value_a == value_b
finally:
self._entry_counter = 0
self._reset_evt.set()
else:
self._reset_evt.wait()
47 changes: 47 additions & 0 deletions tests/test_run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,50 @@ def test_should_run_single(num_parallel_threads):
"*::test_should_run_single PASSED*",
]
)


def test_pytest_warns_detection(pytester):
# create a temporary pytest test module
pytester.makepyfile("""
import pytest
import warnings
import pytest as pyt
import warnings as w
from pytest import warns, deprecated_call
from warnings import catch_warnings
warns_alias = warns
def test_single_thread_warns_1(num_parallel_threads):
with pytest.warns(UserWarning):
warnings.warn('example', UserWarning)
assert num_parallel_threads == 1
def test_single_thread_warns_2(num_parallel_threads):
with warns(UserWarning):
warnings.warn('example', UserWarning)
assert num_parallel_threads == 1
def test_single_thread_warns_3(num_parallel_threads):
with pyt.warns(UserWarning):
warnings.warn('example', UserWarning)
assert num_parallel_threads == 1
def test_single_thread_warns_4(num_parallel_threads):
with warns_alias(UserWarning):
warnings.warn('example', UserWarning)
assert num_parallel_threads == 1
""")

# run pytest with the following cmd args
result = pytester.runpytest("--parallel-threads=10", "-v")

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines(
[
"*::test_single_thread_warns_1 PASSED*",
"*::test_single_thread_warns_2 PASSED*",
"*::test_single_thread_warns_3 PASSED*",
"*::test_single_thread_warns_4 PASSED*",
]
)

0 comments on commit 76a04a2

Please sign in to comment.