Skip to content

Commit 3c294f3

Browse files
committed
Support compilation from SYCL source code
1 parent ae1e532 commit 3c294f3

12 files changed

+785
-8
lines changed

dpctl/_backend.pxd

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ cdef extern from "syclinterface/dpctl_sycl_device_interface.h":
278278
cdef DPCTLDeviceVectorRef DPCTLDevice_GetComponentDevices(
279279
const DPCTLSyclDeviceRef DRef
280280
)
281+
cdef bool DPCTLDevice_CanCompileSPIRV(const DPCTLSyclDeviceRef DRef)
282+
cdef bool DPCTLDevice_CanCompileOpenCL(const DPCTLSyclDeviceRef DRef)
283+
cdef bool DPCTLDevice_CanCompileSYCL(const DPCTLSyclDeviceRef DRef)
281284

282285

283286
cdef extern from "syclinterface/dpctl_sycl_device_manager.h":
@@ -441,6 +444,43 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
441444
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_Copy(
442445
const DPCTLSyclKernelBundleRef KBRef)
443446

447+
cdef struct DPCTLBuildOptionList
448+
cdef struct DPCTLKernelNameList
449+
cdef struct DPCTLVirtualHeaderList
450+
ctypedef DPCTLBuildOptionList* DPCTLBuildOptionListRef
451+
ctypedef DPCTLKernelNameList* DPCTLKernelNameListRef
452+
ctypedef DPCTLVirtualHeaderList* DPCTLVirtualHeaderListRef
453+
454+
cdef DPCTLBuildOptionListRef DPCTLBuildOptionList_Create()
455+
cdef void DPCTLBuildOptionList_Delete(DPCTLBuildOptionListRef Ref)
456+
cdef void DPCTLBuildOptionList_Append(DPCTLBuildOptionListRef Ref,
457+
const char *Option)
458+
459+
cdef DPCTLKernelNameListRef DPCTLKernelNameList_Create()
460+
cdef void DPCTLKernelNameList_Delete(DPCTLKernelNameListRef Ref)
461+
cdef void DPCTLKernelNameList_Append(DPCTLKernelNameListRef Ref,
462+
const char *Option)
463+
464+
cdef DPCTLVirtualHeaderListRef DPCTLVirtualHeaderList_Create()
465+
cdef void DPCTLVirtualHeaderList_Delete(DPCTLVirtualHeaderListRef Ref)
466+
cdef void DPCTLVirtualHeaderList_Append(DPCTLVirtualHeaderListRef Ref,
467+
const char *Name,
468+
const char *Content)
469+
470+
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromSYCLSource(
471+
const DPCTLSyclContextRef Ctx,
472+
const DPCTLSyclDeviceRef Dev,
473+
const char *Source,
474+
DPCTLVirtualHeaderListRef Headers,
475+
DPCTLKernelNameListRef Names,
476+
DPCTLBuildOptionListRef BuildOptions)
477+
478+
cdef DPCTLSyclKernelRef DPCTLKernelBundle_GetSyclKernel(DPCTLSyclKernelBundleRef KBRef,
479+
const char *KernelName)
480+
481+
cdef bool DPCTLKernelBundle_HasSyclKernel(DPCTLSyclKernelBundleRef KBRef,
482+
const char *KernelName);
483+
444484

445485
cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
446486
ctypedef struct _md_local_accessor "MDLocalAccessor":

