Skip to content

Commit

Permalink
reference compute capability instead of chip gen
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Dec 13, 2024
1 parent 7d117f2 commit 3c98f14
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __post_init__(self):
if not _use_ex:
raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+")
if Device().compute_capability < (9, 0):
raise CUDAError("thread block clusters are not supported below Hopper")
raise CUDAError("thread block clusters are not supported on devices with compute capability < 9.0")
self.cluster = self._cast_to_3_tuple(self.cluster)
# we handle "stream=None" in the launch API
if self.stream is not None and not isinstance(self.stream, Stream):
Expand Down
9 changes: 7 additions & 2 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,27 @@ def close(self):
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend")
_supported_code_type = ("c++",)
_supported_code_type = ("c++", "ptx")
_supported_target_type = ("ptx", "cubin", "ltoir")

def __init__(self, code, code_type):
self._mnff = Program._MembersNeededForFinalize(self, None)
code_type = code_type.lower()

if code_type not in self._supported_code_type:
raise NotImplementedError

if code_type.lower() == "c++":
if code_type == "c++":
if not isinstance(code, str):
raise TypeError
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"

if code_type == "ptx":
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"
else:
raise NotImplementedError

Expand Down
5 changes: 4 additions & 1 deletion cuda_core/examples/thread_block_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
dev = Device()
arch = dev.compute_capability
if arch < (9, 0):
print("this demo requires a Hopper GPU (since thread block cluster is a hardware feature)", file=sys.stderr)
print(
"this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)",
file=sys.stderr,
)
sys.exit(0)
arch = "".join(f"{i}" for i in arch)

Expand Down

0 comments on commit 3c98f14

Please sign in to comment.