Skip to content

Commit cc51a8a

Browse files
committed
Merge branch 'main' of https://github.com/NVIDIA/numba-cuda into nrt-memsys-2
2 parents bb1cf0f + 9479123 commit cc51a8a

File tree

11 files changed

+278
-47
lines changed

11 files changed

+278
-47
lines changed

ci/test_conda_pynvjitlink.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ set -euo pipefail
88
if [ "${CUDA_VER%.*.*}" = "11" ]; then
99
CTK_PACKAGES="cudatoolkit"
1010
else
11-
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc"
11+
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc cuda-cuobjdump"
1212
fi
1313

1414
rapids-logger "Install testing dependencies"

numba_cuda/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.19
1+
0.0.21

numba_cuda/numba/cuda/codegen.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import subprocess
1010
import tempfile
1111

12-
1312
CUDA_TRIPLE = 'nvptx64-nvidia-cuda'
1413

1514

@@ -181,17 +180,7 @@ def get_ltoir(self, cc=None):
181180

182181
return ltoir
183182

184-
def get_cubin(self, cc=None):
185-
cc = self._ensure_cc(cc)
186-
187-
cubin = self._cubin_cache.get(cc, None)
188-
if cubin:
189-
return cubin
190-
191-
linker = driver.Linker.new(
192-
max_registers=self._max_registers, cc=cc, lto=self._lto
193-
)
194-
183+
def _link_all(self, linker, cc, ignore_nonlto=False):
195184
if linker.lto:
196185
ltoir = self.get_ltoir(cc=cc)
197186
linker.add_ltoir(ltoir)
@@ -200,11 +189,44 @@ def get_cubin(self, cc=None):
200189
linker.add_ptx(ptx.encode())
201190

202191
for path in self._linking_files:
203-
linker.add_file_guess_ext(path)
192+
linker.add_file_guess_ext(path, ignore_nonlto)
204193
if self.needs_cudadevrt:
205-
linker.add_file_guess_ext(get_cudalib('cudadevrt', static=True))
194+
linker.add_file_guess_ext(
195+
get_cudalib('cudadevrt', static=True), ignore_nonlto
196+
)
197+
198+
def get_cubin(self, cc=None):
199+
cc = self._ensure_cc(cc)
206200

201+
cubin = self._cubin_cache.get(cc, None)
202+
if cubin:
203+
return cubin
204+
205+
if self._lto and config.DUMP_ASSEMBLY:
206+
linker = driver.Linker.new(
207+
max_registers=self._max_registers,
208+
cc=cc,
209+
additional_flags=["-ptx"],
210+
lto=self._lto
211+
)
212+
# `-ptx` flag is meant to view the optimized PTX for LTO objects.
213+
# Non-LTO objects are not passed to linker.
214+
self._link_all(linker, cc, ignore_nonlto=True)
215+
216+
ptx = linker.get_linked_ptx().decode('utf-8')
217+
218+
print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
219+
print(ptx)
220+
print('=' * 80)
221+
222+
linker = driver.Linker.new(
223+
max_registers=self._max_registers,
224+
cc=cc,
225+
lto=self._lto
226+
)
227+
self._link_all(linker, cc, ignore_nonlto=False)
207228
cubin = linker.complete()
229+
208230
self._cubin_cache[cc] = cubin
209231
self._linkerinfo_cache[cc] = linker.info_log
210232

numba_cuda/numba/cuda/cuda_paths.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def get_conda_include_dir():
310310
# though usually it shouldn't.
311311
include_dir = os.path.join(sys.prefix, 'include')
312312

313-
if os.path.exists(include_dir):
313+
if (os.path.exists(include_dir) and os.path.isdir(include_dir)
314+
and os.path.exists(os.path.join(include_dir,
315+
'cuda_device_runtime_api.h'))):
314316
return include_dir
315317
return
316318

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import threading
2121
import asyncio
2222
import pathlib
23+
import subprocess
24+
import tempfile
25+
import re
2326
from itertools import product
2427
from abc import ABCMeta, abstractmethod
2528
from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof,
@@ -35,7 +38,7 @@
3538
from .drvapi import API_PROTOTYPES
3639
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
3740
from .mappings import FILE_EXTENSION_MAP
38-
from .linkable_code import LinkableCode
41+
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
3942
from numba.cuda.utils import _readenv
4043
from numba.cuda.cudadrv import enums, drvapi, nvrtc
4144

