Skip to content

Commit

Permalink
Add support for launch bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Jan 3, 2025
1 parent 0d2fb8d commit e3a9e63
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
22 changes: 17 additions & 5 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
public:
static std::string generateKernelDefinition(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
codegen.genDeclaration(kernel_name, num_threads_per_cta);
codegen.startBlock();
codegen.genPrologue();
codegen.genBody();
Expand Down Expand Up @@ -272,8 +273,17 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
void genDeclaration(
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
if (kernel_->hasManaged("warp_specialized_num_registers")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
Expand Down Expand Up @@ -3542,9 +3552,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
return CudaKernelGenerator::generateKernelDefinition(
kernel, kernel_name, num_threads_per_cta);
}

} // namespace codegen
Expand Down
3 changes: 2 additions & 1 deletion csrc/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace codegen {
//! Generates a CUDA kernel definition for the given kernel
NVF_API std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name = "CUDAGeneratedKernel");
const std::string& kernel_name = "CUDAGeneratedKernel",
std::optional<int64_t> num_threads_per_cta = std::nullopt);

} // namespace codegen
} // namespace nvfuser
26 changes: 13 additions & 13 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,19 @@ void KernelExecutor::compile(
}
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size);

// If NVFUSER_EXTERNAL_SRC is set, utilize the external source code.
// If the loaded external source code is empty, revert to the default codegen.
Expand Down Expand Up @@ -525,18 +537,6 @@ void KernelExecutor::compile(
NVF_THROW(ss.str());
}

// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

// TODO: high water mark should be computed via occupancy API after
// compilation.

Expand Down

0 comments on commit e3a9e63

Please sign in to comment.