From 31904407d18c806f18c232de08b58d419484d122 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 2 Dec 2024 21:12:24 +0000 Subject: [PATCH] init --- CMakeLists.txt | 3 +- csrc/activation_quant_kernels.cu | 127 +++++++++++++++++++++++++++++++ csrc/core/math.hpp | 18 +++++ csrc/ops.h | 2 + csrc/torch_bindings.cpp | 4 +- vllm/compilation/backends.py | 7 +- vllm/compilation/fusion.py | 51 ++++++++++--- vllm/compilation/pass_manager.py | 2 +- 8 files changed, 198 insertions(+), 16 deletions(-) create mode 100644 csrc/activation_quant_kernels.cu create mode 100644 csrc/core/math.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..78331857177a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,7 +240,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/activation_quant_kernels.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/activation_quant_kernels.cu b/csrc/activation_quant_kernels.cu new file mode 100644 index 0000000000000..03edb37450a20 --- /dev/null +++ b/csrc/activation_quant_kernels.cu @@ -0,0 +1,127 @@ +#include +#include +#include + +#include +#include "core/math.hpp" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +// using FP8_TYPE = c10::Float8_e4m3fnuz; +namespace vllm { + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } + else { + x = val / scale; + } + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +// Activation and gating kernel template. +template +__global__ void act_and_mul_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const float* scale, + const int d) { + + const int32_t token_idx = blockIdx.x; + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); + const int32_t elems_per_block = + next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + const int32_t block_start = blockIdx.y * elems_per_block; + int32_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + prev_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int2* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + + float inverted_scale = 1 / *scale; +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + using scalar_64bit_vec_t = std::array; + + scalar_64bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i] , inverted_scale); + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + // out_ptr[idx] = ACT_FN(x) * y; + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y , inverted_scale); + } + } +} +} + + +// Launch activation, gating, and quantize kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_quant_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + scale.data_ptr(), \ + d); \ + }); + +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, + torch::Tensor& scale) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..bd5241c5703fc --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,18 @@ +#pragma once + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Compute the next multiple of a that is greater than or equal to b +template +static inline constexpr auto next_multiple_of(A a, B b) { + return div_ceil(b, a) * a; +} + +// Compute the largest multiple of a that is less than or equal to b +template +static inline constexpr auto prev_multiple_of(A a, B b) { + return (b / a) * a; +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index ea001190bc202..cdbf6297e816e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -78,6 +78,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c92773a..6db92eb98a8f9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -52,9 +52,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); // Activation function used in GeGLU with `none` approximation. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..614cce6903258 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -199,6 +199,7 @@ def __init__( self, compilation_configs: CompilationConfig, ): + print("GETTING TO BACKEND") global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -253,8 +254,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", - self.split_gm)) + logger.debug("%s", lazy_format_graph_code("after split", self.split_gm)) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) @@ -480,8 +480,7 @@ def __call__(self, *args) -> Any: ] assert new_input_addresses == entry.input_addresses, ( "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) + f" Expected {entry.input_addresses}, got {new_input_addresses}") entry.cudagraph.replay() return entry.output diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 5efa410fab6a0..e5f9f0f6f10e8 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -14,6 +14,31 @@ logger = init_logger(__name__) +def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + # result + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + print("REPLACEMENT RUNNING") + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + # result, residual + return at[1] + + def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -171,8 +196,8 @@ def __init__(self, config: CompilationConfig.PassConfig): empty_bf16(1, 5), empty_fp32(1, 1) ] - register_replacement(rms_pattern_static, rms_replacement_static, - inputs, fwd_only, self.patterns) + register_replacement(rms_pattern_static, rms_replacement_static, inputs, + fwd_only, self.patterns) # Fuse fused_add_rms_norm + static_scaled_fp8_quant into # fused_add_rms_norm_static_fp8_quant @@ -192,6 +217,16 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns, extra_check=lambda m: self.record_match(m)) + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_fp32(1, 1) + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. @@ -229,17 +264,15 @@ def process_matches(self, graph: torch.fx.Graph): kwargs = match.kwargs kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), - kwargs=kwargs) + fused_node = graph.call_function(auto_functionalized, ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - residual_node_new = graph.call_function( - operator.getitem, (fused_node, 2)) + residual_node_new = graph.call_function(operator.getitem, + (fused_node, 2)) # Last part of replacement is rebinding the users of nodes in the # match to use the new nodes. diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fb522ae053e97..26defd035688b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -44,7 +44,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_reshape: self.passes += [RedundantReshapesPass(pass_config)] - if pass_config.enable_fusion: + if True: self.passes += [FusionPass.instance(pass_config)] self.fix_functionalization = FixFunctionalizationPass(pass_config)