Skip to content

Commit

Permalink
Merge pull request #18 from Quansight-Labs/add_thread_unsafe_marker
Browse files Browse the repository at this point in the history
Add thread_unsafe marker as an alias for parallel_threads(1)
  • Loading branch information
andfoy authored Nov 13, 2024
2 parents 1ea2c91 + 269aa26 commit 052a35d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ Features
* Two global CLI flags:
* ``--parallel-threads`` to run a test suite in parallel
* ``--iterations`` to run multiple times in each thread
* Two corresponding markers:
* Three corresponding markers:
* ``pytest.mark.parallel_threads(n)`` to mark a single test to run in
parallel in ``n`` threads
* ``pytest.mark.thread_unsafe`` to mark a single test to run in a single
thread. It is equivalent to using ``pytest.mark.parallel_threads(1)``
* ``pytest.mark.iterations(n)`` to mark a single test to run ``n`` times
in each thread
* And the corresponding fixtures:
Expand Down
9 changes: 9 additions & 0 deletions src/pytest_run_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def pytest_configure(config):
"markers",
"iterations(n): run the given test function `n` times in each thread",
)
config.addinivalue_line(
"markers",
"thread_unsafe: mark the test function as single-threaded",
)


def wrap_function_parallel(fn, n_workers, n_iterations):
Expand Down Expand Up @@ -106,6 +110,11 @@ def pytest_itemcollected(item):
if m is not None:
n_iterations = int(m.args[0])

m = item.get_closest_marker("thread_unsafe")
if m is not None:
n_workers = 1
item.add_marker(pytest.mark.parallel_threads(1))

if n_workers > 1 or n_iterations > 1:
original_globals = item.obj.__globals__
item.obj = wrap_function_parallel(item.obj, n_workers, n_iterations)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_run_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,24 @@ def test_should_skip():
"*::test_should_skip SKIPPED*",
]
)


def test_thread_unsafe_marker(pytester):
# create a temporary pytest test module
pytester.makepyfile("""
import pytest
@pytest.mark.thread_unsafe
def test_should_run_single(num_parallel_threads):
assert num_parallel_threads == 1
""")

# run pytest with the following cmd args
result = pytester.runpytest("--parallel-threads=10", "-v")

# fnmatch_lines does an assertion internally
result.stdout.fnmatch_lines(
[
"*::test_should_run_single PASSED*",
]
)

0 comments on commit 052a35d

Please sign in to comment.