Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Nov 14, 2024
1 parent ba27d1d commit 3e04c09
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 114 deletions.
116 changes: 2 additions & 114 deletions src/pytest_run_parallel/plugin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
import functools
import threading
import types

import _pytest.outcomes
import pytest
import inspect
import ast
from textwrap import dedent


try:
import numpy as np

numpy_available = True
except ImportError:
numpy_available = False
from pytest_run_parallel.utils import (
ThreadComparator, identify_warnings_handling)


def pytest_addoption(parser):
Expand Down Expand Up @@ -101,45 +92,6 @@ def closure(*args, **kwargs):
return inner


def identify_warnings_handling(fn):
src = inspect.getsource(fn)
tree = ast.parse(dedent(src))
catches_warns = False
blacklist = {('pytest', 'warns'),
('pytest', 'deprecated_call'),
('warnings', 'catch_warnings')}
modules = {mod for mod, _ in blacklist}
modules_aliases = {}
func_aliases = {}
for var_name in fn.__globals__:
value = fn.__globals__[var_name]
if inspect.ismodule(value):
if value.__name__ in modules:
modules_aliases[var_name] = value.__name__
elif inspect.isfunction(value):
real_name = value.__name__
for mod in modules:
if mod in value.__module__:
func_aliases[var_name] = (mod, real_name)

for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
real_mod = node.func.value.id
if real_mod in modules_aliases:
real_mod = modules_aliases[real_mod]
if (real_mod, node.func.attr) in blacklist:
catches_warns = True
break
elif isinstance(node.func, ast.Name):
if node.func.id in func_aliases:
if func_aliases[node.func.id] in blacklist:
catches_warns = True
break
return catches_warns


@pytest.hookimpl(trylast=True)
def pytest_itemcollected(item):
n_workers = item.config.option.parallel_threads
Expand Down Expand Up @@ -190,70 +142,6 @@ def num_iterations(request):
return n_iterations


class ThreadComparator:
def __init__(self, n_threads):
self._barrier = threading.Barrier(n_threads)
self._reset_evt = threading.Event()
self._entry_barrier = threading.Barrier(n_threads)

self._thread_ids = []
self._values = {}
self._entry_lock = threading.Lock()
self._entry_counter = 0

def __call__(self, **values):
"""
Compares a set of values across threads.
For each value, type equality as well as comparison takes place. If any
of the values is a function, then address comparison is performed.
Also, if any of the values is a `numpy.ndarray`, then approximate
numerical comparison is performed.
"""
tid = id(threading.current_thread())
self._entry_barrier.wait()
with self._entry_lock:
if self._entry_counter == 0:
# Reset state before comparison
self._barrier.reset()
self._reset_evt.clear()
self._thread_ids = []
self._values = {}
self._entry_barrier.reset()
self._entry_counter += 1

self._values[tid] = values
self._thread_ids.append(tid)
self._barrier.wait()

if tid == self._thread_ids[0]:
thread_ids = list(self._values)
try:
for value_name in values:
for i in range(1, len(thread_ids)):
tid_a = thread_ids[i - 1]
tid_b = thread_ids[i]
value_a = self._values[tid_a][value_name]
value_b = self._values[tid_b][value_name]
assert type(value_a) is type(value_b)
if numpy_available and isinstance(value_a, np.ndarray):
if len(value_a.shape) == 0:
assert value_a == value_b
else:
assert np.allclose(value_a, value_b, equal_nan=True)
elif isinstance(value_a, types.FunctionType):
assert id(value_a) == id(value_b)
elif value_a != value_a:
assert value_b != value_b
else:
assert value_a == value_b
finally:
self._entry_counter = 0
self._reset_evt.set()
else:
self._reset_evt.wait()


@pytest.fixture
def thread_comp(num_parallel_threads):
return ThreadComparator(num_parallel_threads)
121 changes: 121 additions & 0 deletions src/pytest_run_parallel/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

import ast
import inspect
import types
from textwrap import dedent
import threading

try:
import numpy as np

numpy_available = True
except ImportError:
numpy_available = False


class WarningNodeVisitor(ast.NodeVisitor):
def __init__(self, fn):
self.catches_warns = False
self.blacklist = {('pytest', 'warns'),
('pytest', 'deprecated_call'),
('warnings', 'catch_warnings')}
modules = {mod for mod, _ in self.blacklist}
self.modules_aliases = {}
self.func_aliases = {}
for var_name in fn.__globals__:
value = fn.__globals__[var_name]
if inspect.ismodule(value) and value.__name__ in modules:
self.modules_aliases[var_name] = value.__name__
elif inspect.isfunction(value):
real_name = value.__name__
for mod in modules:
if mod in value.__module__:
self.func_aliases[var_name] = (mod, real_name)
break

super().__init__()

def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
real_mod = node.func.value.id
if real_mod in self.modules_aliases:
real_mod = self.modules_aliases[real_mod]
if (real_mod, node.func.attr) in self.blacklist:
self.catches_warns = True
elif isinstance(node.func, ast.Name):
if node.func.id in self.func_aliases:
if self.func_aliases[node.func.id] in self.blacklist:
self.catches_warns = True


def identify_warnings_handling(fn):
src = inspect.getsource(fn)
tree = ast.parse(dedent(src))
visitor = WarningNodeVisitor(fn)
visitor.visit(tree)
return visitor.catches_warns


class ThreadComparator:
def __init__(self, n_threads):
self._barrier = threading.Barrier(n_threads)
self._reset_evt = threading.Event()
self._entry_barrier = threading.Barrier(n_threads)

self._thread_ids = []
self._values = {}
self._entry_lock = threading.Lock()
self._entry_counter = 0

def __call__(self, **values):
"""
Compares a set of values across threads.
For each value, type equality as well as comparison takes place. If any
of the values is a function, then address comparison is performed.
Also, if any of the values is a `numpy.ndarray`, then approximate
numerical comparison is performed.
"""
tid = id(threading.current_thread())
self._entry_barrier.wait()
with self._entry_lock:
if self._entry_counter == 0:
# Reset state before comparison
self._barrier.reset()
self._reset_evt.clear()
self._thread_ids = []
self._values = {}
self._entry_barrier.reset()
self._entry_counter += 1

self._values[tid] = values
self._thread_ids.append(tid)
self._barrier.wait()

if tid == self._thread_ids[0]:
thread_ids = list(self._values)
try:
for value_name in values:
for i in range(1, len(thread_ids)):
tid_a = thread_ids[i - 1]
tid_b = thread_ids[i]
value_a = self._values[tid_a][value_name]
value_b = self._values[tid_b][value_name]
assert type(value_a) is type(value_b)
if numpy_available and isinstance(value_a, np.ndarray):
if len(value_a.shape) == 0:
assert value_a == value_b
else:
assert np.allclose(value_a, value_b, equal_nan=True)
elif isinstance(value_a, types.FunctionType):
assert id(value_a) == id(value_b)
elif value_a != value_a:
assert value_b != value_b
else:
assert value_a == value_b
finally:
self._entry_counter = 0
self._reset_evt.set()
else:
self._reset_evt.wait()

0 comments on commit 3e04c09

Please sign in to comment.