-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add silu-mul to rms-norm fusion #20
base: luka/rms-norm-fusion
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/all.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include "../../cuda_compat.h" | ||
#include "../../dispatch_utils.h" | ||
#include "../../reduction_utils.cuh" | ||
// #include "quant_utils.cuh" | ||
#ifndef USE_ROCM | ||
using FP8_TYPE = c10::Float8_e4m3fn; | ||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = | ||
std::numeric_limits<FP8_TYPE>::max(); | ||
#else | ||
#include "amd/hip_float8.h" | ||
using FP8_TYPE = c10::Float8_e4m3fnuz; | ||
// Using the default max value from pytorch (240.0) will cause accuracy | ||
// issue when running dynamic quantization. Here use 224.0f for rocm. | ||
constexpr auto FP8_E4M3_MAX = 224.0f; | ||
#endif | ||
namespace vllm { | ||
|
||
template <bool is_scale_inverted> | ||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, | ||
float const scale) { | ||
float x = 0.0f; | ||
if constexpr (is_scale_inverted) { | ||
x = val * scale; | ||
} else { | ||
x = val / scale; | ||
} | ||
|
||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); | ||
#ifndef USE_ROCM | ||
return static_cast<c10::Float8_e4m3fn>(r); | ||
#else | ||
// Use hardware cvt instruction for fp8 on rocm | ||
return c10::Float8_e4m3fnuz(hip_fp8(r).data, | ||
c10::Float8_e4m3fnuz::from_bits()); | ||
#endif | ||
} | ||
|
||
static inline __device__ int8_t float_to_int8_rn(float x) { | ||
uint32_t dst; | ||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); | ||
return reinterpret_cast<const int8_t&>(dst); | ||
} | ||
|
||
template <typename T> | ||
__device__ __forceinline__ T silu(const T& x) { | ||
// x * sigmoid(x) | ||
return (T)(((float)x) / (1.0f + expf((float)-x))); | ||
} | ||
|
||
template <typename scalar_t> | ||
__global__ void silu_and_mul_quant_kernel( | ||
FP8_TYPE* __restrict__ out, // [..., d] | ||
const scalar_t* __restrict__ input, // [..., 2 * d] | ||
const int d, | ||
float* __restrict__ scale) { | ||
const int64_t token_idx = blockIdx.x; | ||
|
||
float inverted_scale = 1 / *scale; | ||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||
const float x = (float)VLLM_LDG(&input[token_idx * 2 * d + idx]); | ||
const float y = (float)VLLM_LDG(&input[token_idx * 2 * d + d + idx]); | ||
float t = silu(x) * y; | ||
out[token_idx * d + idx] = scaled_fp8_conversion<true>( | ||
t, inverted_scale); | ||
} | ||
|
||
} | ||
} // namespace vllm | ||
|
||
void silu_and_mul_quant(torch::Tensor& result, // [..., d] | ||
torch::Tensor const& input, // [..., 2 * d] | ||
torch::Tensor const& scale // [num_tokens] | ||
) { | ||
int d = input.size(-1) / 2; | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
dim3 grid(num_tokens); | ||
dim3 block(std::min(d, 1024)); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
float* scale_ptr = scale.data_ptr<float>(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "silu_and_mul_quant_kernel", [&] { | ||
vllm::silu_and_mul_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
result.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), d, | ||
scale_ptr); | ||
}); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Adapted from | ||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh | ||
* Copyright (c) 2023, The vLLM team. | ||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include "cuda_compat.h" | ||
|
||
namespace vllm { | ||
|
||
namespace detail { | ||
|
||
template <typename T> | ||
__inline__ __device__ T _max(T a, T b) { | ||
return max(a, b); | ||
} | ||
|
||
template <typename T> | ||
__inline__ __device__ T _sum(T a, T b) { | ||
return a + b; | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <typename T> | ||
using ReduceFnType = T (*)(T, T); | ||
|
||
// Helper function to return the next largest power of 2 | ||
static constexpr int _nextPow2(unsigned int num) { | ||
if (num <= 1) return num; | ||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); | ||
} | ||
|
||
template <typename T, int numLanes = WARP_SIZE> | ||
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) { | ||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, | ||
"numLanes is not a positive power of 2!"); | ||
static_assert(numLanes <= WARP_SIZE); | ||
#pragma unroll | ||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1) | ||
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); | ||
|
||
return val; | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) { | ||
static_assert(maxBlockSize <= 1024); | ||
if constexpr (maxBlockSize > WARP_SIZE) { | ||
val = warpReduce<T>(val, fn); | ||
// Calculates max number of lanes that need to participate in the last | ||
// warpReduce | ||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; | ||
static __shared__ T shared[maxActiveLanes]; | ||
int lane = threadIdx.x % WARP_SIZE; | ||
int wid = threadIdx.x / WARP_SIZE; | ||
if (lane == 0) shared[wid] = val; | ||
|
||
__syncthreads(); | ||
|
||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] | ||
: (T)(0.0f); | ||
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn); | ||
} else { | ||
// A single warpReduce is equal to blockReduce | ||
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn); | ||
} | ||
return val; | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceMax(T val) { | ||
return blockReduce<T, maxBlockSize>(val, detail::_max<T>); | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceSum(T val) { | ||
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>); | ||
} | ||
|
||
} // namespace vllm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,29 @@ | |
|
||
logger = init_logger(__name__) | ||
|
||
@torch.library.custom_op("neuralmagic::silu_mul_quant", mutates_args=()) | ||
def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
result = torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn) | ||
|
||
torch.ops._C.silu_and_mul_quant(result, x, scale) | ||
return result | ||
|
||
@silu_mul_quant.register_fake | ||
def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor): | ||
return torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn) | ||
|
||
def silu_mul_quant_replacement(x: torch.Tensor, scale:torch.Tensor) -> torch.tensor: | ||
# print("MATCH QUANT") | ||
return torch.ops.neuralmagic.silu_mul_quant(x, scale) | ||
|
||
def silu_mul_quant_pattern(input_: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
silu_mul_result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float16, device=input_.device) | ||
silu_mul_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.silu_and_mul.default, result = silu_mul_result, input = input_) | ||
result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float8_e4m3fn, device=input_.device) | ||
static_fp8_quant_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=silu_mul_func[1], scale=scale) | ||
return static_fp8_quant_func[1] | ||
|
||
|
||
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, | ||
scale: torch.Tensor): | ||
at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, | ||
|
@@ -75,6 +98,16 @@ def record_match_fn(match: Match): | |
register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, | ||
extra_check=record_match_fn) | ||
|
||
# silu-mul quant | ||
x = torch.empty((128, 256), device="cuda", dtype=torch.float16) | ||
scale = torch.empty((1,1), device="cuda" , dtype=torch.float32) | ||
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to set device/type here? I guess it doesn't matter a whole lot since I think these are mostly ignored There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was just sage's diff I added. I would not land this without fully removing the allocations from patterns |
||
|
||
register_replacement(silu_mul_quant_pattern, | ||
silu_mul_quant_replacement, | ||
[x, scale], | ||
fwd_only, | ||
[my_patterns]) | ||
|
||
return my_patterns, matches | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The double dispatching might be hurting us here. Is there some reason we can't just inline this pattern into the replacement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll let @Sage confirm but I think this is hiding the slice issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or see if we can call
torch.ops._C.silu_and_mul_quant.default
directly? (if we make sure that this only gets applied for CUDA)