Skip to content

Commit

Permalink
multi-threaded allocation test
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Aug 23, 2024
1 parent a32399d commit fdaddeb
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ nanobind_add_module(test_exception_ext test_exception.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_make_iterator_ext test_make_iterator.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_typing_ext test_typing.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_issue_ext test_issue.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_thread_ext test_thread.cpp ${NB_EXTRA_ARGS})

foreach (NAME functions classes ndarray stl enum typing make_iterator)
if (NAME STREQUAL typing)
Expand Down Expand Up @@ -120,6 +121,7 @@ set(TEST_FILES
test_ndarray.py
test_stubs.py
test_typing.py
test_thread.py

# Stub reference files
test_classes_ext.pyi.ref
Expand Down
16 changes: 16 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <nanobind/nanobind.h>
#include <atomic>

namespace nb = nanobind;

struct Counter {
std::atomic<size_t> counter { 0 };
void inc() { counter++; }
};

NB_MODULE(test_thread_ext, m) {
nb::class_<Counter>(m, "Counter")
.def(nb::init<>())
.def_prop_ro("counter", [](Counter &c) { return (size_t) c.counter; })
.def("inc", &Counter::inc);
}
40 changes: 40 additions & 0 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Temporarily turn off immortalization
try:
from test.support import suppress_immortalization
except ImportError:
from contextlib import nullcontext as suppress_immortalization

import test_thread_ext as t

import threading
import gc

def parallelize(func, n_threads):
with suppress_immortalization(True): # Avoid reference leak errors
barrier = threading.Barrier(n_threads)

def wrapper():
barrier.wait()
return func()

workers = []
for _ in range(n_threads):
t = threading.Thread(target=wrapper)
t.start()
workers.append(t)

for worker in workers:
worker.join()

def test01_object_creation():
from test_thread_ext import Counter

def f():
n = 1000000
r = [None]*n
for i in range(n):
r[i] = Counter()
del r

parallelize(f, n_threads=8)
gc.collect()

0 comments on commit fdaddeb

Please sign in to comment.