Skip to content

Commit

Permalink
support ptx code type for program
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Dec 19, 2024
1 parent 33b7366 commit b48762d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
11 changes: 9 additions & 2 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,29 @@ def close(self):
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend")
_supported_code_type = ("c++",)
_supported_code_type = ("c++", "ptx")
_supported_target_type = ("ptx", "cubin", "ltoir")

def __init__(self, code, code_type):
self._mnff = Program._MembersNeededForFinalize(self, None)

code_type = code_type.lower()
if code_type not in self._supported_code_type:
raise NotImplementedError

if code_type.lower() == "c++":
if code_type == "c++":
if not isinstance(code, str):
raise TypeError
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"

elif code_type == "ptx":
if not isinstance(code, str):
raise TypeError
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"
else:
raise NotImplementedError

Expand Down
14 changes: 9 additions & 5 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def test_program_compile_valid_target_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
arch = "".join(str(i) for i in Device().compute_capability)
object_code = program.compile("ptx", options=(f"-arch=compute_{arch}",))
print(object_code._module.decode())
kernel = object_code.get_kernel("my_kernel")
assert isinstance(object_code, ObjectCode)
assert isinstance(kernel, Kernel)
ptx_object_code = program.compile("ptx", options=(f"-arch=compute_{arch}",))
program = Program(ptx_object_code.code, "ptx")
cubin_object_code = program.compile("cubin", options=(f"-arch=compute_{arch}",))
ptx_kernel = ptx_object_code.get_kernel("my_kernel")
cubin_kernel = cubin_object_code.get_kernel("my_kernel")
assert isinstance(ptx_object_code, ObjectCode)
assert isinstance(cubin_object_code, ObjectCode)
assert isinstance(ptx_kernel, Kernel)
assert isinstance(cubin_kernel, Kernel)


def test_program_compile_invalid_target_type():
Expand Down

0 comments on commit b48762d

Please sign in to comment.