diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 3a5f31c74d5..bbbf7b70939 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -274,6 +274,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { code_ << "__global__ void "; + if (kernel_->hasManaged("warp_specialized_num_registers")) { + constexpr int64_t threads_per_cta = 384; + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" << threads_per_cta + << ") "; + } if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = kernel_->getManaged>(