Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silu-mul to rms-norm fusion #20

Draft
wants to merge 2 commits into
base: luka/rms-norm-fusion
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")

#
Expand Down Expand Up @@ -240,7 +240,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/layernorm_kernels/activation_kernels.cu")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down
3 changes: 2 additions & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.data_ptr<scalar_t>(), d); \
});

void silu_and_mul(torch::Tensor& out, // [..., d]
void silu_and_mul(torch::Tensor& result, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
torch::Tensor& out = result;
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

Expand Down
5 changes: 4 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul(torch::Tensor& result, torch::Tensor& input);

void silu_and_mul_quant(torch::Tensor& result, torch::Tensor const& input,
torch::Tensor const& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

Expand Down
91 changes: 91 additions & 0 deletions csrc/quantization/layernorm_kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include "../../cuda_compat.h"
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
// #include "quant_utils.cuh"
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm {

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

static inline __device__ int8_t float_to_int8_rn(float x) {
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

template <typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}

template <typename scalar_t>
__global__ void silu_and_mul_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2 * d]
const int d,
float* __restrict__ scale) {
const int64_t token_idx = blockIdx.x;

float inverted_scale = 1 / *scale;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x = (float)VLLM_LDG(&input[token_idx * 2 * d + idx]);
const float y = (float)VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float t = silu(x) * y;
out[token_idx * d + idx] = scaled_fp8_conversion<true>(
t, inverted_scale);
}

}
} // namespace vllm

void silu_and_mul_quant(torch::Tensor& result, // [..., d]
torch::Tensor const& input, // [..., 2 * d]
torch::Tensor const& scale // [num_tokens]
) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
float* scale_ptr = scale.data_ptr<float>();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_quant_kernel", [&] {
vllm::silu_and_mul_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
result.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), d,
scale_ptr);
});
}
95 changes: 95 additions & 0 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "cuda_compat.h"

namespace vllm {

namespace detail {

template <typename T>
__inline__ __device__ T _max(T a, T b) {
return max(a, b);
}

template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}

} // namespace detail

template <typename T>
using ReduceFnType = T (*)(T, T);

// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));

return val;
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) shared[wid] = val;

__syncthreads();

val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
}
return val;
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}

} // namespace vllm
22 changes: 15 additions & 7 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
Expand Down Expand Up @@ -107,15 +107,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, Tensor scale, float epsilon) -> "
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, &rms_norm_static_fp8_quant);
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant);

// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor! residual, Tensor weight, "
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant);
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
Expand Down Expand Up @@ -281,6 +285,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);

// Mamba selective scan kernel
ops.def(
Expand Down Expand Up @@ -330,12 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) -> "
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ requires = [
"packaging",
"setuptools>=61",
"setuptools-scm>=8.0",
"torch == 2.4.0",
"wheel",
"torch",
"jinja2",
]
build-backend = "setuptools.build_meta"
Expand Down
20 changes: 19 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def fix_functionalization(graph: fx.Graph):
kwargs = node.kwargs

input = kwargs['input']
out = kwargs['out']
out = kwargs['result'] # TODO

# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
Expand All @@ -189,6 +189,24 @@ def fix_functionalization(graph: fx.Graph):
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops.neuralmagic.silu_mul_quant.default:
#
kwargs = node.kwargs

replace_node = kwargs['result']
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_mul_quant.default, kwargs=kwargs)

for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)

# Remove the nodes all at once
for node in nodes_to_remove:
Expand Down
33 changes: 33 additions & 0 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,29 @@

logger = init_logger(__name__)

@torch.library.custom_op("neuralmagic::silu_mul_quant", mutates_args=())
def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
result = torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn)

torch.ops._C.silu_and_mul_quant(result, x, scale)
return result
Copy link
Member

@bnellnm bnellnm Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The double dispatching might be hurting us here. Is there some reason we can't just inline this pattern into the replacement?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @Sage confirm but I think this is hiding the slice issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or see if we can call torch.ops._C.silu_and_mul_quant.default directly? (if we make sure that this only gets applied for CUDA)


@silu_mul_quant.register_fake
def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor):
return torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn)

def silu_mul_quant_replacement(x: torch.Tensor, scale:torch.Tensor) -> torch.tensor:
# print("MATCH QUANT")
return torch.ops.neuralmagic.silu_mul_quant(x, scale)

def silu_mul_quant_pattern(input_: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
silu_mul_result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float16, device=input_.device)
silu_mul_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.silu_and_mul.default, result = silu_mul_result, input = input_)
result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float8_e4m3fn, device=input_.device)
static_fp8_quant_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=silu_mul_func[1], scale=scale)
return static_fp8_quant_func[1]


def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight,
Expand Down Expand Up @@ -75,6 +98,16 @@ def record_match_fn(match: Match):
register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns,
extra_check=record_match_fn)

# silu-mul quant
x = torch.empty((128, 256), device="cuda", dtype=torch.float16)
scale = torch.empty((1,1), device="cuda" , dtype=torch.float32)
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set device/type here? I guess it doesn't matter a whole lot since I think these are mostly ignored

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just sage's diff I added. I would not land this without fully removing the allocations from patterns


register_replacement(silu_mul_quant_pattern,
silu_mul_quant_replacement,
[x, scale],
fwd_only,
[my_patterns])

return my_patterns, matches


Expand Down
1 change: 0 additions & 1 deletion vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class TorchCompileWrapperWithCustomDispatcher:
"""

def __init__(self, compiled_callable: Optional[Callable] = None):

if compiled_callable is None:
# default compilation settings
# compiling the forward method
Expand Down
Loading