diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index b73e2d7742c30..ce6c07fbed2bc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -288,9 +288,8 @@ bool TryMatMul4Bits( if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { return false; } - const int kWarpSize = GPU_WARP_SIZE_HOST; dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); - dim3 threads(kWarpSize, kColsPerThreadBlock); + dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0);