dpctl/_sycl_device.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ cdef public api class SyclDevice(_SyclDevice) [
6161
cdef int get_overall_ordinal(self)
6262
cdef int get_backend_ordinal(self)
6363
cdef int get_backend_and_device_type_ordinal(self)
64+
cpdef bint can_compile(self, str language)

dpctl/_sycl_device.pyx

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ from ._backend cimport ( # noqa: E211
2525
DPCTLCString_Delete,
2626
DPCTLDefaultSelector_Create,
2727
DPCTLDevice_AreEq,
28+
DPCTLDevice_CanCompileOpenCL,
29+
DPCTLDevice_CanCompileSPIRV,
30+
DPCTLDevice_CanCompileSYCL,
2831
DPCTLDevice_Copy,
2932
DPCTLDevice_CreateFromSelector,
3033
DPCTLDevice_CreateSubDevicesByAffinity,
@@ -2160,6 +2163,35 @@ cdef class SyclDevice(_SyclDevice):
21602163
raise ValueError("device could not be found")
21612164
return dev_id
21622165

2166+
cpdef bint can_compile(self, str language):
2167+
"""
2168+
Check whether it is possible to create an executable kernel_bundle
2169+
for this device from the given source language.
2170+
2171+
Parameters:
2172+
language
2173+
Input language. Possible values are "spirv" for SPIR-V binary
2174+
files, "opencl" for OpenCL C device code and "sycl" for SYCL
2175+
device code.
2176+
2177+
Returns:
2178+
bool:
2179+
True if compilation is supported, False otherwise.
2180+
2181+
Raises:
2182+
ValueError:
2183+
If an unknown source language is used.
2184+
"""
2185+
if language == "spirv" or language == "spv":
2186+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2187+
if language == "opencl" or language == "ocl":
2188+
return DPCTLDevice_CanCompileOpenCL(self._device_ref)
2189+
if language == "sycl":
2190+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2191+
2192+
raise ValueError(f"Unknown source language {language}")
2193+
2194+
21632195

21642196
cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
21652197
"""

dpctl/program/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
SyclProgramCompilationError,
2727
create_program_from_source,
2828
create_program_from_spirv,
29+
create_program_from_sycl_source,
2930
)
3031

3132
__all__ = [
3233
"create_program_from_source",
3334
"create_program_from_spirv",
35+
"create_program_from_sycl_source",
3436
"SyclKernel",
3537
"SyclProgram",
3638
"SyclProgramCompilationError",

dpctl/program/_program.pxd

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ cdef api class SyclProgram [object PySyclProgramObject, type PySyclProgramType]:
4949
binary file.
5050
"""
5151
cdef DPCTLSyclKernelBundleRef _program_ref
52+
cdef bint _is_sycl_source
5253

5354
@staticmethod
54-
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref)
55+
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref, bint _is_sycl_source)
5556
cdef DPCTLSyclKernelBundleRef get_program_ref (self)
5657
cpdef SyclKernel get_sycl_kernel(self, str kernel_name)
5758

5859

5960
cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*)
6061
cpdef create_program_from_spirv (SyclQueue q, const unsigned char[:] IL,
6162
unicode copts=*)
63+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source,
64+
list headers=*, list registered_names=*,
65+
list copts=*)

dpctl/program/_program.pyx

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
2828
from libc.stdint cimport uint32_t
2929

3030
from dpctl._backend cimport ( # noqa: E211, E402;
31+
DPCTLBuildOptionList_Append,
32+
DPCTLBuildOptionList_Create,
33+
DPCTLBuildOptionList_Delete,
34+
DPCTLBuildOptionListRef,
3135
DPCTLKernel_Copy,
3236
DPCTLKernel_Delete,
3337
DPCTLKernel_GetCompileNumSubGroups,
@@ -41,13 +45,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4145
DPCTLKernelBundle_Copy,
4246
DPCTLKernelBundle_CreateFromOCLSource,
4347
DPCTLKernelBundle_CreateFromSpirv,
48+
DPCTLKernelBundle_CreateFromSYCLSource,
4449
DPCTLKernelBundle_Delete,
4550
DPCTLKernelBundle_GetKernel,
51+
DPCTLKernelBundle_GetSyclKernel,
4652
DPCTLKernelBundle_HasKernel,
53+
DPCTLKernelBundle_HasSyclKernel,
54+
DPCTLKernelNameList_Append,
55+
DPCTLKernelNameList_Create,
56+
DPCTLKernelNameList_Delete,
57+
DPCTLKernelNameListRef,
4758
DPCTLSyclContextRef,
4859
DPCTLSyclDeviceRef,
4960
DPCTLSyclKernelBundleRef,
5061
DPCTLSyclKernelRef,
62+
DPCTLVirtualHeaderList_Append,
63+
DPCTLVirtualHeaderList_Create,
64+
DPCTLVirtualHeaderList_Delete,
65+
DPCTLVirtualHeaderListRef,
5166
)
5267

