Skip to content

Commit bb1fe80

Browse files
authored
Fix LaunchConfig.grid unit conversion when cluster is set (#868)
1 parent 5fb3fb6 commit bb1fe80

File tree

5 files changed

+204
-11
lines changed

5 files changed

+204
-11
lines changed

cuda_core/cuda/core/experimental/_launch_config.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ def _lazy_init():
3535
class LaunchConfig:
3636
"""Customizable launch options.
3737
38+
Note
39+
----
40+
When cluster is specified, the grid parameter represents the number of
41+
clusters (not blocks). The hierarchy is: grid (clusters) -> cluster (blocks) ->
42+
block (threads). Each dimension in grid specifies clusters in the grid, each dimension in
43+
cluster specifies blocks per cluster, and each dimension in block specifies
44+
threads per block.
45+
3846
Attributes
3947
----------
4048
grid : Union[tuple, int]
41-
Collection of threads that will execute a kernel function.
49+
Collection of threads that will execute a kernel function. When cluster
50+
is not specified, this represents the number of blocks, otherwise
51+
this represents the number of clusters.
4252
cluster : Union[tuple, int]
4353
Group of blocks (Thread Block Cluster) that will execute on the same
4454
GPU Processing Cluster (GPC). Blocks within a cluster have access to
@@ -89,16 +99,29 @@ def __post_init__(self):
8999
def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig:
90100
_lazy_init()
91101
drv_cfg = driver.CUlaunchConfig()
92-
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
93-
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
94-
drv_cfg.sharedMemBytes = config.shmem_size
95-
attrs = [] # TODO: support more attributes
102+
103+
# Handle grid dimensions and cluster configuration
96104
if config.cluster:
105+
# Convert grid from cluster units to block units
106+
grid_blocks = (
107+
config.grid[0] * config.cluster[0],
108+
config.grid[1] * config.cluster[1],
109+
config.grid[2] * config.cluster[2],
110+
)
111+
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = grid_blocks
112+
113+
# Set up cluster attribute
97114
attr = driver.CUlaunchAttribute()
98115
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
99116
dim = attr.value.clusterDim
100117
dim.x, dim.y, dim.z = config.cluster
101-
attrs.append(attr)
118+
attrs = [attr]
119+
else:
120+
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
121+
attrs = []
122+
123+
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
124+
drv_cfg.sharedMemBytes = config.shmem_size
102125
if config.cooperative_launch:
103126
attr = driver.CUlaunchAttribute()
104127
attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE

cuda_core/docs/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Release Notes
77
.. toctree::
88
:maxdepth: 3
99

10+
release/0.X.Y-notes
1011
release/0.3.2-notes
1112
release/0.3.1-notes
1213
release/0.3.0-notes
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
.. SPDX-License-Identifier: Apache-2.0
3+
4+
.. currentmodule:: cuda.core.experimental
5+
6+
``cuda.core`` 0.X.Y Release Notes
7+
=================================
8+
9+
Released on TBD
10+
11+
12+
Highlights
13+
----------
14+
15+
- Fix for :class:`LaunchConfig` grid parameter unit conversion when thread block clusters are used.
16+
17+
18+
Breaking Changes
19+
----------------
20+
21+
- **LaunchConfig grid parameter interpretation**: When :attr:`LaunchConfig.cluster` is specified, the :attr:`LaunchConfig.grid` parameter now correctly represents the number of clusters instead of blocks. Previously, the grid parameter was incorrectly interpreted as blocks, causing a mismatch with the expected C++ behavior. This change ensures that ``LaunchConfig(grid=4, cluster=2, block=32)`` correctly produces 4 clusters × 2 blocks/cluster = 8 total blocks, matching the C++ equivalent ``cudax::make_hierarchy(cudax::grid_dims(4), cudax::cluster_dims(2), cudax::block_dims(32))``.
22+
23+
24+
New features
25+
------------
26+
27+
None.
28+
29+
30+
New examples
31+
------------
32+
33+
None.
34+
35+
36+
Fixes and enhancements
37+
----------------------
38+
39+
- Fix :class:`LaunchConfig` grid unit conversion when cluster is set (addresses issue #867).

cuda_core/examples/thread_block_cluster.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,23 @@
55
# ################################################################################
66
#
77
# This demo illustrates the use of thread block clusters in the CUDA launch
8-
# configuration.
8+
# configuration and verifies that the correct grid size is passed to the kernel.
99
#
1010
# ################################################################################
1111

1212
import os
1313
import sys
1414

15-
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
15+
import numpy as np
16+
17+
from cuda.core.experimental import (
18+
Device,
19+
LaunchConfig,
20+
LegacyPinnedMemoryResource,
21+
Program,
22+
ProgramOptions,
23+
launch,
24+
)
1625

1726
# prepare include
1827
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
@@ -26,17 +35,34 @@
2635
if os.path.isdir(cccl_include):
2736
include_path.insert(0, cccl_include)
2837

29-
# print cluster info using a kernel
38+
# print cluster info using a kernel and store results in pinned memory
3039
code = r"""
3140
#include <cooperative_groups.h>
3241
3342
namespace cg = cooperative_groups;
3443
3544
extern "C"
36-
__global__ void check_cluster_info() {
45+
__global__ void check_cluster_info(unsigned int* grid_dims, unsigned int* cluster_dims, unsigned int* block_dims) {
3746
auto g = cg::this_grid();
3847
auto b = cg::this_thread_block();
48+
3949
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) {
50+
// Store grid dimensions (in blocks)
51+
grid_dims[0] = g.dim_blocks().x;
52+
grid_dims[1] = g.dim_blocks().y;
53+
grid_dims[2] = g.dim_blocks().z;
54+
55+
// Store cluster dimensions
56+
cluster_dims[0] = g.dim_clusters().x;
57+
cluster_dims[1] = g.dim_clusters().y;
58+
cluster_dims[2] = g.dim_clusters().z;
59+
60+
// Store block dimensions (in threads)
61+
block_dims[0] = b.dim_threads().x;
62+
block_dims[1] = b.dim_threads().y;
63+
block_dims[2] = b.dim_threads().z;
64+
65+
// Also print to console
4066
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z);
4167
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z);
4268
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z);
@@ -70,8 +96,49 @@
7096
block = 32
7197
config = LaunchConfig(grid=grid, cluster=cluster, block=block)
7298

