@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
28
28
from libc.stdint cimport uint32_t
29
29
30
30
from dpctl._backend cimport ( # noqa: E211, E402;
31
+ DPCTLBuildOptionList_Append,
32
+ DPCTLBuildOptionList_Create,
33
+ DPCTLBuildOptionList_Delete,
34
+ DPCTLBuildOptionListRef,
31
35
DPCTLKernel_Copy,
32
36
DPCTLKernel_Delete,
33
37
DPCTLKernel_GetCompileNumSubGroups,
@@ -41,13 +45,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
41
45
DPCTLKernelBundle_Copy,
42
46
DPCTLKernelBundle_CreateFromOCLSource,
43
47
DPCTLKernelBundle_CreateFromSpirv,
48
+ DPCTLKernelBundle_CreateFromSYCLSource,
44
49
DPCTLKernelBundle_Delete,
45
50
DPCTLKernelBundle_GetKernel,
51
+ DPCTLKernelBundle_GetSyclKernel,
46
52
DPCTLKernelBundle_HasKernel,
53
+ DPCTLKernelBundle_HasSyclKernel,
54
+ DPCTLKernelNameList_Append,
55
+ DPCTLKernelNameList_Create,
56
+ DPCTLKernelNameList_Delete,
57
+ DPCTLKernelNameListRef,
47
58
DPCTLSyclContextRef,
48
59
DPCTLSyclDeviceRef,
49
60
DPCTLSyclKernelBundleRef,
50
61
DPCTLSyclKernelRef,
62
+ DPCTLVirtualHeaderList_Append,
63
+ DPCTLVirtualHeaderList_Create,
64
+ DPCTLVirtualHeaderList_Delete,
65
+ DPCTLVirtualHeaderListRef,
51
66
)
52
67
53
68
__all__ = [
@@ -196,9 +211,10 @@ cdef class SyclProgram:
196
211
"""
197
212
198
213
@staticmethod
199
- cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
214
+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source ):
200
215
cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
201
216
ret._program_ref = KBRef
217
+ ret._is_sycl_source = is_sycl_source
202
218
return ret
203
219
204
220
def __dealloc__ (self ):
@@ -209,13 +225,19 @@ cdef class SyclProgram:
209
225
210
226
cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
211
227
name = kernel_name.encode(" utf8" )
228
+ if self ._is_sycl_source:
229
+ return SyclKernel._create(
230
+ DPCTLKernelBundle_GetSyclKernel(self ._program_ref, name),
231
+ kernel_name)
212
232
return SyclKernel._create(
213
233
DPCTLKernelBundle_GetKernel(self ._program_ref, name),
214
234
kernel_name
215
235
)
216
236
217
237
def has_sycl_kernel (self , str kernel_name ):
218
238
name = kernel_name.encode(" utf8" )
239
+ if self ._is_sycl_source:
240
+ return DPCTLKernelBundle_HasSyclKernel(self ._program_ref, name)
219
241
return DPCTLKernelBundle_HasKernel(self ._program_ref, name)
220
242
221
243
def addressof_ref (self ):
@@ -271,7 +293,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
271
293
if KBref is NULL :
272
294
raise SyclProgramCompilationError()
273
295
274
- return SyclProgram._create(KBref)
296
+ return SyclProgram._create(KBref, False )
275
297
276
298
277
299
cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
@@ -317,7 +339,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
317
339
if KBref is NULL :
318
340
raise SyclProgramCompilationError()
319
341
320
- return SyclProgram._create(KBref)
342
+ return SyclProgram._create(KBref, False )
343
+
344
+
345
+ cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers = [], list registered_names = [], list copts = []):
346
+ """
347
+ Creates an executable SYCL kernel_bundle from SYCL source code.
348
+
349
+ This uses the DPC++ ``kernel_compiler`` extension to create a
350
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
351
+ SYCL source code.
352
+
353
+ Parameters:
354
+ q (:class:`dpctl.SyclQueue`)
355
+ The :class:`dpctl.SyclQueue` for which the
356
+ :class:`.SyclProgram` is going to be built.
357
+ source (unicode)
358
+ SYCL source code string.
359
+ headers (list)
360
+ Optional list of virtual headers, where each entry in the list
361
+ needs to be a tuple of header name and header content. See the
362
+ documentation of the ``include_files`` property in the DPC++
363
+ ``kernel_compiler`` extension for more information.
364
+ Default: []
365
+ registered_names (list, optional)
366
+ Optional list of kernel names to register. See the
367
+ documentation of the ``registered_names`` property in the DPC++
368
+ ``kernel_compiler`` extension for more information.
369
+ Default: []
370
+ copts (list)
371
+ Optional list of compilation flags that will be used
372
+ when compiling the program. Default: ``""``.
373
+
374
+ Returns:
375
+ program (:class:`.SyclProgram`)
376
+ A :class:`.SyclProgram` object wrapping the
377
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
378
+ returned by the C API.
379
+
380
+ Raises:
381
+ SyclProgramCompilationError
382
+ If a SYCL kernel bundle could not be created.
383
+ """
384
+ cdef DPCTLSyclKernelBundleRef KBref
385
+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
386
+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
387
+ cdef bytes bSrc = source.encode(' utf8' )
388
+ cdef const char * Src = < const char * > bSrc
389
+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
390
+ cdef bytes bOpt
391
+ cdef const char * sOpt
392
+ cdef bytes bName
393
+ cdef const char * sName
394
+ cdef bytes bContent
395
+ cdef const char * sContent
396
+ for opt in copts:
397
+ if not isinstance (opt, unicode ):
398
+ DPCTLBuildOptionList_Delete(BuildOpts)
399
+ raise SyclProgramCompilationError()
400
+ bOpt = opt.encode(' utf8' )
401
+ sOpt = < const char * > bOpt
402
+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
403
+
404
+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
405
+ for name in registered_names:
406
+ if not isinstance (name, unicode ):
407
+ DPCTLBuildOptionList_Delete(BuildOpts)
408
+ DPCTLKernelNameList_Delete(KernelNames)
409
+ raise SyclProgramCompilationError()
410
+ bName = name.encode(' utf8' )
411
+ sName = < const char * > bName
412
+ DPCTLKernelNameList_Append(KernelNames, sName)
413
+
414
+
415
+ cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
416
+ for name, content in headers:
417
+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
418
+ DPCTLBuildOptionList_Delete(BuildOpts)
419
+ DPCTLKernelNameList_Delete(KernelNames)
420
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
421
+ raise SyclProgramCompilationError()
422
+ bName = name.encode(' utf8' )
423
+ sName = < const char * > bName
424
+ bContent = content.encode(' utf8' )
425
+ sContent = < const char * > bContent
426
+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
427
+
428
+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
429
+ VirtualHeaders, KernelNames,
430
+ BuildOpts)
431
+
432
+ if KBref is NULL :
433
+ DPCTLBuildOptionList_Delete(BuildOpts)
434
+ DPCTLKernelNameList_Delete(KernelNames)
435
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
436
+ raise SyclProgramCompilationError()
437
+
438
+ DPCTLBuildOptionList_Delete(BuildOpts)
439
+ DPCTLKernelNameList_Delete(KernelNames)
440
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
441
+
442
+ return SyclProgram._create(KBref, True )
321
443
322
444
323
445
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(
@@ -336,4 +458,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
336
458
reference.
337
459
"""
338
460
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
339
- return SyclProgram._create(copied_KBRef)
461
+ return SyclProgram._create(copied_KBRef, False )
0 commit comments