5368
__all__ = [
@@ -196,9 +211,10 @@ cdef class SyclProgram:
196211
"""
197212

198213
@staticmethod
199-
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
214+
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source):
200215
cdef SyclProgram ret = SyclProgram.__new__(SyclProgram)
201216
ret._program_ref = KBRef
217+
ret._is_sycl_source = is_sycl_source
202218
return ret
203219

204220
def __dealloc__(self):
@@ -209,13 +225,19 @@ cdef class SyclProgram:
209225

210226
cpdef SyclKernel get_sycl_kernel(self, str kernel_name):
211227
name = kernel_name.encode("utf8")
228+
if self._is_sycl_source:
229+
return SyclKernel._create(
230+
DPCTLKernelBundle_GetSyclKernel(self._program_ref, name),
231+
kernel_name)
212232
return SyclKernel._create(
213233
DPCTLKernelBundle_GetKernel(self._program_ref, name),
214234
kernel_name
215235
)
216236

217237
def has_sycl_kernel(self, str kernel_name):
218238
name = kernel_name.encode("utf8")
239+
if self._is_sycl_source:
240+
return DPCTLKernelBundle_HasSyclKernel(self._program_ref, name)
219241
return DPCTLKernelBundle_HasKernel(self._program_ref, name)
220242

221243
def addressof_ref(self):
@@ -271,7 +293,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
271293
if KBref is NULL:
272294
raise SyclProgramCompilationError()
273295

274-
return SyclProgram._create(KBref)
296+
return SyclProgram._create(KBref, False)
275297

276298

277299
cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
@@ -317,7 +339,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
317339
if KBref is NULL:
318340
raise SyclProgramCompilationError()
319341

320-
return SyclProgram._create(KBref)
342+
return SyclProgram._create(KBref, False)
343+
344+
345+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers=[], list registered_names=[], list copts=[]):
346+
"""
347+
Creates an executable SYCL kernel_bundle from SYCL source code.
348+
349+
This uses the DPC++ ``kernel_compiler`` extension to create a
350+
``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
351+
SYCL source code.
352+
353+
Parameters:
354+
q (:class:`dpctl.SyclQueue`)
355+
The :class:`dpctl.SyclQueue` for which the
356+
:class:`.SyclProgram` is going to be built.
357+
source (unicode)
358+
SYCL source code string.
359+
headers (list)
360+
Optional list of virtual headers, where each entry in the list
361+
needs to be a tuple of header name and header content. See the
362+
documentation of the ``include_files`` property in the DPC++
363+
``kernel_compiler`` extension for more information.
364+
Default: []
365+
registered_names (list, optional)
366+
Optional list of kernel names to register. See the
367+
documentation of the ``registered_names`` property in the DPC++
368+
``kernel_compiler`` extension for more information.
369+
Default: []
370+
copts (list)
371+
Optional list of compilation flags that will be used
372+
when compiling the program. Default: ``""``.
373+
374+
Returns:
375+
program (:class:`.SyclProgram`)
376+
A :class:`.SyclProgram` object wrapping the
377+
``sycl::kernel_bundle<sycl::bundle_state::executable>``
378+
returned by the C API.
379+
380+
Raises:
381+
SyclProgramCompilationError
382+
If a SYCL kernel bundle could not be created.
383+
"""
384+
cdef DPCTLSyclKernelBundleRef KBref
385+
cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
386+
cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
387+
cdef bytes bSrc = source.encode('utf8')
388+
cdef const char *Src = <const char*>bSrc
389+
cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
390+
cdef bytes bOpt
391+
cdef const char* sOpt
392+
cdef bytes bName
393+
cdef const char* sName
394+
cdef bytes bContent
395+
cdef const char* sContent
396+
for opt in copts:
397+
if not isinstance(opt, unicode):
398+
DPCTLBuildOptionList_Delete(BuildOpts)
399+
raise SyclProgramCompilationError()
400+
bOpt = opt.encode('utf8')
401+
sOpt = <const char*>bOpt
402+
DPCTLBuildOptionList_Append(BuildOpts, sOpt)
403+
404+
cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
405+
for name in registered_names:
406+
if not isinstance(name, unicode):
407+
DPCTLBuildOptionList_Delete(BuildOpts)
408+
DPCTLKernelNameList_Delete(KernelNames)
409+
raise SyclProgramCompilationError()
410+
bName = name.encode('utf8')
411+
sName = <const char*>bName
412+
DPCTLKernelNameList_Append(KernelNames, sName)
413+
414+
415+
cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
416+
for name, content in headers:
417+
if not isinstance(name, unicode) or not isinstance(content, unicode):
418+
DPCTLBuildOptionList_Delete(BuildOpts)
419+
DPCTLKernelNameList_Delete(KernelNames)
420+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
421+
raise SyclProgramCompilationError()
422+
bName = name.encode('utf8')
423+
sName = <const char*>bName
424+
bContent = content.encode('utf8')
425+
sContent = <const char*>bContent
426+
DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
427+
428+
KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
429+
VirtualHeaders, KernelNames,
430+
BuildOpts)
431+
432+
if KBref is NULL:
433+
DPCTLBuildOptionList_Delete(BuildOpts)
434+
DPCTLKernelNameList_Delete(KernelNames)
435+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
436+
raise SyclProgramCompilationError()
437+
438+
DPCTLBuildOptionList_Delete(BuildOpts)
439+
DPCTLKernelNameList_Delete(KernelNames)
440+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
441+
442+
return SyclProgram._create(KBref, True)
321443

322444

323445
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(
@@ -336,4 +458,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
336458
reference.
337459
"""
338460
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
339-
return SyclProgram._create(copied_KBRef)
461+
return SyclProgram._create(copied_KBRef, False)

0 commit comments

Comments
 (0)