Skip to content

Commit

Permalink
free-threading: Test suite
Browse files Browse the repository at this point in the history
Several parallel tests to check that locking works as expected
  • Loading branch information
wjakob committed Sep 8, 2024
1 parent 47d9b76 commit 5c5fdb1
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ set(TEST_NAMES
typing
issue
intrusive
thread
)

foreach (NAME ${TEST_NAMES})
Expand Down Expand Up @@ -138,6 +139,7 @@ set(TEST_FILES
test_ndarray.py
test_stubs.py
test_typing.py
test_thread.py

# Stub reference files
test_classes_ext.pyi.ref
Expand Down
37 changes: 37 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <nanobind/nanobind.h>

namespace nb = nanobind;
using namespace nb::literals;

struct Counter {
size_t value = 0;
void inc() { value++; }
void merge(Counter &o) {
value += o.value;
o.value = 0;
}
};

nb::ft_mutex mutex;

NB_MODULE(test_thread_ext, m) {
nb::class_<Counter>(m, "Counter")
.def(nb::init<>())
.def_ro("value", &Counter::value)
.def("inc_unsafe", &Counter::inc)
.def("inc_safe", &Counter::inc, nb::lock_self())
.def("merge_unsafe", &Counter::merge)
.def("merge_safe", &Counter::merge, nb::lock_self(), "o"_a.lock());

m.def("return_self", [](Counter *c) -> Counter * { return c; });

m.def("inc_safe",
[](Counter &c) { c.inc(); },
"counter"_a.lock());

m.def("inc_global",
[](Counter &c) {
nb::ft_lock_guard guard(mutex);
c.inc();
}, "counter");
}
100 changes: 100 additions & 0 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import test_thread_ext as t
from test_thread_ext import Counter

import threading

# Helper function to parallelize execution of a function. We intentionally
# don't use the Python threads pools here to have threads shut down / start
# between test cases.
def parallelize(func, n_threads):
barrier = threading.Barrier(n_threads)
result = [None]*n_threads

def wrapper(i):
barrier.wait()
result[i] = func()

workers = []
for i in range(n_threads):
t = threading.Thread(target=wrapper, args=(i,))
t.start()
workers.append(t)

for worker in workers:
worker.join()
return result


def test01_object_creation(n_threads=8):
# This test hammers 'inst_c2p' from multiple threads, and
# checks that the locking of internal data structures works

n = 100000
def f():
r = [None]*n
for i in range(n):
c = Counter()
c.inc_unsafe()
r[i] = c
for i in range(n):
assert t.return_self(r[i]) is r[i]
return r

v = parallelize(f, n_threads=n_threads)
assert len(v) == n_threads
for v2 in v:
assert len(v2) == n
for v3 in v2:
assert v3.value == 1

def test02_global_lock(n_threads=8):
# Test that a global PyMutex protects the counter
n = 100000
c = Counter()
def f():
for i in range(n):
t.inc_global(c)

parallelize(f, n_threads=n_threads)
assert c.value == n * n_threads


def test03_locked_method(n_threads=8):
# Checks that nb::lock_self() protects an internal counter
n = 100000
c = Counter()
def f():
for i in range(n):
c.inc_safe()

parallelize(f, n_threads=n_threads)
assert c.value == n * n_threads


def test04_locked_function(n_threads=8):
# Checks that nb::lock_self() protects an internal counter
n = 100000
c = Counter()
def f():
for i in range(n):
t.inc_safe(c)

parallelize(f, n_threads=n_threads)
assert c.value == n * n_threads


def test05_locked_twoargs(n_threads=8):
# Check two-argument locking
n = 100000
c = Counter()
def f():
c2 = Counter()
for i in range(n):
c2.inc_unsafe()
if i & 1 == 0:
c2.merge_safe(c)
else:
c.merge_safe(c2)

parallelize(f, n_threads=n_threads)
assert c.value == n * n_threads

0 comments on commit 5c5fdb1

Please sign in to comment.