Skip to content

Commit e5b5ea4

Browse files
committed
avoid extra copy + ensure compiled objcode loadable
1 parent 4d32276 commit e5b5ea4

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from cuda.core.experimental._stream import Stream
1212
from cuda.core.experimental._utils.clear_error_support import (
1313
assert_type,
14-
assert_type_str_or_bytes,
14+
assert_type_str_or_bytes_like,
1515
raise_code_path_meant_to_be_unreachable,
1616
)
1717
from cuda.core.experimental._utils.cuda_utils import driver, get_binding_version, handle_return, precondition
@@ -615,14 +615,14 @@ def _lazy_load_module(self, *args, **kwargs):
615615
if self._handle is not None:
616616
return
617617
module = self._module
618-
assert_type_str_or_bytes(module)
618+
assert_type_str_or_bytes_like(module)
619619
if isinstance(module, str):
620620
if self._backend_version == "new":
621621
self._handle = handle_return(self._loader["file"](module.encode(), [], [], 0, [], [], 0))
622622
else: # "old" backend
623623
self._handle = handle_return(self._loader["file"](module.encode()))
624624
return
625-
if isinstance(module, bytes):
625+
if isinstance(module, (bytes, bytearray)):
626626
if self._backend_version == "new":
627627
self._handle = handle_return(self._loader["data"](module, [], [], 0, [], [], 0))
628628
else: # "old" backend

cuda_core/cuda/core/experimental/_program.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,7 @@ def compile(self, target_type, name_expressions=(), logs=None):
657657
nvvm.get_program_log(self._mnff.handle, log)
658658
logs.write(log.decode("utf-8", errors="backslashreplace"))
659659

660-
data_bytes = bytes(data)
661-
return ObjectCode._init(data_bytes, target_type, name=self._options.name)
660+
return ObjectCode._init(data, target_type, name=self._options.name)
662661

663662
supported_backends = ("nvJitLink", "driver")
664663
if self._backend not in supported_backends:

cuda_core/cuda/core/experimental/_utils/clear_error_support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ def assert_type(obj, expected_type):
99
raise TypeError(f"Expected type {expected_type.__name__}, but got {type(obj).__name__}")
1010

1111

12-
def assert_type_str_or_bytes(obj):
12+
def assert_type_str_or_bytes_like(obj):
1313
"""Ensure obj is of type str or bytes, else raise AssertionError with a clear message."""
14-
if not isinstance(obj, (str, bytes)):
15-
raise TypeError(f"Expected type str or bytes, but got {type(obj).__name__}")
14+
if not isinstance(obj, (str, bytes, bytearray)):
15+
raise TypeError(f"Expected type str or bytes or bytearray, but got {type(obj).__name__}")
1616

1717

1818
def raise_code_path_meant_to_be_unreachable():

cuda_core/tests/test_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,13 @@ def test_nvvm_deferred_import():
312312

313313

314314
@nvvm_available
315-
def test_nvvm_program_creation(nvvm_ir):
315+
def test_nvvm_program_creation_compilation(nvvm_ir):
316316
"""Test basic NVVM program creation"""
317317
program = Program(nvvm_ir, "nvvm")
318318
assert program.backend == "NVVM"
319319
assert program.handle is not None
320+
obj = program.compile("ptx")
321+
ker = obj.get_kernel("simple") # noqa: F841
320322
program.close()
321323

322324

0 commit comments

Comments
 (0)