From ae10ca13d9a6d388007c169dbdff5a949e8d49af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Hadnagy?= Date: Tue, 7 Jan 2025 10:43:04 +0000 Subject: [PATCH] Move reduce shared memory to the back, query max. shmem size dynamically --- .../extensions/cuda/marlin/marlin_cuda_kernel.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu index b18b0469..36c59989 100644 --- a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu +++ b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu @@ -374,9 +374,9 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_red = sh_s + (stages * s_sh_stage); // ADDED: shared memory storage for scaled zero points int4* sh_sz = sh_red + (stages * s_sh_stage); + int4* sh_red = sh_sz + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -728,7 +728,6 @@ __global__ void Marlin( // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0 // ADDED: add scaled zero pointer #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ @@ -739,11 +738,11 @@ const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8. cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM \ + max_shared_mem \ ); \ Marlin< \ THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ - ><<>>( \ + ><<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\ prob_m, prob_n, prob_k, \ locks \ @@ -789,6 +788,10 @@ int marlin_cuda( } } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;