Skip to content

Commit

Permalink
WAR: mark PTX test xfail due to CI condition
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Dec 8, 2024
1 parent c01e015 commit f1d0e40
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@

import pytest

from cuda import cuda, nvrtc
from cuda.core.experimental import Program
from cuda.core.experimental._module import Kernel, ObjectCode


@pytest.fixture
def can_load_generated_ptx():
_, driver_ver = cuda.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
if nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver:
return False
return True


def test_program_init_valid_code_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
Expand All @@ -31,6 +41,8 @@ def test_program_init_invalid_code_format():
Program(code, "c++")


# TODO: incorporate this check in Program
@pytest.mark.xfail(not can_load_generated_ptx, reason="PTX version too new")
def test_program_compile_valid_target_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
Expand Down

0 comments on commit f1d0e40

Please sign in to comment.