Skip to content

Commit 45685f0

Browse files
committed
apply review feedback
copy inputs to SpecializationConstant to prevent dangling pointers
1 parent 7e6065f commit 45685f0

3 files changed

Lines changed: 34 additions & 19 deletions

File tree

dpctl/program/_program.pyx

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ from cpython.buffer cimport (
3737
from cpython.bytes cimport PyBytes_FromStringAndSize
3838
from libc.stdint cimport uint32_t
3939
from libc.stdlib cimport free, malloc
40-
from libc.string cimport memcmp
40+
from libc.string cimport memcmp, memcpy
4141

4242
import warnings
4343

@@ -298,12 +298,10 @@ cdef class SpecializationConstant:
298298
integers, the first argument is interpreted as the number of bytes and
299299
the second argument is interpreted as a pointer to the data.
300300
301-
Note that when constructing from a buffer, the
302-
:class:`.SpecializationConstant`, shares memory with the original object.
303-
Modifications to the original object's data after construction will be
304-
reflected when the :class:`.SpecializationConstant` is used to create a
305-
:class:`.SyclKernelBundle`. This is not the case when constructing from a
306-
raw pointer, as the data is copied.
301+
Note that construction of the :class:`.SpecializationConstant` copies the
302+
input, so modifications made after construction of the
303+
:class:`.SpecializationConstant` will not be reflected in the
304+
:class:`.SyclKernelBundle`.
307305
308306
Args:
309307
spec_id (int):
@@ -319,11 +317,11 @@ cdef class SpecializationConstant:
319317
"""
320318

321319
cdef _spec_const _spec_const
322-
cdef Py_buffer _buffer
323320

324321
def __cinit__(self, spec_id, *args):
325322
cdef int ret_code = 0
326323
cdef object target_obj = None
324+
cdef Py_buffer _local_buffer
327325

328326
if not isinstance(spec_id, numbers.Integral):
329327
raise TypeError(
@@ -348,16 +346,16 @@ cdef class SpecializationConstant:
348346
)
349347
elif isinstance(args[0], str):
350348
target_obj = np.ascontiguousarray(args[1], dtype=args[0])
349+
else:
350+
raise TypeError(
351+
"Invalid arguments."
352+
)
351353

352354
elif len(args) == 1:
353355
target_obj = args[0]
354356
if not PyObject_CheckBuffer(target_obj):
355357
# attempt to coerce to a numpy array
356358
target_obj = np.ascontiguousarray(target_obj)
357-
else:
358-
raise TypeError(
359-
"Invalid arguments."
360-
)
361359

362360
if isinstance(target_obj, np.ndarray):
363361
if target_obj.dtype.kind not in ("b", "i", "u", "f", "c"):
@@ -372,17 +370,28 @@ cdef class SpecializationConstant:
372370
)
373371

374372
ret_code = PyObject_GetBuffer(
375-
target_obj, &(self._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
373+
target_obj, &(_local_buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
376374
)
377375
if ret_code != 0:
378376
raise ValueError(
379377
"Failed to get buffer view for the provided object."
380378
)
381-
self._spec_const.value = <void*>self._buffer.buf
382-
self._spec_const.size = <size_t>self._buffer.len
379+
self._spec_const.value = <void*>malloc(self._spec_const.size)
380+
self._spec_const.size = <size_t>_local_buffer.len
381+
382+
if self._spec_const.value == NULL:
383+
PyBuffer_Release(&(_local_buffer))
384+
raise MemoryError(
385+
"Failed to allocate memory for specialization constant data."
386+
)
387+
388+
memcpy(self._spec_const.value, _local_buffer.buf, self._spec_const.size)
389+
390+
PyBuffer_Release(&(_local_buffer))
383391

384392
def __dealloc__(self):
385-
PyBuffer_Release(&(self._buffer))
393+
if self._spec_const.value != NULL:
394+
free(self._spec_const.value)
386395

387396
def __repr__(self):
388397
return f"SpecializationConstant({self._spec_const.id})"

dpctl/program/utils/_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def parse_spirv_specializations(
106106

107107
if word_count == 0:
108108
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
109+
if i + word_count > len(words):
110+
raise ValueError(
111+
f"Invalid SPIR-V instruction at offset {i} (extends beyond "
112+
"buffer)"
113+
)
109114

110115
if opcode == SpirvOpCode.OpFunction:
111116
# everything following is not relevant to specialization constant
@@ -173,12 +178,14 @@ def parse_spirv_specializations(
173178
dtype_str = type_info["dtype"]
174179
raw_default = defaults.get(target_id)
175180
default_value = None
176-
if isinstance(raw_default, bytes):
181+
if isinstance(raw_default, bool):
182+
default_value = raw_default
183+
elif isinstance(raw_default, bytes) and dtype_str != "unknown_type":
177184
try:
178185
default_value = np.frombuffer(raw_default, dtype=dtype_str)[
179186
0
180187
].item()
181-
except Exception:
188+
except (ValueError, TypeError):
182189
default_value = None
183190

184191
result.append(

libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,6 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
503503
backend_traits<ze_be>::return_type<device> ZeDevice;
504504
ZeDevice = get_native<ze_be>(SyclDev);
505505

506-
// Specialization constants are not supported by DPCTL at the moment
507506
std::vector<std::uint32_t> spec_ids;
508507
std::vector<const void *> spec_values;
509508

0 commit comments

Comments
 (0)