Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ def _lazy_init():
"fatbin": _nvjitlink.InputType.FATBIN,
"ltoir": _nvjitlink.InputType.LTOIR,
"object": _nvjitlink.InputType.OBJECT,
"library": _nvjitlink.InputType.LIBRARY,
}
else:
_driver_input_types = {
"ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX,
"cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN,
"fatbin": _driver.CUjitInputType.CU_JIT_INPUT_FATBINARY,
"object": _driver.CUjitInputType.CU_JIT_INPUT_OBJECT,
"library": _driver.CU_JIT_INPUT_LIBRARY,
}
_inited = True

Expand Down
65 changes: 64 additions & 1 deletion cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class ObjectCode:
"""

__slots__ = ("_handle", "_backend_version", "_code_type", "_module", "_loader", "_sym_map")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")

def __new__(self, *args, **kwargs):
raise RuntimeError(
Expand Down Expand Up @@ -334,6 +334,69 @@ def from_ptx(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None
"""
return ObjectCode._init(module, "ptx", symbol_mapping=symbol_mapping)

@staticmethod
def from_ltoir(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing LTOIR.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory ltoir code to load, or
a file path string pointing to the on-disk ltoir file to load.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "ltoir", symbol_mapping=symbol_mapping)

@staticmethod
def from_fatbin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing fatbin.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory fatbin to load, or
a file path string pointing to the on-disk fatbin to load.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "fatbin", symbol_mapping=symbol_mapping)

@staticmethod
def from_object(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing object code.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory object code to load, or
a file path string pointing to the on-disk object code to load.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "object", symbol_mapping=symbol_mapping)

def from_library(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing library.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory library to load, or
a file path string pointing to the on-disk library to load.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "library", symbol_mapping=symbol_mapping)

# TODO: do we want to unload in a finalizer? Probably not..

def _lazy_load_module(self, *args, **kwargs):
Expand Down