Skip to content

Commit 56edbb0

Browse files
authored
Extract requires() test mark to eliminate repeated numpy version checks (#1844)
* Extract requires() mark to eliminate repeated version checks Add helpers/marks.py with a reusable requires() decorator and replace all inline numpy version skipif patterns across test files. Made-with: Cursor * Rework requires() mark: rename to requires_module, use importorskip Rename the mark to requires_module and reimplement it as a thin wrapper around pytest.importorskip, forwarding *args/**kwargs directly. Version arguments are now strings (matching importorskip's minversion parameter) rather than integer tuples. Update all call sites accordingly. Made-with: Cursor * Restore numpy GH #28632 reference in skip reason Made-with: Cursor
1 parent fa25626 commit 56edbb0

File tree

7 files changed

+63
-17
lines changed

7 files changed

+63
-17
lines changed

cuda_core/tests/graph/test_device_launch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pytest
8+
from helpers.marks import requires_module
89

910
from cuda.core import (
1011
Device,
@@ -75,7 +76,7 @@ def _compile_device_launcher_kernel():
7576
Device().compute_capability.major < 9,
7677
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
7778
)
78-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
79+
@requires_module(np, "2.1")
7980
def test_device_launch_basic(init_cuda):
8081
"""Test basic device-side graph launch functionality.
8182
@@ -127,7 +128,7 @@ def test_device_launch_basic(init_cuda):
127128
Device().compute_capability.major < 9,
128129
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
129130
)
130-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
131+
@requires_module(np, "2.1")
131132
def test_device_launch_multiple(init_cuda):
132133
"""Test that device-side graph launch can be executed multiple times.
133134

cuda_core/tests/graph/test_graph_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pytest
88
from helpers.graph_kernels import compile_common_kernels
9+
from helpers.marks import requires_module
910

1011
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
1112

@@ -116,7 +117,7 @@ def test_graph_is_join_required(init_cuda):
116117
gb.end_building().complete()
117118

118119

119-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
120+
@requires_module(np, "2.1")
120121
def test_graph_repeat_capture(init_cuda):
121122
mod = compile_common_kernels()
122123
add_one = mod.get_kernel("add_one")

cuda_core/tests/graph/test_graph_builder_conditional.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import numpy as np
99
import pytest
1010
from helpers.graph_kernels import compile_conditional_kernels
11+
from helpers.marks import requires_module
1112

1213
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
1314

1415

1516
@pytest.mark.parametrize(
1617
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
1718
)
18-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
19+
@requires_module(np, "2.1")
1920
def test_graph_conditional_if(init_cuda, condition_value):
2021
mod = compile_conditional_kernels(type(condition_value))
2122
add_one = mod.get_kernel("add_one")
@@ -79,7 +80,7 @@ def test_graph_conditional_if(init_cuda, condition_value):
7980
@pytest.mark.parametrize(
8081
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
8182
)
82-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
83+
@requires_module(np, "2.1")
8384
def test_graph_conditional_if_else(init_cuda, condition_value):
8485
mod = compile_conditional_kernels(type(condition_value))
8586
add_one = mod.get_kernel("add_one")
@@ -151,7 +152,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):
151152

152153

153154
@pytest.mark.parametrize("condition_value", [0, 1, 2, 3])
154-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
155+
@requires_module(np, "2.1")
155156
def test_graph_conditional_switch(init_cuda, condition_value):
156157
mod = compile_conditional_kernels(type(condition_value))
157158
add_one = mod.get_kernel("add_one")
@@ -242,7 +243,7 @@ def test_graph_conditional_switch(init_cuda, condition_value):
242243

243244

244245
@pytest.mark.parametrize("condition_value", [True, False, 1, 0])
245-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
246+
@requires_module(np, "2.1")
246247
def test_graph_conditional_while(init_cuda, condition_value):
247248
mod = compile_conditional_kernels(type(condition_value))
248249
add_one = mod.get_kernel("add_one")

cuda_core/tests/graph/test_graph_update.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import numpy as np
77
import pytest
88
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
9+
from helpers.marks import requires_module
910

1011
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
1112
from cuda.core._graph._graph_def import GraphDef
1213
from cuda.core._utils.cuda_utils import CUDAError
1314

1415

1516
@pytest.mark.parametrize("builder", ["GraphBuilder", "GraphDef"])
16-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
17+
@requires_module(np, "2.1")
1718
def test_graph_update_kernel_args(init_cuda, builder):
1819
"""Update redirects a kernel to write to a different pointer."""
1920
mod = compile_common_kernels()
@@ -59,7 +60,7 @@ def build(ptr):
5960
b.close()
6061

6162

62-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
63+
@requires_module(np, "2.1")
6364
def test_graph_update_conditional(init_cuda):
6465
"""Update swaps conditional switch graphs with matching topology."""
6566
mod = compile_conditional_kernels(int)

cuda_core/tests/helpers/marks.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
"""Reusable pytest marks for cuda_core tests."""
5+
6+
import inspect
7+
8+
import pytest
9+
10+
11+
def requires_module(module, *args, **kwargs):
12+
"""Skip the test if a module is missing or older than required.
13+
14+
Thin wrapper around :func:`pytest.importorskip`. The first argument
15+
may be a module object or a string; all remaining positional and
16+
keyword arguments (``minversion``, ``reason``, ``exc_type``) are
17+
forwarded.
18+
19+
Prefer this over ``pytest.importorskip`` when:
20+
21+
- You need finer granularity than module scope or a test body; this
22+
mark can decorate classes, individual tests, or ``pytest.param`` entries.
23+
- You want to skip before fixtures run, avoiding setup costs.
24+
- The module is already imported and you want to pass it directly.
25+
26+
Usage::
27+
28+
@requires_module("numpy", "2.1")
29+
def test_foo(): ...
30+
31+
32+
@requires_module(np, minversion="2.1")
33+
def test_bar(): ...
34+
"""
35+
if inspect.ismodule(module):
36+
module = module.__name__
37+
elif not isinstance(module, str):
38+
raise TypeError(f"expected module or string, got {type(module).__name__}")
39+
40+
try:
41+
pytest.importorskip(module, *args, **kwargs)
42+
except pytest.skip.Skipped as exc:
43+
return pytest.mark.skipif(True, reason=str(exc))
44+
else:
45+
return pytest.mark.skipif(False, reason="")

cuda_core/tests/test_launcher.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ctypes
55

66
import helpers
7+
from helpers.marks import requires_module
78
from helpers.misc import StreamWrapper
89

910
try:
@@ -190,7 +191,7 @@ def test_launch_invalid_values(init_cuda):
190191

191192

192193
@pytest.mark.parametrize("python_type, cpp_type, init_value", PARAMS)
193-
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
194+
@requires_module(np, "2.1")
194195
def test_launch_scalar_argument(python_type, cpp_type, init_value):
195196
dev = Device()
196197
dev.set_current()
@@ -289,10 +290,7 @@ def test_cooperative_launch():
289290
"device_memory_resource", # kludgy, but can go away after #726 is resolved
290291
pytest.param(
291292
LegacyPinnedMemoryResource,
292-
marks=pytest.mark.skipif(
293-
tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5),
294-
reason="need numpy 2.2.5+, numpy GH #28632",
295-
),
293+
marks=requires_module(np, "2.2.5", reason="need numpy 2.2.5+ (numpy GH #28632)"),
296294
),
297295
],
298296
)

cuda_core/tests/test_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ml_dtypes = None
2727
import numpy as np
2828
import pytest
29+
from helpers.marks import requires_module
2930

3031
from cuda.core import Device
3132
from cuda.core._dlpack import DLDeviceType
@@ -85,9 +86,7 @@ def convert_strides_to_counts(strides, itemsize):
8586
# readonly is fixed recently (numpy/numpy#26501)
8687
pytest.param(
8788
np.frombuffer(b""),
88-
marks=pytest.mark.skipif(
89-
tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+"
90-
),
89+
marks=requires_module(np, "2.1"),
9190
),
9291
),
9392
)

0 commit comments

Comments
 (0)