-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
198 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/all.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include <cmath> | ||
#include "core/math.hpp" | ||
#include "cuda_compat.h" | ||
#include "dispatch_utils.h" | ||
|
||
using FP8_TYPE = c10::Float8_e4m3fn; | ||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = | ||
std::numeric_limits<FP8_TYPE>::max(); | ||
// using FP8_TYPE = c10::Float8_e4m3fnuz; | ||
namespace vllm { | ||
|
||
template <typename T> | ||
__device__ __forceinline__ T silu_kernel(const T& x) { | ||
// x * sigmoid(x) | ||
return (T)(((float)x) / (1.0f + expf((float)-x))); | ||
} | ||
|
||
template <bool is_scale_inverted> | ||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, | ||
float const scale) { | ||
float x = 0.0f; | ||
if constexpr (is_scale_inverted) { | ||
x = val * scale; | ||
} | ||
else { | ||
x = val / scale; | ||
} | ||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); | ||
return static_cast<c10::Float8_e4m3fn>(r); | ||
} | ||
|
||
// Activation and gating kernel template. | ||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> | ||
__global__ void act_and_mul_quant_kernel( | ||
FP8_TYPE* __restrict__ out, // [..., d] | ||
const scalar_t* __restrict__ input, // [..., 2, d] | ||
const float* scale, | ||
const int d) { | ||
|
||
const int32_t token_idx = blockIdx.x; | ||
const int32_t blocks_per_token = gridDim.y; | ||
|
||
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); | ||
|
||
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); | ||
const int32_t elems_per_block = | ||
next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); | ||
const int32_t block_start = blockIdx.y * elems_per_block; | ||
int32_t block_end = block_start + elems_per_block; | ||
block_end = block_end > d ? d : block_end; | ||
|
||
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; | ||
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; | ||
FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; | ||
|
||
// 128-bit vectorized code | ||
const int32_t vec_loop_end = | ||
prev_multiple_of(elems_per_128bit_load, block_end); | ||
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; | ||
const int32_t vec_start_idx = block_start / elems_per_128bit_load; | ||
|
||
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr); | ||
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr); | ||
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr); | ||
|
||
float inverted_scale = 1 / *scale; | ||
#pragma unroll | ||
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; | ||
vec_idx += blockDim.x) { | ||
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); | ||
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); | ||
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>; | ||
using scalar_64bit_vec_t = std::array<FP8_TYPE, elems_per_128bit_load>; | ||
|
||
scalar_64bit_vec_t out_vec; | ||
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit); | ||
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit); | ||
|
||
#pragma unroll | ||
for (int i = 0; i < elems_per_128bit_load; i++) { | ||
out_vec[i] = scaled_fp8_conversion<true>(ACT_FN(x_vec[i]) * y_vec[i] , inverted_scale); | ||
} | ||
|
||
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec); | ||
} | ||
|
||
// Scalar cleanup code | ||
if (block_end > vec_loop_end) { | ||
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; | ||
idx += blockDim.x) { | ||
const scalar_t x = VLLM_LDG(&x_ptr[idx]); | ||
const scalar_t y = VLLM_LDG(&y_ptr[idx]); | ||
// out_ptr[idx] = ACT_FN(x) * y; | ||
out_ptr[idx] = scaled_fp8_conversion<true>(ACT_FN(x) * y , inverted_scale); | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
// Launch activation, gating, and quantize kernel. | ||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ | ||
int d = input.size(-1) / 2; \ | ||
int64_t num_tokens = input.numel() / input.size(-1); \ | ||
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ | ||
dim3 block(std::min(d, 512)); \ | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ | ||
VLLM_DISPATCH_FLOATING_TYPES( \ | ||
input.scalar_type(), "act_and_mul_kernel", [&] { \ | ||
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>> \ | ||
<<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(), \ | ||
input.data_ptr<scalar_t>(), \ | ||
scale.data_ptr<float>(), \ | ||
d); \ | ||
}); | ||
|
||
void silu_and_mul_quant(torch::Tensor& out, // [..., d] | ||
torch::Tensor& input, | ||
torch::Tensor& scale) // [..., 2 * d] | ||
{ | ||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#pragma once | ||
|
||
template <typename A, typename B> | ||
static inline constexpr auto div_ceil(A a, B b) { | ||
return (a + b - 1) / b; | ||
} | ||
|
||
// Compute the next multiple of a that is greater than or equal to b | ||
template <typename A, typename B> | ||
static inline constexpr auto next_multiple_of(A a, B b) { | ||
return div_ceil(b, a) * a; | ||
} | ||
|
||
// Compute the largest multiple of a that is less than or equal to b | ||
template <typename A, typename B> | ||
static inline constexpr auto prev_multiple_of(A a, B b) { | ||
return (b / a) * a; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters