Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
SageMoore committed Dec 3, 2024
1 parent a4c4daf commit 3190440
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 16 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/activation_quant_kernels.cu")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down
127 changes: 127 additions & 0 deletions csrc/activation_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"

using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
// using FP8_TYPE = c10::Float8_e4m3fnuz;
namespace vllm {

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
}
else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}

// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale,
const int d) {

const int32_t token_idx = blockIdx.x;
const int32_t blocks_per_token = gridDim.y;

const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);

const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
next_multiple_of(elems_per_128bit_load, tgt_elems_per_block);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;

const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
FP8_TYPE* __restrict__ out_ptr = out + token_idx * d;

// 128-bit vectorized code
const int32_t vec_loop_end =
prev_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;

const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);

float inverted_scale = 1 / *scale;
#pragma unroll
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
vec_idx += blockDim.x) {
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
using scalar_64bit_vec_t = std::array<FP8_TYPE, elems_per_128bit_load>;

scalar_64bit_vec_t out_vec;
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);

#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] = scaled_fp8_conversion<true>(ACT_FN(x_vec[i]) * y_vec[i] , inverted_scale);
}

out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
}

// Scalar cleanup code
if (block_end > vec_loop_end) {
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
// out_ptr[idx] = ACT_FN(x) * y;
out_ptr[idx] = scaled_fp8_conversion<true>(ACT_FN(x) * y , inverted_scale);
}
}
}
}


// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), \
d); \
});

void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input,
torch::Tensor& scale) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
18 changes: 18 additions & 0 deletions csrc/core/math.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}

// Compute the next multiple of a that is greater than or equal to b
template <typename A, typename B>
static inline constexpr auto next_multiple_of(A a, B b) {
return div_ceil(b, a) * a;
}

// Compute the largest multiple of a that is less than or equal to b
template <typename A, typename B>
static inline constexpr auto prev_multiple_of(A a, B b) {
return (b / a) * a;
}
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
4 changes: 3 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
Expand Down
7 changes: 3 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
self,
compilation_configs: CompilationConfig,
):
print("GETTING TO BACKEND")
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
Expand Down Expand Up @@ -253,8 +254,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
logger.debug("%s", lazy_format_graph_code("after split",
self.split_gm))
logger.debug("%s", lazy_format_graph_code("after split", self.split_gm))

compilation_counter.num_piecewise_graphs_seen += len(
self.piecewise_graphs)
Expand Down Expand Up @@ -480,8 +480,7 @@ def __call__(self, *args) -> Any:
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
f" Expected {entry.input_addresses}, got {new_input_addresses}")

entry.cudagraph.replay()
return entry.output
51 changes: 42 additions & 9 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
logger = init_logger(__name__)


def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
# result
return at2[1]


def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
print("REPLACEMENT RUNNING")
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
result=result,
input=input,
scale=scale)
# result, residual
return at[1]


def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
Expand Down Expand Up @@ -171,8 +196,8 @@ def __init__(self, config: CompilationConfig.PassConfig):
empty_bf16(1, 5),
empty_fp32(1, 1)
]
register_replacement(rms_pattern_static, rms_replacement_static,
inputs, fwd_only, self.patterns)
register_replacement(rms_pattern_static, rms_replacement_static, inputs,
fwd_only, self.patterns)

# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
# fused_add_rms_norm_static_fp8_quant
Expand All @@ -192,6 +217,16 @@ def __init__(self, config: CompilationConfig.PassConfig):
self.patterns,
extra_check=lambda m: self.record_match(m))

inputs = [
empty_fp8(5, 4),
empty_bf16(5, 4),
empty_bf16(5, 4),
empty_fp32(1, 1)
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)

def record_match(self, match: Match) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
Expand Down Expand Up @@ -229,17 +264,15 @@ def process_matches(self, graph: torch.fx.Graph):
kwargs = match.kwargs
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm

fused_node = graph.call_function(
auto_functionalized,
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
),
kwargs=kwargs)
fused_node = graph.call_function(auto_functionalized, (
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ),
kwargs=kwargs)

graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
(fused_node, 1))
residual_node_new = graph.call_function(
operator.getitem, (fused_node, 2))
residual_node_new = graph.call_function(operator.getitem,
(fused_node, 2))

# Last part of replacement is rebinding the users of nodes in the
# match to use the new nodes.
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig):
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]

if pass_config.enable_fusion:
if True:
self.passes += [FusionPass.instance(pass_config)]

self.fix_functionalization = FixFunctionalizationPass(pass_config)
Expand Down

0 comments on commit 3190440

Please sign in to comment.