@@ -2664,12 +2667,18 @@ def add_cu_file(self, path):
26642667
cu = f.read()
26652668
self.add_cu(cu, os.path.basename(path))
26662669

2667-
def add_file_guess_ext(self, path_or_code):
2670+
def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
26682671
"""
26692672
Add a file or LinkableCode object to the link. If a file is
26702673
passed, the type will be inferred from the extension. A LinkableCode
26712674
object represents a file already in memory.
2675+
2676+
When `ignore_nonlto` is set to true, do not add code that will not
2677+
be LTO-ed in the linking process. This is useful in inspecting the
2678+
LTO-ed portion of the PTX when linker is added with objects that can be
2679+
both LTO-ed and not LTO-ed.
26722680
"""
2681+
26732682
if isinstance(path_or_code, str):
26742683
ext = pathlib.Path(path_or_code).suffix
26752684
if ext == '':
@@ -2685,6 +2694,26 @@ def add_file_guess_ext(self, path_or_code):
26852694
"Don't know how to link file with extension "
26862695
f"{ext}"
26872696
)
2697+
2698+
if ignore_nonlto:
2699+
warn_and_return = False
2700+
if kind in (
2701+
FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"]
2702+
):
2703+
entry_types = inspect_obj_content(path_or_code)
2704+
if "nvvm" not in entry_types:
2705+
warn_and_return = True
2706+
elif kind != FILE_EXTENSION_MAP["ltoir"]:
2707+
warn_and_return = True
2708+
2709+
if warn_and_return:
2710+
warnings.warn(
2711+
f"Not adding {path_or_code} as it is not "
2712+
"optimizable at link time, and `ignore_nonlto == "
2713+
"True`."
2714+
)
2715+
return
2716+
26882717
self.add_file(path_or_code, kind)
26892718
return
26902719
else:
@@ -2697,6 +2726,25 @@ def add_file_guess_ext(self, path_or_code):
26972726
if path_or_code.kind == "cu":
26982727
self.add_cu(path_or_code.data, path_or_code.name)
26992728
else:
2729+
if ignore_nonlto:
2730+
warn_and_return = False
2731+
if isinstance(path_or_code, (Fatbin, Object)):
2732+
with tempfile.NamedTemporaryFile("w") as fp:
2733+
fp.write(path_or_code.data)
2734+
entry_types = inspect_obj_content(fp.name)
2735+
if "nvvm" not in entry_types:
2736+
warn_and_return = True
2737+
elif not isinstance(path_or_code, LTOIR):
2738+
warn_and_return = True
2739+
2740+
if warn_and_return:
2741+
warnings.warn(
2742+
f"Not adding {path_or_code.name} as it is not "
2743+
"optimizable at link time, and `ignore_nonlto == "
2744+
"True`."
2745+
)
2746+
return
2747+
27002748
self.add_data(
27012749
path_or_code.data, path_or_code.kind, path_or_code.name
27022750
)
@@ -3046,6 +3094,28 @@ def add_file(self, path, kind):
30463094
name = pathlib.Path(path).name
30473095
self.add_data(data, kind, name)
30483096

3097+
def add_cu(self, cu, name):
3098+
"""Add CUDA source in a string to the link. The name of the source
3099+
file should be specified in `name`."""
3100+
with driver.get_active_context() as ac:
3101+
dev = driver.get_device(ac.devnum)
3102+
cc = dev.compute_capability
3103+
3104+
program, log = nvrtc.compile(cu, name, cc, ltoir=self.lto)
3105+
3106+
if not self.lto and config.DUMP_ASSEMBLY:
3107+
print(("ASSEMBLY %s" % name).center(80, "-"))
3108+
print(program)
3109+
print("=" * 80)
3110+
3111+
suffix = ".ltoir" if self.lto else ".ptx"
3112+
program_name = os.path.splitext(name)[0] + suffix
3113+
# Link the program's PTX or LTOIR using the normal linker mechanism
3114+
if self.lto:
3115+
self.add_ltoir(program, program_name)
3116+
else:
3117+
self.add_ptx(program.encode(), program_name)
3118+
30493119
def add_data(self, data, kind, name):
30503120
if kind == FILE_EXTENSION_MAP["cubin"]:
30513121
fn = self._linker.add_cubin
@@ -3067,6 +3137,12 @@ def add_data(self, data, kind, name):
30673137
except NvJitLinkError as e:
30683138
raise LinkerError from e
30693139