99+
# allocate pinned memory to store kernel results
100+
pinned_mr = LegacyPinnedMemoryResource()
101+
element_size = np.dtype(np.uint32).itemsize
102+
103+
# allocate 3 uint32 values each for grid, cluster, and block dimensions
104+
grid_buffer = pinned_mr.allocate(3 * element_size)
105+
cluster_buffer = pinned_mr.allocate(3 * element_size)
106+
block_buffer = pinned_mr.allocate(3 * element_size)
107+
108+
# create NumPy arrays from the pinned memory
109+
grid_dims = np.from_dlpack(grid_buffer).view(dtype=np.uint32)
110+
cluster_dims = np.from_dlpack(cluster_buffer).view(dtype=np.uint32)
111+
block_dims = np.from_dlpack(block_buffer).view(dtype=np.uint32)
112+
113+
# initialize arrays to zero
114+
grid_dims[:] = 0
115+
cluster_dims[:] = 0
116+
block_dims[:] = 0
117+
73118
# launch kernel on the default stream
74-
launch(dev.default_stream, config, ker)
119+
launch(dev.default_stream, config, ker, grid_buffer, cluster_buffer, block_buffer)
75120
dev.sync()
76121

122+
# verify results
123+
print("\nResults stored in pinned memory:")
124+
print(f"Grid dimensions (blocks): {tuple(grid_dims)}")
125+
print(f"Cluster dimensions: {tuple(cluster_dims)}")
126+
print(f"Block dimensions (threads): {tuple(block_dims)}")
127+
128+
# verify that grid conversion worked correctly:
129+
# LaunchConfig(grid=4, cluster=2) should result in 8 total blocks (4 clusters * 2 blocks/cluster)
130+
expected_grid_blocks = grid * cluster # 4 * 2 = 8
131+
actual_grid_blocks = grid_dims[0]
132+
133+
print("\nVerification:")
134+
print(f"LaunchConfig specified: grid={grid} clusters, cluster={cluster} blocks/cluster")
135+
print(f"Expected total blocks: {expected_grid_blocks}")
136+
print(f"Actual total blocks: {actual_grid_blocks}")
137+
138+
if actual_grid_blocks == expected_grid_blocks:
139+
print("✓ Grid conversion is correct!")
140+
else:
141+
print("✗ Grid conversion failed!")
142+
sys.exit(1)
143+
77144
print("done!")

