Skip to content

Commit 110d6de

Browse files
committed
Replace _reduce_3_tuple with math.prod in _launcher.pyx
Remove the now-dead _reduce_3_tuple helper from cuda_utils.pyx. Made-with: Cursor
1 parent 51b8f63 commit 110d6de

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

cuda_core/cuda/core/_launcher.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from cuda.core._utils.cuda_utils cimport (
1717
)
1818
from cuda.core._module import Kernel
1919
from cuda.core._stream import Stream
20-
from cuda.core._utils.cuda_utils import _reduce_3_tuple
20+
from math import prod
2121

2222

2323
def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args):
@@ -62,9 +62,9 @@ cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Str
6262
dev = stream.device
6363
num_sm = dev.properties.multiprocessor_count
6464
max_grid_size = (
65-
kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm
65+
kernel.occupancy.max_active_blocks_per_multiprocessor(prod(config.block), config.shmem_size) * num_sm
6666
)
67-
if _reduce_3_tuple(config.grid) > max_grid_size:
67+
if prod(config.grid) > max_grid_size:
6868
# For now let's try not to be smart and adjust the grid size behind users' back.
6969
# We explicitly ask users to adjust.
7070
x, y, z = config.grid

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def cast_to_3_tuple(label, cfg):
6060
return cfg + (1,) * (3 - len(cfg))
6161

6262

63-
def _reduce_3_tuple(t: tuple):
64-
return t[0] * t[1] * t[2]
65-
66-
6763
cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil:
6864
if err != cydriver.CUresult.CUDA_SUCCESS:
6965
return _check_driver_error(err)

0 commit comments

Comments
 (0)