diff --git a/ctests/test_triton_pointwise.cpp b/ctests/test_triton_pointwise.cpp index 664896cf1..610b7fda6 100644 --- a/ctests/test_triton_pointwise.cpp +++ b/ctests/test_triton_pointwise.cpp @@ -2,13 +2,12 @@ #include "flag_gems/operators.h" #include "torch/torch.h" -TEST(pointwise_op_test, add) { +TEST(pointwise_op_simple_test, add) { const torch::Device device(torch::kCUDA, 0); - torch::Tensor a = torch::randn({10, 10}, device); - torch::Tensor b = torch::randn({10, 10}, device); + torch::Tensor a = torch::randn({128}, device); + torch::Tensor b = torch::randn({128}, device); torch::Tensor out_torch = a + b; torch::Tensor out_triton = flag_gems::add_tensor(a, b); - EXPECT_TRUE(torch::allclose(out_torch, out_triton)); } diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index baa6d62ce..353aeeb9d 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -11,16 +11,59 @@ #include "torch/torch.h" namespace flag_gems::utils { - +using Shape = c10::IntArrayRef; std::filesystem::path get_path_of_this_library(); std::filesystem::path get_triton_src_path(); std::filesystem::path get_flag_gems_src_path(); int64_t next_power_of_2(int64_t n); -bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); -std::tuple permute_reduction_axes_right(const at::Tensor &tensor, +std::tuple permute_reduction_axes_right(const at::Tensor& tensor, int reduction_axis); std::tuple permute_reduction_axes_right( - const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt); -std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis); + const at::Tensor& tensor, at::OptionalIntArrayRef reduction_axes_opt); +std::tuple parse_reduction_axes(const at::Tensor& tensor, int reduction_axis); int cdiv(int a, int b); -} // namespace flag_gems::utils +bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); +}; // namespace flag_gems::utils + +namespace flag_gems::pointwise_dynamic { +void checkIfScalar(const torch::Tensor& tensor1, + const torch::Tensor& tensor2, + std::array& is_tensor); +bool use_fast_path(const std::vector& tensors); + +class ParamStack { + private: + std::vector kernel_params; + std::string signature; + std::vector tensor_ptr; + std::vector strides; + std::vector task_shape; + std::vector task_partition; + std::string constexp; + void* global_scratch; + + private: + void push_strides(); + void push_task_shape(); + void push_task_partition(); + void add_global_scratch(); + + public: + ParamStack(int max_args = 32) { + kernel_params.reserve(max_args); + tensor_ptr.reserve(max_args); + void* global_scratch = nullptr; + } + void save_tensor(at::Tensor& tensor); + void save_tensor(const at::Tensor& tensor); + void save_stride(int64_t stride); + void save_task_shape(int64_t shape); + void save_task_partition(int64_t partition); + void save_constexpr(int64_t value); + void save_constexpr(bool value); + void** get_params(); + std::string get_signature(); + + void build(); +}; +}; // namespace flag_gems::pointwise_dynamic diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e57c88da5..d2d79a25e 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,8 +1,13 @@ +find_package(Python COMPONENTS Interpreter Development REQUIRED) +if(NOT Python_INCLUDE_DIRS OR NOT Python_LIBRARIES) + message(FATAL_ERROR "Python development files not found. Please ensure Python is installed and development headers are available.") +endif() +include_directories(${Python_INCLUDE_DIRS}) + add_library(operators SHARED zeros.cpp utils.cpp - add.cpp sum.cpp max.cpp mm.cpp @@ -18,9 +23,11 @@ add_library(operators bmm.cpp embedding.cpp argmax.cpp + fill.cpp softmax.cpp - exponential_.cpp - fill.cpp) + pointwise_dynamic.cpp + exponential_.cpp + ) target_include_directories(operators PUBLIC $ diff --git a/lib/add.cpp b/lib/add.cpp deleted file mode 100644 index abea7267f..000000000 --- a/lib/add.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "flag_gems/operators.h" -#include "flag_gems/utils.h" - -#include -#include "c10/cuda/CUDAStream.h" -#include "triton_jit/triton_jit_function.h" - -namespace flag_gems { -using namespace triton_jit; - -at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { - auto res = torch::broadcast_tensors({a_, b_}); - res[0] = res[0].contiguous(); - res[1] = res[1].contiguous(); - const at::Tensor &a = res[0]; - const at::Tensor &b = res[1]; - - at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); - at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); - - const TritonJITFunction &f = - TritonJITFunction::getInstance(std::string(utils::get_triton_src_path() / "binary_add.py"), - "binary_pointwise_kernel"); - - // add utility to build this automatically - int64_t tile_size = 1024; - const int num_warps = 8; - const int num_stages = 1; - int64_t n = out.numel(); - const unsigned int num_blocks = (n + tile_size - 1) / tile_size; - - // getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - c10::DeviceGuard guard(out.device()); - CUstream raw_stream = static_cast(stream.stream()); - f(raw_stream, num_blocks, 1, 1, num_warps, num_stages, a, b, out, n, tile_size); - return out; -} -} // namespace flag_gems diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp new file mode 100644 index 000000000..eb3fac1cb --- /dev/null +++ b/lib/pointwise_dynamic.cpp @@ -0,0 +1,87 @@ +#include "flag_gems/operators.h" +#include "flag_gems/utils.h" + +#include +#include +#include +#include "c10/cuda/CUDAStream.h" +#include "c10/util/Logging.h" +#include "pybind11/embed.h" +#include "triton_jit/pointwise_generator.h" +#include "triton_jit/triton_jit_function.h" + +namespace flag_gems { +using namespace triton_jit; + +namespace py = pybind11; +at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { + pointwise_dynamic::ParamStack stk = pointwise_dynamic::ParamStack(); + int64_t task_shape, ndim; + int64_t num_ctas; + int64_t tiles_per_cta; + int64_t tile_sizes; + int64_t num_tiles; + at::Tensor out = at::empty_like(a_); + std::vector tensors = {a_, b_, out}; + const int num_warps = 4; + const int num_stages = 1; + if (pointwise_dynamic::use_fast_path(tensors)) { + task_shape = a_.numel(); + int64_t stride = 1; + ndim = 1; + stk.save_stride(stride); + stk.save_stride(stride); + stk.save_stride(stride); + stk.save_task_shape(task_shape); + stk.save_task_shape(task_shape); + tile_sizes = num_warps * 32; + num_tiles = utils::cdiv(task_shape, tile_sizes); + num_ctas = std::min(static_cast(65536), num_tiles); + tiles_per_cta = utils::cdiv(num_tiles, num_ctas); + stk.save_task_partition(tiles_per_cta); + } else { + std::runtime_error("NotImplementError"); + } + stk.save_constexpr(tile_sizes); + int64_t one_tile_per_cta = (tiles_per_cta == 1); + stk.save_constexpr(one_tile_per_cta); + + std::array is_scalar; + pointwise_dynamic::checkIfScalar(a_, b_, is_scalar); + std::optional f; + auto ans_tuple = gen_add(ndim); + std::string kernel_name = std::get<0>(ans_tuple); + std::string file_path = std::get<1>(ans_tuple); + if (!is_scalar[0] && !is_scalar[1]) { + f = TritonJITFunction::getInstance(file_path, kernel_name); + } else if (!is_scalar[0] && is_scalar[1]) { + std::runtime_error("NotImplementError"); + f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_tensor_scalar"); + } else if (is_scalar[0] && !is_scalar[1]) { + std::runtime_error("NotImplementError"); + f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_scalar_tensor"); + } else { + return a_ + b_; + } + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + c10::DeviceGuard guard(out.device()); + CUstream raw_stream = static_cast(stream.stream()); + + stk.save_tensor(a_); + stk.save_tensor(b_); + stk.save_tensor(out); + stk.build(); + f->launch_with_raw_args(raw_stream, + num_ctas, + 1, + 1, + num_warps, + num_stages, + stk.get_signature(), + stk.get_params()); + return out; +} + +}; // namespace flag_gems diff --git a/lib/utils.cpp b/lib/utils.cpp index 7b6b240ef..c5b304d95 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -8,7 +8,7 @@ std::filesystem::path get_path_of_this_library() { // there is no build system generator to take care of this. static const std::filesystem::path cached_path = []() { Dl_info dl_info; - if (dladdr(reinterpret_cast(&get_path_of_this_library), &dl_info) && dl_info.dli_fname) { + if (dladdr(reinterpret_cast(&get_path_of_this_library), &dl_info) && dl_info.dli_fname) { return std::filesystem::canonical(dl_info.dli_fname); // Ensure absolute, resolved path } else { throw std::runtime_error("cannot get the path of libjit_utils.so"); @@ -34,7 +34,7 @@ std::filesystem::path get_triton_src_path() { std::filesystem::path get_flag_gems_src_path() { const static std::filesystem::path flag_gems_src_dir = []() { - const char *flag_gems_dir = std::getenv("FLAGGEMS_SOURCE_DIR"); + const char* flag_gems_dir = std::getenv("FLAGGEMS_SOURCE_DIR"); if (!flag_gems_dir) { throw std::runtime_error("Environment variable FLAGGEMS_SOURCE_DIR not set"); } @@ -43,6 +43,20 @@ std::filesystem::path get_flag_gems_src_path() { return flag_gems_src_dir; } +std::filesystem::path get_code_cache_dir() { + const char* env_cache_dir = std::getenv("FLAGGEMS_CACHE_DIR"); + std::filesystem::path cache_dir; + if (env_cache_dir) { + cache_dir = std::filesystem::path(env_cache_dir); + } else { + cache_dir = std::filesystem::path(std::getenv("HOME")) / ".flaggems"; + } + std::filesystem::create_directories(cache_dir); + std::filesystem::path code_cache_dir = cache_dir / "code_cache"; + std::filesystem::create_directories(code_cache_dir); + return code_cache_dir; +} + int64_t next_power_of_2(int64_t n) { if (n <= 1) return 1; --n; @@ -82,7 +96,7 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2) { } std::tuple permute_reduction_axes_right( - const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt) { + const at::Tensor& tensor, at::OptionalIntArrayRef reduction_axes_opt) { int64_t dim = tensor.dim(); c10::DimVector reduction_axes; @@ -111,7 +125,7 @@ std::tuple permute_reduction_axes_right( return {tensor.permute(permute_order), non_reduction_size, reduction_size}; } -std::tuple permute_reduction_axes_right(const at::Tensor &tensor, +std::tuple permute_reduction_axes_right(const at::Tensor& tensor, int reduction_axis) { int64_t dim = tensor.dim(); c10::DimVector left_axes, right_axes; @@ -132,7 +146,7 @@ std::tuple permute_reduction_axes_right(const at:: return {tensor.permute(permute_order), non_reduction_size, reduction_size}; } -std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis) { +std::tuple parse_reduction_axes(const at::Tensor& tensor, int reduction_axis) { int64_t dim = tensor.dim(); c10::DimVector left_axes, right_axes, remain_axes; int64_t non_reduction_size = 1; @@ -156,4 +170,193 @@ std::tuple parse_reduction_axes(const at::Tensor &ten int cdiv(int a, int b) { return (a + b - 1) / b; } -} // namespace flag_gems::utils +}; // namespace flag_gems::utils + +namespace flag_gems::pointwise_dynamic { +void checkIfScalar(const torch::Tensor& tensor1, + const torch::Tensor& tensor2, + std::array& is_scalar) { + is_scalar[0] = (tensor1.dim() == 0); + is_scalar[1] = (tensor2.dim() == 0); +} + +bool all_the_same_shape(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + const auto& first_shape = tensors[0].sizes(); + for (const auto& tensor : tensors) { + if (!tensor.sizes().equals(first_shape)) { + return false; + } + } + return true; +} + +bool all_c_contiguous(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + for (const auto& tensor : tensors) { + if (!tensor.is_contiguous()) { + return false; + } + } + return true; +} + +bool all_the_same_stride(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + const auto& first_stride = tensors[0].strides(); + for (const auto& tensor : tensors) { + if (!tensor.strides().equals(first_stride)) { + return false; + } + } + return true; +} +bool use_fast_path(const std::vector& tensors) { + if (!all_the_same_shape(tensors)) { + return false; + } + if (all_c_contiguous(tensors)) { + return true; + } + return all_the_same_stride(tensors) && tensors[0].is_non_overlapping_and_dense(); +} + +void ParamStack::save_tensor(const at::Tensor& tensor) { + void* p_item = tensor.data_ptr(); + tensor_ptr.push_back(p_item); + if (tensor.dtype() == at::kFloat) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp32:16,"); + } else if (tensor.dtype() == at::kInt) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*int32:16,"); + } else if (tensor.dtype() == at::kDouble) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp64:16,"); + } else if (tensor.dtype() == at::kHalf) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp16:16,"); + } else { + std::runtime_error("TypeError: we only support fp64/32/16 and int32 now"); + } +} + +void ParamStack::save_tensor(at::Tensor& tensor) { + void* p_item = tensor.data_ptr(); + tensor_ptr.push_back(p_item); + if (tensor.dtype() == at::kFloat) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp32:16,"); + } else if (tensor.dtype() == at::kInt) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*int32:16,"); + } else if (tensor.dtype() == at::kDouble) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp64:16,"); + } else if (tensor.dtype() == at::kHalf) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp16:16,"); + } else { + std::runtime_error("TypeError: we only support fp64/32/16 and int32 now"); + } +} + +std::string ParamStack::get_signature() { + if (!signature.empty() && signature.back() == ',') { + signature.pop_back(); + } + return signature; +} + +void** ParamStack::get_params() { + void** res = kernel_params.empty() ? nullptr : kernel_params.data(); + if (res == nullptr) { + // kernel_params 是空的 + std::cout << "The parameter stack is empty." << std::endl; + } else { + // kernel_params 不为空 + std::cout << "The parameter stack is not empty." << std::endl; + } + return res; +} + +void ParamStack::save_stride(int64_t stride) { + if (stride == 1) { + strides.push_back(0); + } else { + strides.push_back(stride); + } +} + +void ParamStack::save_task_shape(int64_t shape) { + task_shape.push_back(shape); +} + +void ParamStack::save_task_partition(int64_t partition) { + if (partition == 1) { + task_partition.push_back(0); + } else { + task_partition.push_back(partition); + } +} + +void ParamStack::push_strides() { + for (auto& stride : strides) { + if (stride != 0) { + kernel_params.push_back(static_cast(&stride)); + signature.append("i64,"); + } else { + signature.append("i64:1,"); + } + } +} + +void ParamStack::push_task_shape() { + for (auto& shape : task_shape) { + kernel_params.push_back(static_cast(&shape)); + signature.append("i64,"); + } +} + +void ParamStack::push_task_partition() { + for (auto& partition : task_partition) { + if (partition != 0) { + kernel_params.push_back(static_cast(&partition)); + signature.append("i64,"); + } else { + signature.append("i64:1,"); + } + } +} + +void ParamStack::add_global_scratch() { + kernel_params.push_back(&global_scratch); +} + +void ParamStack::build() { + push_strides(); + push_task_shape(); + push_task_partition(); + signature.append(constexp); + add_global_scratch(); +} + +void ParamStack::save_constexpr(int64_t value) { + constexp.append(std::to_string(value) + ","); +} + +void ParamStack::save_constexpr(bool value) { + if (value) { + constexp.append("True,"); + } else { + constexp.append("False,"); + } +} + +}; // namespace flag_gems::pointwise_dynamic diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 9e3506c7e..74f33bba5 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -950,6 +950,11 @@ def gen_kernel_launch_1d( for i in range(schema.num_output_tensors()): code.writeline(f"out{i}_strides = out{i}.stride()") + for i in range(schema.num_input_tensors()): + code.writeline(f"in{i}_strides = in{i}.stride()") + for i in range(schema.num_output_tensors()): + code.writeline(f"out{i}_strides = out{i}.stride()") + code.writeline("with torch_device_fn.device(in0.device.index):") with code.indent(): code.writeline(f"{self.jit_fn_name}[grid](") @@ -1163,7 +1168,6 @@ def prepare_args(self, *args, **kwargs): # tensor that is not broadcated, no attempts to simplify task, no reordering, # no dimenion collapsing shapes = tuple(item.shape for item in in_tensors) - task_shape = broadcast_shapes(shapes) if out_tensors: diff --git a/src/flag_gems/utils/triton_lang_extension.py b/src/flag_gems/utils/triton_lang_extension.py index 8ea38aac8..8fc017176 100644 --- a/src/flag_gems/utils/triton_lang_extension.py +++ b/src/flag_gems/utils/triton_lang_extension.py @@ -103,3 +103,33 @@ def fmod(x, y): def trunc(x): """trunc default - truncate to integer""" return tl.where(x >= 0, tl.floor(x), tl.ceil(x)) + + +# --- Pointwise Functions --- + + +# src/flag_gems/ops/add.py for lib/add.cpp +@triton.jit +def add_func(x, y, alpha=1): + return x + y * alpha + + +@triton.jit +def add_func_tensor_scalar(x, y, alpha=1): + return x + y * alpha + + +@triton.jit +def add_func_scalar_tensor(x, y, alpha=1): + return x + y * alpha + + +# src/flag_gems/ops/fill.py for lib/fill.cpp +@triton.jit +def fill_scalar_func(inp, value_scalar): + return tl.full(inp.shape, value_scalar, dtype=inp.dtype) + + +@triton.jit +def fill_tensor_func(inp, value): + return value