55from __future__ import annotations
66
77import weakref
8+ from contextlib import contextmanager
89from dataclasses import dataclass
910from typing import TYPE_CHECKING , Union
1011from warnings import warn
2021 _handle_boolean_option ,
2122 check_or_create_options ,
2223 driver ,
24+ get_binding_version ,
2325 handle_return ,
2426 is_nested_sequence ,
2527 is_sequence ,
2628 nvrtc ,
2729)
2830
2931
32+ @contextmanager
33+ def _nvvm_exception_manager (self ):
34+ """
35+ Taken from _linker.py
36+ """
37+ try :
38+ yield
39+ except Exception as e :
40+ error_log = ""
41+ if hasattr (self , "_mnff" ):
42+ try :
43+ nvvm = _get_nvvm_module ()
44+ logsize = nvvm .get_program_log_size (self ._mnff .handle )
45+ if logsize > 1 :
46+ log = bytearray (logsize )
47+ nvvm .get_program_log (self ._mnff .handle , log )
48+ error_log = log .decode ("utf-8" , errors = "backslashreplace" )
49+ except Exception :
50+ error_log = ""
51+ # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but
52+ # unfortunately we are still supporting Python 3.9/3.10...
53+ e .args = (e .args [0 ] + (f"\n NVVM program log: { error_log } " if error_log else "" ), * e .args [1 :])
54+ raise e
55+
56+
57+ _nvvm_module = None
58+ _nvvm_import_attempted = False
59+
60+
61+ def _get_nvvm_module ():
62+ """
63+ Handles the import of NVVM module with version and availability checks.
64+ NVVM bindings were added in cuda-bindings 12.9.0, so we need to handle cases where:
65+ 1. cuda.bindings is not new enough (< 12.9.0)
66+ 2. libnvvm is not found in the Python environment
67+
68+ Returns:
69+ The nvvm module if available and working
70+
71+ Raises:
72+ RuntimeError: If NVVM is not available due to version or library issues
73+ """
74+ global _nvvm_module , _nvvm_import_attempted
75+
76+ if _nvvm_import_attempted :
77+ if _nvvm_module is None :
78+ raise RuntimeError ("NVVM module is not available (previous import attempt failed)" )
79+ return _nvvm_module
80+
81+ _nvvm_import_attempted = True
82+
83+ try :
84+ version = get_binding_version ()
85+ if version < (12 , 9 ):
86+ raise RuntimeError (
87+ f"NVVM bindings require cuda-bindings >= 12.9.0, but found { version [0 ]} .{ version [1 ]} .x. "
88+ "Please update cuda-bindings to use NVVM features."
89+ )
90+
91+ from cuda .bindings import nvvm
92+ from cuda .bindings ._internal .nvvm import _inspect_function_pointer
93+
94+ if _inspect_function_pointer ("__nvvmCreateProgram" ) == 0 :
95+ raise RuntimeError ("NVVM library (libnvvm) is not available in this Python environment. " )
96+
97+ _nvvm_module = nvvm
98+ return _nvvm_module
99+
100+ except RuntimeError as e :
101+ _nvvm_module = None
102+ raise e
103+
104+
30105def _process_define_macro_inner (formatted_options , macro ):
31106 if isinstance (macro , str ):
32107 formatted_options .append (f"--define-macro={ macro } " )
@@ -229,11 +304,10 @@ def __post_init__(self):
229304
230305 self ._formatted_options = []
231306 if self .arch is not None :
232- self ._formatted_options .append (f"--gpu-architecture ={ self .arch } " )
307+ self ._formatted_options .append (f"-arch ={ self .arch } " )
233308 else :
234- self ._formatted_options .append (
235- "--gpu-architecture=sm_" + "" .join (f"{ i } " for i in Device ().compute_capability )
236- )
309+ self .arch = f"sm_{ Device ().arch } "
310+ self ._formatted_options .append (f"-arch={ self .arch } " )
237311 if self .relocatable_device_code is not None :
238312 self ._formatted_options .append (
239313 f"--relocatable-device-code={ _handle_boolean_option (self .relocatable_device_code )} "
@@ -370,28 +444,33 @@ class Program:
370444 code : Any
371445 String of the CUDA Runtime Compilation program.
372446 code_type : Any
373- String of the code type. Currently ``"ptx"`` and ``"c++"`` are supported.
447+ String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm "`` are supported.
374448 options : ProgramOptions, optional
375449 A ProgramOptions object to customize the compilation process.
376450 See :obj:`ProgramOptions` for more information.
377451 """
378452
379453 class _MembersNeededForFinalize :
380- __slots__ = "handle"
454+ __slots__ = "handle" , "backend"
381455
382- def __init__ (self , program_obj , handle ):
456+ def __init__ (self , program_obj , handle , backend ):
383457 self .handle = handle
458+ self .backend = backend
384459 weakref .finalize (program_obj , self .close )
385460
386461 def close (self ):
387462 if self .handle is not None :
388- handle_return (nvrtc .nvrtcDestroyProgram (self .handle ))
463+ if self .backend == "NVRTC" :
464+ handle_return (nvrtc .nvrtcDestroyProgram (self .handle ))
465+ elif self .backend == "NVVM" :
466+ nvvm = _get_nvvm_module ()
467+ nvvm .destroy_program (self .handle )
389468 self .handle = None
390469
391470 __slots__ = ("__weakref__" , "_mnff" , "_backend" , "_linker" , "_options" )
392471
393472 def __init__ (self , code , code_type , options : ProgramOptions = None ):
394- self ._mnff = Program ._MembersNeededForFinalize (self , None )
473+ self ._mnff = Program ._MembersNeededForFinalize (self , None , None )
395474
396475 self ._options = options = check_or_create_options (ProgramOptions , options , "Program options" )
397476 code_type = code_type .lower ()
@@ -402,6 +481,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
402481 # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
403482
404483 self ._mnff .handle = handle_return (nvrtc .nvrtcCreateProgram (code .encode (), options ._name , 0 , [], []))
484+ self ._mnff .backend = "NVRTC"
405485 self ._backend = "NVRTC"
406486 self ._linker = None
407487
@@ -411,8 +491,22 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
411491 ObjectCode ._init (code .encode (), code_type ), options = self ._translate_program_options (options )
412492 )
413493 self ._backend = self ._linker .backend
494+
495+ elif code_type == "nvvm" :
496+ if isinstance (code , str ):
497+ code = code .encode ("utf-8" )
498+ elif not isinstance (code , (bytes , bytearray )):
499+ raise TypeError ("NVVM IR code must be provided as str, bytes, or bytearray" )
500+
501+ nvvm = _get_nvvm_module ()
502+ self ._mnff .handle = nvvm .create_program ()
503+ self ._mnff .backend = "NVVM"
504+ nvvm .add_module_to_program (self ._mnff .handle , code , len (code ), options ._name .decode ())
505+ self ._backend = "NVVM"
506+ self ._linker = None
507+
414508 else :
415- supported_code_types = ("c++" , "ptx" )
509+ supported_code_types = ("c++" , "ptx" , "nvvm" )
416510 assert code_type not in supported_code_types , f"{ code_type = } "
417511 raise RuntimeError (f"Unsupported { code_type = } ({ supported_code_types = } )" )
418512
@@ -433,6 +527,33 @@ def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions:
433527 ptxas_options = options .ptxas_options ,
434528 )
435529
530+ def _translate_program_options_to_nvvm (self , options : ProgramOptions ) -> list [str ]:
531+ """Translate ProgramOptions to NVVM-specific compilation options."""
532+ nvvm_options = []
533+
534+ assert options .arch is not None
535+ arch = options .arch
536+ if arch .startswith ("sm_" ):
537+ arch = f"compute_{ arch [3 :]} "
538+ nvvm_options .append (f"-arch={ arch } " )
539+ if options .debug :
540+ nvvm_options .append ("-g" )
541+ if options .device_code_optimize is False :
542+ nvvm_options .append ("-opt=0" )
543+ elif options .device_code_optimize is True :
544+ nvvm_options .append ("-opt=3" )
545+ # NVVM is not consistent with NVRTC, it uses 0/1 instead...
546+ if options .ftz is not None :
547+ nvvm_options .append (f"-ftz={ '1' if options .ftz else '0' } " )
548+ if options .prec_sqrt is not None :
549+ nvvm_options .append (f"-prec-sqrt={ '1' if options .prec_sqrt else '0' } " )
550+ if options .prec_div is not None :
551+ nvvm_options .append (f"-prec-div={ '1' if options .prec_div else '0' } " )
552+ if options .fma is not None :
553+ nvvm_options .append (f"-fma={ '1' if options .fma else '0' } " )
554+
555+ return nvvm_options
556+
436557 def close (self ):
437558 """Destroy this program."""
438559 if self ._linker :
@@ -513,6 +634,31 @@ def compile(self, target_type, name_expressions=(), logs=None):
513634
514635 return ObjectCode ._init (data , target_type , symbol_mapping = symbol_mapping , name = self ._options .name )
515636
637+ elif self ._backend == "NVVM" :
638+ if target_type not in ("ptx" , "ltoir" ):
639+ raise ValueError (f'NVVM backend only supports target_type="ptx", "ltoir", got "{ target_type } "' )
640+
641+ nvvm_options = self ._translate_program_options_to_nvvm (self ._options )
642+ if target_type == "ltoir" and "-gen-lto" not in nvvm_options :
643+ nvvm_options .append ("-gen-lto" )
644+ nvvm = _get_nvvm_module ()
645+ with _nvvm_exception_manager (self ):
646+ nvvm .verify_program (self ._mnff .handle , len (nvvm_options ), nvvm_options )
647+ nvvm .compile_program (self ._mnff .handle , len (nvvm_options ), nvvm_options )
648+
649+ size = nvvm .get_compiled_result_size (self ._mnff .handle )
650+ data = bytearray (size )
651+ nvvm .get_compiled_result (self ._mnff .handle , data )
652+
653+ if logs is not None :
654+ logsize = nvvm .get_program_log_size (self ._mnff .handle )
655+ if logsize > 1 :
656+ log = bytearray (logsize )
657+ nvvm .get_program_log (self ._mnff .handle , log )
658+ logs .write (log .decode ("utf-8" , errors = "backslashreplace" ))
659+
660+ return ObjectCode ._init (data , target_type , name = self ._options .name )
661+
516662 supported_backends = ("nvJitLink" , "driver" )
517663 if self ._backend not in supported_backends :
518664 raise ValueError (f'Unsupported backend="{ self ._backend } " ({ supported_backends = } )' )
0 commit comments