cuda_core/tests/test_launcher.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
launch,
2424
)
2525
from cuda.core.experimental._memory import _SynchronousMemoryResource
26+
from cuda.core.experimental._utils.cuda_utils import CUDAError
2627

2728

2829
def test_launch_config_init(init_cuda):
@@ -59,6 +60,68 @@ def test_launch_config_shmem_size():
5960
assert config.shmem_size == 0
6061

6162

63+
def test_launch_config_cluster_grid_conversion(init_cuda):
64+
"""Test that LaunchConfig preserves original grid values and conversion happens in native config."""
65+
try:
66+
# Test case 1: 1D - Issue #867 example
67+
config = LaunchConfig(grid=4, cluster=2, block=32)
68+
assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}"
69+
assert config.cluster == (2, 1, 1), f"Expected (2, 1, 1), got {config.cluster}"
70+
assert config.block == (32, 1, 1), f"Expected (32, 1, 1), got {config.block}"
71+
72+
# Test case 2: 2D grid and cluster
73+
config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32)
74+
assert config.grid == (2, 3, 1), f"Expected (2, 3, 1), got {config.grid}"
75+
assert config.cluster == (2, 2, 1), f"Expected (2, 2, 1), got {config.cluster}"
76+
77+
# Test case 3: 3D full specification
78+
config = LaunchConfig(grid=(2, 2, 2), cluster=(3, 3, 3), block=(8, 8, 8))
79+
assert config.grid == (2, 2, 2), f"Expected (2, 2, 2), got {config.grid}"
80+
assert config.cluster == (3, 3, 3), f"Expected (3, 3, 3), got {config.cluster}"
81+
82+
# Test case 4: Identity case
83+
config = LaunchConfig(grid=1, cluster=1, block=32)
84+
assert config.grid == (1, 1, 1), f"Expected (1, 1, 1), got {config.grid}"
85+
86+
# Test case 5: No cluster (should not convert grid)
87+
config = LaunchConfig(grid=4, block=32)
88+
assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}"
89+
assert config.cluster is None
90+
91+
except CUDAError:
92+
pytest.skip("Driver or GPU not new enough for thread block clusters")
93+
94+
95+
def test_launch_config_native_conversion(init_cuda):
96+
"""Test that _to_native_launch_config correctly converts grid from cluster units to block units."""
97+
from cuda.core.experimental._launch_config import _to_native_launch_config
98+
99+
try:
100+
# Test case 1: 1D - Issue #867 example
101+
config = LaunchConfig(grid=4, cluster=2, block=32)
102+
native_config = _to_native_launch_config(config)
103+
assert native_config.gridDimX == 8, f"Expected gridDimX=8, got {native_config.gridDimX}"
104+
assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}"
105+
assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}"
106+
107+
# Test case 2: 2D grid and cluster
108+
config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32)
109+
native_config = _to_native_launch_config(config)
110+
assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}"
111+
assert native_config.gridDimY == 6, f"Expected gridDimY=6, got {native_config.gridDimY}"
112+
assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}"
113+
114+
# Test case 3: No cluster (should not convert grid)
115+
config = LaunchConfig(grid=4, block=32)
116+
native_config = _to_native_launch_config(config)
117+
assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}"
118+
assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}"
119+
assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}"
120+
121+
except CUDAError:
122+
pytest.skip("Driver or GPU not new enough for thread block clusters")
123+
124+
62125
def test_launch_invalid_values(init_cuda):
63126
code = 'extern "C" __global__ void my_kernel() {}'
64127
program = Program(code, "c++")

0 commit comments

Comments
 (0)