Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 55 additions & 50 deletions cuda_bindings/tests/test_nvjitlink.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

from contextlib import contextmanager

import pytest

from cuda.bindings import nvjitlink, nvrtc


@contextmanager
def nvjitlink_session(num_options, options):
"""Create an nvJitLink handle and always destroy it (including on test failure)."""
handle = nvjitlink.create(num_options, options)
try:
yield handle
finally:
if handle != 0:
nvjitlink.destroy(handle)


# Establish a handful of compatible architectures and PTX versions to test with
ARCHITECTURES = ["sm_75", "sm_80", "sm_90", "sm_100"]
PTX_VERSIONS = ["6.4", "7.0", "8.5", "8.8"]
Expand Down Expand Up @@ -95,87 +109,78 @@ def test_invalid_arch_error():

@pytest.mark.parametrize("option", ARCHITECTURES)
def test_create_and_destroy(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
assert handle != 0
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={option}"]) as handle:
assert handle != 0


def test_create_and_destroy_bytes_options():
handle = nvjitlink.create(1, [b"-arch=sm_80"])
assert handle != 0
nvjitlink.destroy(handle)
with nvjitlink_session(1, [b"-arch=sm_80"]) as handle:
assert handle != 0


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_complete_empty(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.complete(handle)
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={option}"]) as handle:
nvjitlink.complete(handle)


@arch_ptx_parametrized
def test_add_data(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)


@arch_ptx_parametrized
def test_add_file(arch, ptx_bytes, tmp_path):
handle = nvjitlink.create(1, [f"-arch={arch}"])
file_path = tmp_path / "test_file.cubin"
file_path.write_bytes(ptx_bytes)
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
nvjitlink.complete(handle)
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
file_path = tmp_path / "test_file.cubin"
file_path.write_bytes(ptx_bytes)
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
nvjitlink.complete(handle)


@pytest.mark.parametrize("arch", ARCHITECTURES)
def test_get_error_log(arch):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.complete(handle)
log_size = nvjitlink.get_error_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_error_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
nvjitlink.complete(handle)
log_size = nvjitlink.get_error_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_error_log(handle, log)
assert len(log) == log_size


@arch_ptx_parametrized
def test_get_info_log(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
log_size = nvjitlink.get_info_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_info_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
log_size = nvjitlink.get_info_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_info_log(handle, log)
assert len(log) == log_size


@arch_ptx_parametrized
def test_get_linked_cubin(arch, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={arch}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
cubin_size = nvjitlink.get_linked_cubin_size(handle)
cubin = bytearray(cubin_size)
nvjitlink.get_linked_cubin(handle, cubin)
assert len(cubin) == cubin_size
nvjitlink.destroy(handle)
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
cubin_size = nvjitlink.get_linked_cubin_size(handle)
cubin = bytearray(cubin_size)
nvjitlink.get_linked_cubin(handle, cubin)
assert len(cubin) == cubin_size


@pytest.mark.parametrize("arch", ARCHITECTURES)
def test_get_linked_ptx(arch, get_dummy_ltoir):
handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"])
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
nvjitlink.complete(handle)
ptx_size = nvjitlink.get_linked_ptx_size(handle)
ptx = bytearray(ptx_size)
nvjitlink.get_linked_ptx(handle, ptx)
assert len(ptx) == ptx_size
nvjitlink.destroy(handle)
with nvjitlink_session(3, [f"-arch={arch}", "-lto", "-ptx"]) as handle:
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
nvjitlink.complete(handle)
ptx_size = nvjitlink.get_linked_ptx_size(handle)
ptx = bytearray(ptx_size)
nvjitlink.get_linked_ptx(handle, ptx)
assert len(ptx) == ptx_size


def test_package_version():
Expand Down
Loading