diff --git a/cuda_bindings/tests/test_nvjitlink.py b/cuda_bindings/tests/test_nvjitlink.py index 66de16c56d..e221dd72ef 100644 --- a/cuda_bindings/tests/test_nvjitlink.py +++ b/cuda_bindings/tests/test_nvjitlink.py @@ -1,10 +1,24 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +from contextlib import contextmanager + import pytest from cuda.bindings import nvjitlink, nvrtc + +@contextmanager +def nvjitlink_session(num_options, options): + """Create an nvJitLink handle and always destroy it (including on test failure).""" + handle = nvjitlink.create(num_options, options) + try: + yield handle + finally: + if handle != 0: + nvjitlink.destroy(handle) + + # Establish a handful of compatible architectures and PTX versions to test with ARCHITECTURES = ["sm_75", "sm_80", "sm_90", "sm_100"] PTX_VERSIONS = ["6.4", "7.0", "8.5", "8.8"] @@ -95,87 +109,78 @@ def test_invalid_arch_error(): @pytest.mark.parametrize("option", ARCHITECTURES) def test_create_and_destroy(option): - handle = nvjitlink.create(1, [f"-arch={option}"]) - assert handle != 0 - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={option}"]) as handle: + assert handle != 0 def test_create_and_destroy_bytes_options(): - handle = nvjitlink.create(1, [b"-arch=sm_80"]) - assert handle != 0 - nvjitlink.destroy(handle) + with nvjitlink_session(1, [b"-arch=sm_80"]) as handle: + assert handle != 0 @pytest.mark.parametrize("option", ARCHITECTURES) def test_complete_empty(option): - handle = nvjitlink.create(1, [f"-arch={option}"]) - nvjitlink.complete(handle) - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={option}"]) as handle: + nvjitlink.complete(handle) @arch_ptx_parametrized def test_add_data(arch, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={arch}"]) - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") - nvjitlink.complete(handle) - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") + nvjitlink.complete(handle) @arch_ptx_parametrized def test_add_file(arch, ptx_bytes, tmp_path): - handle = nvjitlink.create(1, [f"-arch={arch}"]) - file_path = tmp_path / "test_file.cubin" - file_path.write_bytes(ptx_bytes) - nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path)) - nvjitlink.complete(handle) - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: + file_path = tmp_path / "test_file.cubin" + file_path.write_bytes(ptx_bytes) + nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path)) + nvjitlink.complete(handle) @pytest.mark.parametrize("arch", ARCHITECTURES) def test_get_error_log(arch): - handle = nvjitlink.create(1, [f"-arch={arch}"]) - nvjitlink.complete(handle) - log_size = nvjitlink.get_error_log_size(handle) - log = bytearray(log_size) - nvjitlink.get_error_log(handle, log) - assert len(log) == log_size - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: + nvjitlink.complete(handle) + log_size = nvjitlink.get_error_log_size(handle) + log = bytearray(log_size) + nvjitlink.get_error_log(handle, log) + assert len(log) == log_size @arch_ptx_parametrized def test_get_info_log(arch, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={arch}"]) - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") - nvjitlink.complete(handle) - log_size = nvjitlink.get_info_log_size(handle) - log = bytearray(log_size) - nvjitlink.get_info_log(handle, log) - assert len(log) == log_size - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") + nvjitlink.complete(handle) + log_size = nvjitlink.get_info_log_size(handle) + log = bytearray(log_size) + nvjitlink.get_info_log(handle, log) + assert len(log) == log_size @arch_ptx_parametrized def test_get_linked_cubin(arch, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={arch}"]) - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") - nvjitlink.complete(handle) - cubin_size = nvjitlink.get_linked_cubin_size(handle) - cubin = bytearray(cubin_size) - nvjitlink.get_linked_cubin(handle, cubin) - assert len(cubin) == cubin_size - nvjitlink.destroy(handle) + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") + nvjitlink.complete(handle) + cubin_size = nvjitlink.get_linked_cubin_size(handle) + cubin = bytearray(cubin_size) + nvjitlink.get_linked_cubin(handle, cubin) + assert len(cubin) == cubin_size @pytest.mark.parametrize("arch", ARCHITECTURES) def test_get_linked_ptx(arch, get_dummy_ltoir): - handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"]) - nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data") - nvjitlink.complete(handle) - ptx_size = nvjitlink.get_linked_ptx_size(handle) - ptx = bytearray(ptx_size) - nvjitlink.get_linked_ptx(handle, ptx) - assert len(ptx) == ptx_size - nvjitlink.destroy(handle) + with nvjitlink_session(3, [f"-arch={arch}", "-lto", "-ptx"]) as handle: + nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data") + nvjitlink.complete(handle) + ptx_size = nvjitlink.get_linked_ptx_size(handle) + ptx = bytearray(ptx_size) + nvjitlink.get_linked_ptx(handle, ptx) + assert len(ptx) == ptx_size def test_package_version():