11# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33
4+ import importlib
45import importlib .metadata
56
7+ import pytest
8+
69from cuda .bindings import driver , runtime
7- from cuda .bindings ._utils import driver_cu_result_explanations , runtime_cuda_error_explanations
10+
11+ _EXPLANATION_MODULES = [
12+ ("driver_cu_result_explanations" , driver .CUresult ),
13+ ("runtime_cuda_error_explanations" , runtime .cudaError_t ),
14+ ]
815
916
1017def _get_binding_version ():
@@ -15,25 +22,13 @@ def _get_binding_version():
1522 return tuple (int (v ) for v in major_minor )
1623
1724
18- def test_driver_cu_result_explanations_health ():
19- expl_dict = driver_cu_result_explanations ._EXPLANATIONS
20-
21- known_codes = set ()
22- for error in driver .CUresult :
23- code = int (error )
24- assert code in expl_dict
25- known_codes .add (code )
26-
27- if _get_binding_version () >= (13 , 0 ):
28- extra_expl = sorted (set (expl_dict .keys ()) - known_codes )
29- assert not extra_expl
30-
31-
32- def test_runtime_cuda_error_explanations_health ():
33- expl_dict = runtime_cuda_error_explanations ._EXPLANATIONS
25+ @pytest .mark .parametrize ("module_name,enum_type" , _EXPLANATION_MODULES )
26+ def test_explanations_health (module_name , enum_type ):
27+ mod = importlib .import_module (f"cuda.bindings._utils.{ module_name } " )
28+ expl_dict = mod ._EXPLANATIONS
3429
3530 known_codes = set ()
36- for error in runtime . cudaError_t :
31+ for error in enum_type :
3732 code = int (error )
3833 assert code in expl_dict
3934 known_codes .add (code )
0 commit comments