Skip to content

Commit 4ab4049

Browse files
authored
Revert "Revert "Fix get_nested_resource_ptr to accept both str and bytes inputs"" (#1698)
* Revert "Revert "Fix get_nested_resource_ptr to accept both str and bytes inpu…" This reverts commit 2d85f3e. * compiler error fix * feedback * adding back parens * format
1 parent 10fcae6 commit 4ab4049

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

cuda_bindings/cuda/bindings/_internal/utils.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,14 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj,
120120
nested_ptr.reset(nested_vec, True)
121121
for i, obj_i in enumerate(obj):
122122
if ResT is char:
123-
obj_i_bytes = (<str?>(obj_i)).encode()
123+
obj_i_type = type(obj_i)
124+
if obj_i_type is str:
125+
obj_i_bytes = obj_i.encode("utf-8")
126+
elif obj_i_type is bytes:
127+
obj_i_bytes = obj_i
128+
else:
129+
raise TypeError(
130+
f"Expected str or bytes, got {obj_i_type.__name__}")
124131
str_len = <size_t>(len(obj_i_bytes)) + 1 # including null termination
125132
deref(nested_res_vec)[i].resize(str_len)
126133
obj_i_ptr = <char*>(obj_i_bytes)

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def test_create_and_destroy(option):
100100
nvjitlink.destroy(handle)
101101

102102

103+
def test_create_and_destroy_bytes_options():
104+
handle = nvjitlink.create(1, [b"-arch=sm_80"])
105+
assert handle != 0
106+
nvjitlink.destroy(handle)
107+
108+
103109
@pytest.mark.parametrize("option", ARCHITECTURES)
104110
def test_complete_empty(option):
105111
handle = nvjitlink.create(1, [f"-arch={option}"])

cuda_bindings/tests/test_nvvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_get_buffer_empty(get_size, get_buffer):
126126
assert buffer == b"\x00"
127127

128128

129-
@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"]])
129+
@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"], [b"-opt=0"]])
130130
def test_compile_program_with_minimal_nvvm_ir(minimal_nvvmir, options):
131131
with nvvm_program() as prog:
132132
nvvm.add_module_to_program(prog, minimal_nvvmir, len(minimal_nvvmir), "FileNameHere.ll")

0 commit comments

Comments
 (0)