Skip to content

Commit 658bfb2

Browse files
rparolinclaude
andauthored
Fix get_nested_resource_ptr to accept both str and bytes inputs (#1665)
* Fix get_nested_resource_ptr to accept both str and bytes inputs The char resource path in get_nested_resource_ptr previously only handled str inputs via a Cython <str?> cast, which would reject bytes objects. This updates the logic to explicitly handle str, bytes, and raise a clear TypeError for other types, enabling users to pass bytes-encoded options to APIs like nvjitlink and nvvm. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * format * Adding utf-8 as the byte encoding --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f3e457e commit 658bfb2

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

cuda_bindings/cuda/bindings/_internal/utils.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,14 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
120120
nested_ptr.reset(nested_vec, True)
121121
for i, obj_i in enumerate(obj):
122122
if ResT is char:
123-
obj_i_bytes = (<str?>(obj_i)).encode()
123+
obj_i_type = type(obj_i)
124+
if obj_i_type is str:
125+
obj_i_bytes = obj_i.encode("utf-8")
126+
elif obj_i_type is bytes:
127+
obj_i_bytes = obj_i
128+
else:
129+
raise TypeError(
130+
f"Expected str or bytes, got {obj_i_type.__name__}")
124131
str_len = <size_t>(len(obj_i_bytes)) + 1 # including null termination
125132
deref(nested_res_vec)[i].resize(str_len)
126133
obj_i_ptr = <char*>(obj_i_bytes)

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ def test_create_and_destroy(option):
9999
nvjitlink.destroy(handle)
100100

101101

102+
@pytest.mark.parametrize("option", ARCHITECTURES)
103+
def test_create_and_destroy_bytes_options(option):
104+
handle = nvjitlink.create(1, [f"-arch={option}".encode()])
105+
assert handle != 0
106+
nvjitlink.destroy(handle)
107+
108+
102109
@pytest.mark.parametrize("option", ARCHITECTURES)
103110
def test_complete_empty(option):
104111
handle = nvjitlink.create(1, [f"-arch={option}"])

cuda_bindings/tests/test_nvvm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def test_get_buffer_empty(get_size, get_buffer):
115115
assert buffer == b"\x00"
116116

117117

118-
@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"]])
118+
@pytest.mark.parametrize(
119+
"options", [[], ["-opt=0"], ["-opt=3", "-g"], [b"-opt=0"], [b"-opt=3", b"-g"], ["-opt=3", b"-g"]]
120+
)
119121
def test_compile_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa: F401, F811
120122
with nvvm_program() as prog:
121123
nvvm.add_module_to_program(prog, minimal_nvvmir, len(minimal_nvvmir), "FileNameHere.ll")
@@ -135,7 +137,9 @@ def test_compile_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa:
135137
assert ".visible .entry kernel()" in buffer.decode()
136138

137139

138-
@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"]])
140+
@pytest.mark.parametrize(
141+
"options", [[], ["-opt=0"], ["-opt=3", "-g"], [b"-opt=0"], [b"-opt=3", b"-g"], ["-opt=3", b"-g"]]
142+
)
139143
def test_verify_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa: F401, F811
140144
with nvvm_program() as prog:
141145
nvvm.add_module_to_program(prog, minimal_nvvmir, len(minimal_nvvmir), "FileNameHere.ll")

0 commit comments

Comments
 (0)