3140+
def get_linked_ptx(self):
3141+
try:
3142+
return self._linker.get_linked_ptx()
3143+
except NvJitLinkError as e:
3144+
raise LinkerError from e
3145+
30703146
def complete(self):
30713147
try:
30723148
return self._linker.get_linked_cubin()
@@ -3342,3 +3418,28 @@ def get_version():
33423418
Return the driver version as a tuple of (major, minor)
33433419
"""
33443420
return driver.get_version()
3421+
3422+
3423+
def inspect_obj_content(objpath: str):
3424+
"""
3425+
Given path to a fatbin or object, use `cuobjdump` to examine its content
3426+
Return the set of entries in the object.
3427+
"""
3428+
code_types :set[str] = set()
3429+
3430+
try:
3431+
out = subprocess.run(["cuobjdump", objpath], check=True,
3432+
capture_output=True)
3433+
except FileNotFoundError as e:
3434+
msg = ("cuobjdump has not been found. You may need "
3435+
"to install the CUDA toolkit and ensure that "
3436+
"it is available on your PATH.\n")
3437+
raise RuntimeError(msg) from e
3438+
3439+
objtable = out.stdout.decode('utf-8')
3440+
entry_pattern = r"Fatbin (.*) code"
3441+
for line in objtable.split("\n"):
3442+
if match := re.match(entry_pattern, line):
3443+
code_types.add(match.group(1))
3444+
3445+
return code_types

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ class NVRTC:
6161
NVVM interface. Initialization is protected by a lock and uses the standard
6262
(for Numba) open_cudalib function to load the NVRTC library.
6363
"""
64+
65+
_CU12ONLY_PROTOTYPES = {
66+
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
67+
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
68+
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
69+
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p)
70+
}
71+
6472
_PROTOTYPES = {
6573
# nvrtcResult nvrtcVersion(int *major, int *minor)
6674
'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
@@ -110,6 +118,10 @@ def __new__(cls):
110118
cls.__INSTANCE = None
111119
raise NvrtcSupportError("NVRTC cannot be loaded") from e
112120

121+
from numba.cuda.cudadrv.runtime import get_version
122+
if get_version() >= (12, 0):
123+
inst._PROTOTYPES |= inst._CU12ONLY_PROTOTYPES
124+
113125
# Find & populate functions
114126
for name, proto in inst._PROTOTYPES.items():
115127
func = getattr(lib, name)
@@ -208,17 +220,31 @@ def get_ptx(self, program):
208220

209221
return ptx.value.decode()
210222

223+
def get_lto(self, program):
224+
"""
225+
Get the compiled LTOIR as a Python bytes object.
226+
"""
227+
lto_size = c_size_t()
228+
self.nvrtcGetLTOIRSize(program.handle, byref(lto_size))
229+
230+
lto = b" " * lto_size.value
231+
self.nvrtcGetLTOIR(program.handle, lto)
232+
233+
return lto
211234

212-
def compile(src, name, cc):
235+
236+
def compile(src, name, cc, ltoir=False):
213237
"""
214-
Compile a CUDA C/C++ source to PTX for a given compute capability.
238+
Compile a CUDA C/C++ source to PTX or LTOIR for a given compute capability.
215239
216240
:param src: The source code to compile
217241
:type src: str
218242
:param name: The filename of the source (for information only)
219243
:type name: str
220244
:param cc: A tuple ``(major, minor)`` of the compute capability
221245
:type cc: tuple
246+
:param ltoir: Compile into LTOIR if True, otherwise into PTX
247+
:type ltoir: bool
222248
:return: The compiled PTX and compilation log
223249
:rtype: tuple
224250
"""
@@ -246,6 +272,9 @@ def compile(src, name, cc):
246272

247273
options = [arch, *cuda_include, numba_include, nrt_include, '-rdc', 'true']
248274

275+
if ltoir:
276+
options.append("-dlto")
277+
249278
if nvrtc.get_version() < (12, 0):
250279
options += ["-std=c++17"]
251280

@@ -265,5 +294,9 @@ def compile(src, name, cc):
265294
msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
266295
warnings.warn(msg)
267296

268-
ptx = nvrtc.get_ptx(program)
269-
return ptx, log
297+
if ltoir:
298+
ltoir = nvrtc.get_lto(program)
299+
return ltoir, log
300+
else:
301+
ptx = nvrtc.get_ptx(program)
302+
return ptx, log

0 commit comments

Comments
 (0)