From bfb5ff34658dc63cd5168897b671abd6b567ef8f Mon Sep 17 00:00:00 2001 From: nianqi-tian <85143461+nianqi-tian@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:23:34 +0800 Subject: [PATCH 01/22] [kunlunxin] skip unsuported operator and open supported operator (#802) * tmp disable * format change * tmp disable * tmp disable * tmp disable * tmp skip --- benchmark/test_reduction_perf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index c273587df..2c3ab2ef1 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -244,7 +244,7 @@ def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes): if vendor_name == "kunlunxin": if op_name in ["nll_loss"]: pytest.skip("RUNTIME TODOFIX") - elif op_name in ["cummin", "cummax"]: + elif op_name in ["cummax"]: pytest.skip("CUMSUM UNSUPPORTED") bench = GenericBenchmark2DOnly( input_fn=input_fn, op_name=op_name, torch_op=torch_op, dtypes=dtypes From 51ae5277687ad8f18b292c5ceefbb9849f66b6b2 Mon Sep 17 00:00:00 2001 From: nianqi-tian <85143461+nianqi-tian@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:12:50 +0800 Subject: [PATCH 02/22] [KUNLUNXIN]tmp disable some operators and open operator (#805) * tmp disable and open operator * format --- benchmark/test_reduction_perf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 2c3ab2ef1..c273587df 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -244,7 +244,7 @@ def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes): if vendor_name == "kunlunxin": if op_name in ["nll_loss"]: pytest.skip("RUNTIME TODOFIX") - elif op_name in ["cummax"]: + elif op_name in ["cummin", "cummax"]: pytest.skip("CUMSUM UNSUPPORTED") bench = GenericBenchmark2DOnly( input_fn=input_fn, op_name=op_name, torch_op=torch_op, dtypes=dtypes From c72320f82321d90e00ea6698af99379cfd7ea101 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Wed, 30 Jul 2025 17:53:20 +0800 Subject: [PATCH 03/22] add max.dim_max --- ctests/test_triton_reduction.cpp | 40 +++++++++++++++++++++++++ include/flag_gems/operators.h | 5 ++++ lib/max.cpp | 49 +++++++++++++++++++++++++++++++ src/flag_gems/csrc/aten_patch.cpp | 1 + src/flag_gems/csrc/cstub.cpp | 5 ++++ 5 files changed, 100 insertions(+) diff --git a/ctests/test_triton_reduction.cpp b/ctests/test_triton_reduction.cpp index 5382a0647..2225ecb3c 100644 --- a/ctests/test_triton_reduction.cpp +++ b/ctests/test_triton_reduction.cpp @@ -86,3 +86,43 @@ TEST(MaxTest, max) { } EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-5, 1e-8)); } + +class MaxDimMaxTest : public ::testing::TestWithParam {}; + +TEST_P(MaxDimMaxTest, max_dim_max) { + const MaxDimTestParam param = GetParam(); + const torch::Device device(torch::kCUDA, 0); + const at::TensorOptions opt = at::TensorOptions().device(device).dtype(param.dtype); + torch::Tensor input = torch::randn({param.m, param.n}, opt); + auto out_torch = at::max(input, param.dim_to_keep, param.keepdim); + torch::Tensor max_torch = std::get<0>(out_torch); + torch::Tensor index_torch = std::get<1>(out_torch); + torch::Tensor out_value = torch::empty_like(max_torch); + torch::Tensor out_index = torch::empty_like(index_torch); + auto out_triton = flag_gems::max_dim_max(input, param.dim_to_keep, param.keepdim, out_value, out_index); + torch::Tensor max_triton = std::get<0>(out_triton); + torch::Tensor index_triton = std::get<1>(out_triton); + if (!torch::allclose(max_torch, max_triton, 1e-5, 1e-8)) { + LOG(INFO) << "Max value difference (keepdim=" << param.keepdim << "):\n" << max_torch - max_triton; + } + EXPECT_TRUE(torch::allclose(max_torch, max_triton, 1e-5, 1e-8)); + if (!torch::allclose(index_torch, index_triton, 1e-5, 1e-8)) { + LOG(INFO) << "Index difference (keepdim=" << param.keepdim << "):\n" << index_torch - index_triton; + } + EXPECT_TRUE(torch::allclose(index_torch, index_triton, 0, 0)); +} + +INSTANTIATE_TEST_SUITE_P(MaxDimMaxTests, + MaxDimMaxTest, + ::testing::Values(MaxDimTestParam {32, 1024, true, 0, at::ScalarType::Float}, + MaxDimTestParam {32, 1024, true, 1, at::ScalarType::Float}, + MaxDimTestParam {32, 1024, false, 0, at::ScalarType::Float}, + MaxDimTestParam {32, 1024, false, 1, at::ScalarType::Float}, + MaxDimTestParam {32, 1024, true, 0, at::ScalarType::Half}, + MaxDimTestParam {32, 1024, true, 1, at::ScalarType::Half}, + MaxDimTestParam {32, 1024, false, 0, at::ScalarType::Half}, + MaxDimTestParam {32, 1024, false, 1, at::ScalarType::Half}, + MaxDimTestParam {32, 1024, true, 0, at::ScalarType::BFloat16}, + MaxDimTestParam {32, 1024, true, 1, at::ScalarType::BFloat16}, + MaxDimTestParam {32, 1024, false, 0, at::ScalarType::BFloat16}, + MaxDimTestParam {32, 1024, false, 1, at::ScalarType::BFloat16})); diff --git a/include/flag_gems/operators.h b/include/flag_gems/operators.h index 816a26299..c26627f38 100644 --- a/include/flag_gems/operators.h +++ b/include/flag_gems/operators.h @@ -15,6 +15,11 @@ at::Tensor sum_dim(const at::Tensor &self, bool keepdim = false, ::std::optional dtype = ::std::nullopt); std::tuple max_dim(const at::Tensor &self, int64_t dim, bool keepdim); +std::tuple max_dim_max(const at::Tensor &self, + int64_t dim, + bool keepdim, + const at::Tensor out_value, + const at::Tensor out_index); at::Tensor max(const at::Tensor &self); at::Tensor rms_norm(const at::Tensor &input, const at::Tensor &weight, double epsilon = 1e-5); void fused_add_rms_norm(at::Tensor &input, diff --git a/lib/max.cpp b/lib/max.cpp index cccec49a9..be353672f 100644 --- a/lib/max.cpp +++ b/lib/max.cpp @@ -42,6 +42,55 @@ int cdiv(int a, int b) { namespace flag_gems { using namespace triton_jit; +// max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> +// (Tensor(a!) values, Tensor(b!) indices) +::std::tuple max_dim_max(const at::Tensor &self, + int64_t dim, + bool keepdim, + const at::Tensor out_value, + const at::Tensor out_index) { + auto [permuted_self, non_reduction_size, reduction_size] = permute_reduction_axes_right(self, dim); + // set_output(out_value,out_index); + permuted_self = permuted_self.contiguous(); + const TritonJITFunction &f = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "max.py"), + "max_kernel"); + int64_t tile_m = 4; + int64_t tile_n = 512; + const int num_warps = 8; + const int num_stages = 2; + const unsigned int num_blocks = (non_reduction_size + tile_m - 1) / tile_m; + /* + def max_kernel( + inp, + out_value, + out_index, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + */ + c10::DeviceGuard guard(out_value.device()); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + + f(stream, + num_blocks, + 1, + 1, + num_warps, + num_stages, + permuted_self, + out_value, + out_index, + non_reduction_size, + reduction_size, + tile_m, + tile_n); + + return std::make_tuple(out_value, out_index); +} // max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) ::std::tuple max_dim(const at::Tensor &self, int64_t dim, bool keepdim) { at::DimVector shape = at::meta::get_reduction_shape(self, dim, keepdim, false); diff --git a/src/flag_gems/csrc/aten_patch.cpp b/src/flag_gems/csrc/aten_patch.cpp index cfbb29402..961f755cd 100644 --- a/src/flag_gems/csrc/aten_patch.cpp +++ b/src/flag_gems/csrc/aten_patch.cpp @@ -29,6 +29,7 @@ TORCH_LIBRARY_IMPL(aten, CUDA, m) { REGISTER_AND_LOG("addmm", addmm); REGISTER_AND_LOG("bmm", bmm); REGISTER_AND_LOG("mm", mm_tensor); + REGISTER_AND_LOG("max.dim_max", max_dim_max) REGISTER_AND_LOG("max.dim", max_dim) REGISTER_AND_LOG("max", max) } diff --git a/src/flag_gems/csrc/cstub.cpp b/src/flag_gems/csrc/cstub.cpp index 6d0b4501a..d23ad4b06 100644 --- a/src/flag_gems/csrc/cstub.cpp +++ b/src/flag_gems/csrc/cstub.cpp @@ -10,6 +10,7 @@ PYBIND11_MODULE(c_operators, m) { m.def("max_dim", &flag_gems::max_dim); m.def("max", &flag_gems::max); m.def("add_tensor", &flag_gems::add_tensor); + m.def("max_dim_max", &flag_gems::max_dim_max); m.def("rms_norm", &flag_gems::rms_norm); m.def("fused_add_rms_norm", &flag_gems::fused_add_rms_norm); m.def("nonzero", &flag_gems::nonzero); @@ -30,6 +31,9 @@ TORCH_LIBRARY(flag_gems, m) { "zeros(SymInt[] size, ScalarType? dtype=None,Layout? layout=None, Device? device=None, bool? " "pin_memory=None) -> Tensor"); m.def("sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); + m.def( + "max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> " + "(Tensor(a!) values, Tensor(b!) indices)"); m.def("max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); m.def("max(Tensor self) -> Tensor"); m.def("add_tensor(Tensor self, Tensor other) -> Tensor", {at::Tag::pt2_compliant_tag}); @@ -69,6 +73,7 @@ TORCH_LIBRARY_IMPL(flag_gems, CUDA, m) { m.impl("zeros", TORCH_FN(zeros)); m.impl("sum.dim_IntList", TORCH_FN(sum_dim)); + m.impl("max.dim_max", TORCH_FN(max_dim_max)); m.impl("max.dim", TORCH_FN(max_dim)); m.impl("max", TORCH_FN(max)); m.impl("add_tensor", TORCH_FN(add_tensor)); From df98af1b95ab4980187a380135cc79fe9631ea9a Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 15:15:29 +0800 Subject: [PATCH 04/22] use CUstream instead c10::cuda::CUDAStream --- lib/add.cpp | 2 +- lib/argmax.cpp | 4 ++-- lib/contiguous.cpp | 2 +- lib/zeros.cpp | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/add.cpp b/lib/add.cpp index 8270570a8..abea7267f 100644 --- a/lib/add.cpp +++ b/lib/add.cpp @@ -33,7 +33,7 @@ at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); c10::DeviceGuard guard(out.device()); CUstream raw_stream = static_cast(stream.stream()); - f(stream, num_blocks, 1, 1, num_warps, num_stages, a, b, out, n, tile_size); + f(raw_stream, num_blocks, 1, 1, num_warps, num_stages, a, b, out, n, tile_size); return out; } } // namespace flag_gems diff --git a/lib/argmax.cpp b/lib/argmax.cpp index 05425453b..12c933893 100644 --- a/lib/argmax.cpp +++ b/lib/argmax.cpp @@ -105,8 +105,8 @@ at::Tensor argmax(const at::Tensor &self, std::optional dim, bool keepd c10::DeviceGuard guard(self.device()); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - - f(stream, grid_x, grid_y, 1, num_warps, num_stages, contiguous_self, out, M, N, K, tile_m, tile_n); + CUstream = raw_stream = static_cast(stream.stream()); + f(raw_stream, grid_x, grid_y, 1, num_warps, num_stages, contiguous_self, out, M, N, K, tile_m, tile_n); return out; } diff --git a/lib/contiguous.cpp b/lib/contiguous.cpp index 31969d3bc..9ddf83cd0 100644 --- a/lib/contiguous.cpp +++ b/lib/contiguous.cpp @@ -33,7 +33,7 @@ at::Tensor contiguous(const at::Tensor &self, at::MemoryFormat memory_format) { c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); c10::DeviceGuard guard(out.device()); CUstream raw_stream = static_cast(stream.stream()); - f(stream, + f(raw_stream, num_blocks, 1, 1, diff --git a/lib/zeros.cpp b/lib/zeros.cpp index d512d3521..3f8d7cf4d 100644 --- a/lib/zeros.cpp +++ b/lib/zeros.cpp @@ -46,7 +46,7 @@ at::Tensor zeros(at::IntArrayRef size, c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream raw_stream = static_cast(stream.stream()); - f(stream, + f(raw_stream, num_blocks, /* grid_y = */ 1, /* grid_z = */ 1, From 53923fd84c503b0778d2cf957edc38f5bf4607d9 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 15:22:47 +0800 Subject: [PATCH 05/22] use CUstream instead c10::cuda::CUDAStream --- lib/argmax.cpp | 1 + lib/max.cpp | 2 +- lib/sum.cpp | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/argmax.cpp b/lib/argmax.cpp index 12c933893..c0bf7bae7 100644 --- a/lib/argmax.cpp +++ b/lib/argmax.cpp @@ -106,6 +106,7 @@ at::Tensor argmax(const at::Tensor &self, std::optional dim, bool keepd c10::DeviceGuard guard(self.device()); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream = raw_stream = static_cast(stream.stream()); + f(raw_stream, grid_x, grid_y, 1, num_warps, num_stages, contiguous_self, out, M, N, K, tile_m, tile_n); return out; diff --git a/lib/max.cpp b/lib/max.cpp index be353672f..ee1f9252a 100644 --- a/lib/max.cpp +++ b/lib/max.cpp @@ -122,7 +122,7 @@ ::std::tuple max_dim(const at::Tensor &self, int64_t dim c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream raw_stream = static_cast(stream.stream()); - f(stream, + f(raw_stream, num_blocks, 1, 1, diff --git a/lib/sum.cpp b/lib/sum.cpp index 2469bdaa0..b8c150180 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -86,7 +86,7 @@ at::Tensor sum_dim(const at::Tensor &self, c10::DeviceGuard guard(out.device()); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream raw_stream = static_cast(stream.stream()); - f(stream, + f(raw_stream, num_blocks, 1, 1, From f7b0993b75bd44d8a9df9910af3a6b29369d74ba Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 16:24:55 +0800 Subject: [PATCH 06/22] =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/flag_gems/utils.h | 5 ++++ lib/max.cpp | 40 +++------------------------ lib/sum.cpp | 42 +++++------------------------ lib/utils.cpp | 57 +++++++++++++++++++++++++++++++++++++-- src/flag_gems/ops/sum.py | 34 +++++++++++++++++++++++ triton_src/sum.py | 50 ---------------------------------- 6 files changed, 105 insertions(+), 123 deletions(-) delete mode 100644 triton_src/sum.py diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index 3aa08909d..fd245cbcd 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -17,4 +17,9 @@ std::filesystem::path get_triton_src_path(); std::filesystem::path get_flag_gems_src_path(); int64_t next_power_of_2(int64_t n); bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); +std::tuple permute_reduction_axes_right(const at::Tensor &tensor, + int reduction_axis); +std::tuple permute_reduction_axes_right( + const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt); +int cdiv(int a, int b); } // namespace flag_gems::utils diff --git a/lib/max.cpp b/lib/max.cpp index ee1f9252a..acf00d4c0 100644 --- a/lib/max.cpp +++ b/lib/max.cpp @@ -11,35 +11,6 @@ #include "ATen/native/ReduceOpsUtils.h" #include "c10/util/DimVector.h" -namespace { -std::tuple permute_reduction_axes_right(const at::Tensor &tensor, - int reduction_axis) { - int64_t dim = tensor.dim(); - c10::DimVector left_axes, right_axes; - int64_t non_reduction_size = 1, reduction_size = 1; - - for (int64_t i = 0; i < dim; ++i) { - if (i == reduction_axis) { - right_axes.push_back(i); - reduction_size *= tensor.size(i); - } else { - left_axes.push_back(i); - non_reduction_size *= tensor.size(i); - } - } - c10::DimVector permute_order = left_axes; - permute_order.insert(permute_order.end(), right_axes.begin(), right_axes.end()); - - return {tensor.permute(permute_order), non_reduction_size, reduction_size}; -} -int next_power_of_2(int x) { - return (1 << (32 - __builtin_clz(x - 1))); -} -int cdiv(int a, int b) { - return (a + b - 1) / b; -} -} // anonymous namespace - namespace flag_gems { using namespace triton_jit; // max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> @@ -97,7 +68,7 @@ ::std::tuple max_dim(const at::Tensor &self, int64_t dim at::Tensor out_value = at::empty(shape, self.options()); at::Tensor out_index = at::empty(shape, self.options().dtype(at::kLong)); - auto [permuted_self, non_reduction_size, reduction_size] = permute_reduction_axes_right(self, dim); + auto [permuted_self, non_reduction_size, reduction_size] = utils::permute_reduction_axes_right(self, dim); permuted_self = permuted_self.contiguous(); const TritonJITFunction &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "max.py"), @@ -142,9 +113,9 @@ ::std::tuple max_dim(const at::Tensor &self, int64_t dim at::Tensor max(const at::Tensor &self) { TORCH_CHECK(self.is_contiguous(), "Input tensor must be contiguous"); int64_t M = self.numel(); - int64_t block_size = 1 << static_cast(std::ceil(std::log2(std::sqrt(M)))); - int64_t mid_size = (M + block_size - 1) / block_size; - int64_t block_mid = 1 << static_cast(std::ceil(std::log2(mid_size))); + int64_t block_size = utils::next_power_of_2(static_cast(std::ceil(std::sqrt(M)))); + int64_t mid_size = utils::cdiv(M, block_size); + int64_t block_mid = utils::next_power_of_2(mid_size); at::Tensor mid = torch::empty({mid_size}, self.options()); at::Tensor out = torch::empty({}, self.options()); @@ -155,9 +126,6 @@ at::Tensor max(const at::Tensor &self) { const TritonJITFunction &max_kernel_2 = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "max.py"), "max_kernel_2"); - block_size = next_power_of_2(static_cast(std::ceil(std::sqrt(M)))); - mid_size = cdiv(M, block_size); - block_mid = next_power_of_2(mid_size); const int num_warps = 8; const int num_stages = 2; c10::DeviceGuard guard(out.device()); diff --git a/lib/sum.cpp b/lib/sum.cpp index b8c150180..309c04d4e 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -11,41 +11,12 @@ #include "ATen/native/ReduceOpsUtils.h" #include "c10/util/DimVector.h" -namespace { -std::tuple permute_reduction_axes_right( - const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt) { - int64_t dim = tensor.dim(); - c10::DimVector reduction_axes; - - if (reduction_axes_opt.has_value()) { - reduction_axes = reduction_axes_opt.value().vec(); - } - - std::unordered_set reduction_set(reduction_axes.begin(), reduction_axes.end()); - - c10::DimVector left_axes, right_axes; - int64_t non_reduction_size = 1, reduction_size = 1; - - for (int64_t i = 0; i < dim; ++i) { - if (reduction_set.count(i)) { - right_axes.push_back(i); - reduction_size *= tensor.size(i); - } else { - left_axes.push_back(i); - non_reduction_size *= tensor.size(i); - } - } - - // Concatenate left and right axes to form the new permutation order - c10::DimVector permute_order = left_axes; - permute_order.insert(permute_order.end(), right_axes.begin(), right_axes.end()); - - return {tensor.permute(permute_order), non_reduction_size, reduction_size}; -} -} // anonymous namespace - namespace flag_gems { using namespace triton_jit; +// sum(Tensor self, *, ScalarType? dtype=None) -> Tensor +at::Tensor sum_dim(const at::Tensor &self, ::std::optional dtype) { +} + // signature // sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? // dtype=None) -> Tensor @@ -59,7 +30,7 @@ at::Tensor sum_dim(const at::Tensor &self, c10::ScalarType out_dtype = at::native::get_dtype_from_self(self, dtype, true); at::Tensor out = at::empty(shape, self.options()); - auto [permuted_self, non_reduction_size, reduction_size] = permute_reduction_axes_right(self, dims_); + auto [permuted_self, non_reduction_size, reduction_size] = utils::permute_reduction_axes_right(self, dims_); permuted_self = permuted_self.contiguous(); /* signature to remind yourself @@ -74,7 +45,8 @@ at::Tensor sum_dim(const at::Tensor &self, ): */ const TritonJITFunction &f = - TritonJITFunction::getInstance(std::string(utils::get_triton_src_path() / "sum.py"), "sum_kernel"); + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), + "sum_kernel"); // add utility to build this automatically int64_t tile_m = 4; diff --git a/lib/utils.cpp b/lib/utils.cpp index cd04079ae..4ed35665d 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -8,7 +8,7 @@ std::filesystem::path get_path_of_this_library() { // there is no build system generator to take care of this. static const std::filesystem::path cached_path = []() { Dl_info dl_info; - if (dladdr(reinterpret_cast(&get_path_of_this_library), &dl_info) && dl_info.dli_fname) { + if (dladdr(reinterpret_cast(&get_path_of_this_library), &dl_info) && dl_info.dli_fname) { return std::filesystem::canonical(dl_info.dli_fname); // Ensure absolute, resolved path } else { throw std::runtime_error("cannot get the path of libjit_utils.so"); @@ -34,7 +34,7 @@ std::filesystem::path get_triton_src_path() { std::filesystem::path get_flag_gems_src_path() { const static std::filesystem::path flag_gems_src_dir = []() { - const char* flag_gems_dir = std::getenv("FLAGGEMS_SOURCE_DIR"); + const char *flag_gems_dir = std::getenv("FLAGGEMS_SOURCE_DIR"); if (!flag_gems_dir) { throw std::runtime_error("Environment variable FLAGGEMS_SOURCE_DIR not set"); } @@ -80,4 +80,57 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2) { return true; } +std::tuple permute_reduction_axes_right( + const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt) { + int64_t dim = tensor.dim(); + c10::DimVector reduction_axes; + + if (reduction_axes_opt.has_value()) { + reduction_axes = reduction_axes_opt.value().vec(); + } + + std::unordered_set reduction_set(reduction_axes.begin(), reduction_axes.end()); + + c10::DimVector left_axes, right_axes; + int64_t non_reduction_size = 1, reduction_size = 1; + + for (int64_t i = 0; i < dim; ++i) { + if (reduction_set.count(i)) { + right_axes.push_back(i); + reduction_size *= tensor.size(i); + } else { + left_axes.push_back(i); + non_reduction_size *= tensor.size(i); + } + } + + // Concatenate left and right axes to form the new permutation order + c10::DimVector permute_order = left_axes; + permute_order.insert(permute_order.end(), right_axes.begin(), right_axes.end()); + + return {tensor.permute(permute_order), non_reduction_size, reduction_size}; +} +std::tuple permute_reduction_axes_right(const at::Tensor &tensor, + int reduction_axis) { + int64_t dim = tensor.dim(); + c10::DimVector left_axes, right_axes; + int64_t non_reduction_size = 1, reduction_size = 1; + + for (int64_t i = 0; i < dim; ++i) { + if (i == reduction_axis) { + right_axes.push_back(i); + reduction_size *= tensor.size(i); + } else { + left_axes.push_back(i); + non_reduction_size *= tensor.size(i); + } + } + c10::DimVector permute_order = left_axes; + permute_order.insert(permute_order.end(), right_axes.begin(), right_axes.end()); + + return {tensor.permute(permute_order), non_reduction_size, reduction_size}; +} +int cdiv(int a, int b) { + return (a + b - 1) / b; +} } // namespace flag_gems::utils diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index 47699b0ee..c65b37163 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -58,6 +58,40 @@ def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): tl.store(out, sum_val) +@libentry() +@triton.jit +def sum_kernel( + in_ptr, + out_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, +): + if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( + in_ptr.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = in_ptr.dtype.element_ty + + # Map the program id to the row of inp it should compute. + row_ids = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + row_mask = row_ids < M + + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) + for off in tl.range(0, N, BLOCK_N, STAGE): + col_ids = off + tl.arange(0, BLOCK_N) + col_mask = col_ids < N + mask = row_mask[:, None] & col_mask[None, :] + + a = tl.load(in_ptr + row_ids[:, None] * N + col_ids, mask, other=0).to(cdtype) + acc += a + out = tl.sum(acc, axis=1) + tl.store(out_ptr + row_ids, out, row_mask) + + def sum(inp, *, dtype=None): logger.debug("GEMS SUM") M = inp.numel() diff --git a/triton_src/sum.py b/triton_src/sum.py deleted file mode 100644 index 0290a9679..000000000 --- a/triton_src/sum.py +++ /dev/null @@ -1,50 +0,0 @@ -import triton -from triton import language as tl - - -@triton.jit -def sum_kernel( - in_ptr, - out_ptr, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, -): - if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( - in_ptr.dtype.element_ty == tl.bfloat16 - ): - cdtype = tl.float32 - else: - cdtype = in_ptr.dtype.element_ty - - # Map the program id to the row of inp it should compute. - row_ids = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - row_mask = row_ids < M - - acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) - for off in tl.range(0, N, BLOCK_N, STAGE): - col_ids = off + tl.arange(0, BLOCK_N) - col_mask = col_ids < N - mask = row_mask[:, None] & col_mask[None, :] - - a = tl.load(in_ptr + row_ids[:, None] * N + col_ids, mask, other=0).to(cdtype) - acc += a - out = tl.sum(acc, axis=1) - tl.store(out_ptr + row_ids, out, row_mask) - - -if __name__ == "__main__": - import torch - - m = 1024 - n = 1024 * 16 - x = torch.randn((m, n), device="cuda:0") - out = torch.empty((m,), device="cuda:0") - BLOCK_M = 1 - BLOCK_N = 1024 - grid = (triton.cdiv(m, BLOCK_M), 1, 1) - sum_kernel[grid](x, out, m, n, BLOCK_M, BLOCK_N, STAGE=2, num_warps=4) - print(out) - print(x.sum(1)) From 110f47d3e3788592ff343958a41ccadfe96b0d43 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 16:47:50 +0800 Subject: [PATCH 07/22] add max_dim_max and sum, move permute_reduction_axes_right to utils, remove sum in triton_src --- include/flag_gems/operators.h | 1 + lib/argmax.cpp | 2 +- lib/max.cpp | 2 +- lib/sum.cpp | 21 +++++++++++++++++++++ src/flag_gems/csrc/aten_patch.cpp | 1 + src/flag_gems/csrc/cstub.cpp | 4 +++- 6 files changed, 28 insertions(+), 3 deletions(-) diff --git a/include/flag_gems/operators.h b/include/flag_gems/operators.h index c26627f38..f1444adb5 100644 --- a/include/flag_gems/operators.h +++ b/include/flag_gems/operators.h @@ -14,6 +14,7 @@ at::Tensor sum_dim(const at::Tensor &self, at::OptionalIntArrayRef dim, bool keepdim = false, ::std::optional dtype = ::std::nullopt); +at::Tensor sum(const at::Tensor &self, ::std::optional dtype); std::tuple max_dim(const at::Tensor &self, int64_t dim, bool keepdim); std::tuple max_dim_max(const at::Tensor &self, int64_t dim, diff --git a/lib/argmax.cpp b/lib/argmax.cpp index c0bf7bae7..868731258 100644 --- a/lib/argmax.cpp +++ b/lib/argmax.cpp @@ -105,7 +105,7 @@ at::Tensor argmax(const at::Tensor &self, std::optional dim, bool keepd c10::DeviceGuard guard(self.device()); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - CUstream = raw_stream = static_cast(stream.stream()); + CUstream raw_stream = static_cast(stream.stream()); f(raw_stream, grid_x, grid_y, 1, num_warps, num_stages, contiguous_self, out, M, N, K, tile_m, tile_n); diff --git a/lib/max.cpp b/lib/max.cpp index acf00d4c0..b30d5d2e2 100644 --- a/lib/max.cpp +++ b/lib/max.cpp @@ -20,7 +20,7 @@ ::std::tuple max_dim_max(const at::Tensor &self, bool keepdim, const at::Tensor out_value, const at::Tensor out_index) { - auto [permuted_self, non_reduction_size, reduction_size] = permute_reduction_axes_right(self, dim); + auto [permuted_self, non_reduction_size, reduction_size] = utils::permute_reduction_axes_right(self, dim); // set_output(out_value,out_index); permuted_self = permuted_self.contiguous(); const TritonJITFunction &f = diff --git a/lib/sum.cpp b/lib/sum.cpp index 309c04d4e..3944bb254 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -15,6 +15,27 @@ namespace flag_gems { using namespace triton_jit; // sum(Tensor self, *, ScalarType? dtype=None) -> Tensor at::Tensor sum_dim(const at::Tensor &self, ::std::optional dtype) { + TORCH_CHECK(self.is_contiguous(), "Input tensor must be contiguous"); + int64_t M = self.numel(); + int64_t block_size = utils::next_power_of_2(static_cast(std::ceil(std::sqrt(M)))); + int64_t mid_size = utils::cdiv(M, block_size); + int64_t block_mid = utils::next_power_of_2(mid_size); + at::Tensor mid = torch::empty({mid_size}, self.options()); + at::Tensor out = torch::empty({}, self.options()); + const TritonJITFunction &sum_kernel_1 = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), + "sum_kernel_1"); + const TritonJITFunction &sum_kernel_2 = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "max.py"), + "sum_kernel_2"); + const int num_warps = 8; + const int num_stages = 2; + c10::DeviceGuard guard(out.device()); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + sum_kernel_1(raw_stream, mid_size, 1, 1, num_warps, num_stages, self, mid, M, block_size); + sum_kernel_2(raw_stream, 1, 1, 1, num_warps, num_stages, mid, out, mid_size, block_mid); + return out; } // signature diff --git a/src/flag_gems/csrc/aten_patch.cpp b/src/flag_gems/csrc/aten_patch.cpp index 961f755cd..374e979c7 100644 --- a/src/flag_gems/csrc/aten_patch.cpp +++ b/src/flag_gems/csrc/aten_patch.cpp @@ -32,5 +32,6 @@ TORCH_LIBRARY_IMPL(aten, CUDA, m) { REGISTER_AND_LOG("max.dim_max", max_dim_max) REGISTER_AND_LOG("max.dim", max_dim) REGISTER_AND_LOG("max", max) + REGISTER_AND_LOG("sum", sum) } } // namespace flag_gems diff --git a/src/flag_gems/csrc/cstub.cpp b/src/flag_gems/csrc/cstub.cpp index d23ad4b06..c81833809 100644 --- a/src/flag_gems/csrc/cstub.cpp +++ b/src/flag_gems/csrc/cstub.cpp @@ -7,6 +7,7 @@ // bindings provided by torch library, since it is in a boxed fashion PYBIND11_MODULE(c_operators, m) { m.def("sum_dim", &flag_gems::sum_dim); + m.def("sum", &flag_gems::sum); m.def("max_dim", &flag_gems::max_dim); m.def("max", &flag_gems::max); m.def("add_tensor", &flag_gems::add_tensor); @@ -19,7 +20,6 @@ PYBIND11_MODULE(c_operators, m) { m.def("rotary_embedding_inplace", &flag_gems::rotary_embedding_inplace); m.def("bmm", &flag_gems::bmm); } - namespace flag_gems { TORCH_LIBRARY(flag_gems, m) { // blas @@ -31,6 +31,7 @@ TORCH_LIBRARY(flag_gems, m) { "zeros(SymInt[] size, ScalarType? dtype=None,Layout? layout=None, Device? device=None, bool? " "pin_memory=None) -> Tensor"); m.def("sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); + m.def("sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"); m.def( "max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> " "(Tensor(a!) values, Tensor(b!) indices)"); @@ -73,6 +74,7 @@ TORCH_LIBRARY_IMPL(flag_gems, CUDA, m) { m.impl("zeros", TORCH_FN(zeros)); m.impl("sum.dim_IntList", TORCH_FN(sum_dim)); + m.impl("sum", TORCH_FN(sum)); m.impl("max.dim_max", TORCH_FN(max_dim_max)); m.impl("max.dim", TORCH_FN(max_dim)); m.impl("max", TORCH_FN(max)); From d674544805ff4c0417d9d3ecac56ee31ea74e8d9 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 16:59:46 +0800 Subject: [PATCH 08/22] add and fix --- ctests/test_triton_reduction.cpp | 14 ++++++++++++++ include/flag_gems/operators.h | 2 +- lib/sum.cpp | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/ctests/test_triton_reduction.cpp b/ctests/test_triton_reduction.cpp index 2225ecb3c..a2e84c161 100644 --- a/ctests/test_triton_reduction.cpp +++ b/ctests/test_triton_reduction.cpp @@ -2,10 +2,24 @@ #include "c10/util/Logging.h" #include "flag_gems/operators.h" #include "torch/torch.h" + TEST(reduction_op_test, sum) { const torch::Device device(torch::kCUDA, 0); torch::Tensor a = torch::randn({32, 1024}, device); + torch::Tensor out_torch = at::sum(a); + torch::Tensor out_triton = flag_gems::sum(a); + if (!torch::allclose(out_torch, out_triton, 1e-5, 1e-8)) { + LOG(INFO) << "Difference:\n" << out_torch - out_triton; + } + + EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-5, 1e-8)); +} + +TEST(reduction_op_test, sum_dim) { + const torch::Device device(torch::kCUDA, 0); + torch::Tensor a = torch::randn({32, 1024}, device); + torch::Tensor out_torch = at::sum(a, {1}); torch::Tensor out_triton = flag_gems::sum_dim(a, {1}); if (!torch::allclose(out_torch, out_triton, 1e-5, 1e-8)) { diff --git a/include/flag_gems/operators.h b/include/flag_gems/operators.h index f1444adb5..cd4f02b6f 100644 --- a/include/flag_gems/operators.h +++ b/include/flag_gems/operators.h @@ -14,7 +14,7 @@ at::Tensor sum_dim(const at::Tensor &self, at::OptionalIntArrayRef dim, bool keepdim = false, ::std::optional dtype = ::std::nullopt); -at::Tensor sum(const at::Tensor &self, ::std::optional dtype); +at::Tensor sum(const at::Tensor &self, ::std::optional dtype = ::std::nullopt); std::tuple max_dim(const at::Tensor &self, int64_t dim, bool keepdim); std::tuple max_dim_max(const at::Tensor &self, int64_t dim, diff --git a/lib/sum.cpp b/lib/sum.cpp index 3944bb254..88c647ce5 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -14,7 +14,7 @@ namespace flag_gems { using namespace triton_jit; // sum(Tensor self, *, ScalarType? dtype=None) -> Tensor -at::Tensor sum_dim(const at::Tensor &self, ::std::optional dtype) { +at::Tensor sum(const at::Tensor &self, ::std::optional dtype) { TORCH_CHECK(self.is_contiguous(), "Input tensor must be contiguous"); int64_t M = self.numel(); int64_t block_size = utils::next_power_of_2(static_cast(std::ceil(std::sqrt(M)))); From 9e50fd6eeb2f8d2a70503d8f6d57b16cf450c6dd Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 31 Jul 2025 17:03:53 +0800 Subject: [PATCH 09/22] fix error --- lib/sum.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sum.cpp b/lib/sum.cpp index 88c647ce5..ee4842066 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -26,7 +26,7 @@ at::Tensor sum(const at::Tensor &self, ::std::optional dtype) { TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), "sum_kernel_1"); const TritonJITFunction &sum_kernel_2 = - TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "max.py"), + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), "sum_kernel_2"); const int num_warps = 8; const int num_stages = 2; From 1a74427b84005d42be43c4fea77038f59c3dea9a Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Wed, 6 Aug 2025 15:35:57 +0800 Subject: [PATCH 10/22] tmp update --- demo.py | 9 + lib/CMakeLists.txt | 4 +- lib/add.cpp | 63 ++ lib/pointwise_dynamic.cpp | 201 +++++ lib/utils.cpp | 209 +++++ pointwise.log | 960 +++++++++++++++++++++++ src/flag_gems/utils/pointwise_dynamic.py | 8 + src/flag_gems/utils/tensor_wrapper.py | 9 + 8 files changed, 1462 insertions(+), 1 deletion(-) create mode 100644 demo.py create mode 100644 lib/pointwise_dynamic.cpp create mode 100644 pointwise.log diff --git a/demo.py b/demo.py new file mode 100644 index 000000000..81dd32f43 --- /dev/null +++ b/demo.py @@ -0,0 +1,9 @@ +import torch + +import flag_gems + +shape = (4,) +x = torch.randn(shape, device=flag_gems.device) +y = torch.randn_like(x) +with flag_gems.use_gems(): + C = torch.add(x, y) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d4b4d7827..ee06fbdaf 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -18,7 +18,9 @@ add_library(operators bmm.cpp embedding.cpp argmax.cpp - fill.cpp) + fill.cpp + pointwise_dynamic.cpp + ) target_include_directories(operators PUBLIC $ diff --git a/lib/add.cpp b/lib/add.cpp index 8270570a8..47444fe17 100644 --- a/lib/add.cpp +++ b/lib/add.cpp @@ -8,6 +8,7 @@ namespace flag_gems { using namespace triton_jit; +/* at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { auto res = torch::broadcast_tensors({a_, b_}); res[0] = res[0].contiguous(); @@ -36,4 +37,66 @@ at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { f(stream, num_blocks, 1, 1, num_warps, num_stages, a, b, out, n, tile_size); return out; } +*/ + +at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { + // 1. Broadcasting and ensuring contiguous memory layout + auto res = torch::broadcast_tensors({a_, b_}); + const at::Tensor& a = res[0].contiguous(); + const at::Tensor& b = res[1].contiguous(); + + // 2. Determine output dtype and create output tensor + at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); + at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); + + // 3. Get the TritonJITFunction instance + const TritonJITFunction& f = + TritonJITFunction::getInstance(std::string(utils::get_triton_src_path() / "binary_add.py"), + "binary_pointwise_kernel"); + + // 4. Manually prepare the raw argument list (void**) and the signature + int64_t tile_size = 1024; + int64_t n = out.numel(); + + // This is the raw C-style argument array that the CUDA kernel expects. + // It contains the addresses of all the kernel's parameters. + std::vector raw_args_list; + + // Push the data pointers for the tensors. + raw_args_list.push_back(a.data_ptr()); + raw_args_list.push_back(b.data_ptr()); + raw_args_list.push_back(out.data_ptr()); + + // Push the addresses of scalar values. + // NOTE: The scalars 'n' and 'tile_size' must have their addresses taken. + // This is why we use references or variables. + raw_args_list.push_back(&n); + raw_args_list.push_back(&tile_size); + + // 5. Manually generate the signature string + // This must match the kernel's type-based signature for overload resolution. + // This is an example; the exact signature depends on the kernel definition. + std::string signature = "tl.pointer_type,tl.pointer_type,tl.pointer_type,int64,int64"; + + // 6. Set up the launch configuration + const int num_warps = 8; + const int num_stages = 1; + const unsigned int num_blocks = (n + tile_size - 1) / tile_size; + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + c10::DeviceGuard guard(out.device()); + CUstream raw_stream = static_cast(stream.stream()); + + // 7. Launch the kernel using the raw argument list + f.launch_with_raw_args(raw_stream, + num_blocks, + 1, + 1, + num_warps, + num_stages, + signature, + raw_args_list.data()); + + return out; +} + } // namespace flag_gems diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp new file mode 100644 index 000000000..826abb530 --- /dev/null +++ b/lib/pointwise_dynamic.cpp @@ -0,0 +1,201 @@ +#include "flag_gems/operators.h" +#include "flag_gems/utils.h" + +#include +#include "c10/cuda/CUDAStream.h" +#include "triton_jit/triton_jit_function.h" + +namespace pointwise_dynamic { + +// 构造函数 +// src/flag_gems/utils/pointwise_dynamic.py:prepare_args +/* +args = tuple( +( +StridedBuffer( + item, + task_shape, + broadcasted_stride(item.shape, item.stride(), task_shape), +) +if schema.is_tensor(i) +else item +) +for i, item in enumerate(args) +) +kwargs = { + k: StridedBuffer( + item, + task_shape, + broadcasted_stride(item.shape, item.stride(), task_shape), + ) + for k, item in kwargs.items() +} +*/ +pointwise_dynamic::StridedBuffer + + Shape + broadcast(const Shape& s1, const Shape& s2) { + if (s1.empty()) { + return s2; + } + if (s2.empty()) { + return s1; + } + + const Shape* _s1 = &s1; + const Shape* _s2 = &s2; + + if (_s1->size() < _s2->size()) { + std::swap(_s1, _s2); + } + + size_t r1 = _s1->size(); + size_t r2 = _s2->size(); + size_t d = r1 - r2; + + Shape s = *_s1; + + for (size_t i = 0; i < r2; ++i) { + if ((*_s1)[d + i] == 1) { + s[d + i] = (*_s2)[i]; + } else if ((*_s2)[i] == 1) { + s[d + i] = (*_s1)[d + i]; + } else if ((*_s2)[i] == (*_s1)[d + i]) { + s[d + i] = (*_s2)[i]; + } else { + std::string msg = "Unbroadcastable shapes: ("; + for (size_t j = 0; j < s1.size(); ++j) msg += std::to_string(s1[j]) + (j < s1.size() - 1 ? ", " : ""); + msg += ") and ("; + for (size_t j = 0; j < s2.size(); ++j) msg += std::to_string(s2[j]) + (j < s2.size() - 1 ? ", " : ""); + msg += ")"; + throw std::invalid_argument(msg); + } + } + + return s; +} + +template +Shape broadcast_shapes(const Iterable& shapes) { + if (std::empty(shapes)) { + return {}; + } + + auto it = std::begin(shapes); + Shape result_shape = *it; + ++it; + + for (; it != std::end(shapes); ++it) { + result_shape = broadcast(result_shape, *it); + } + + return result_shape; +} +}; // namespace pointwise_dynamic + +namespace flag_gems { +using namespace triton_jit; +int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +} +at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { + // TODO: parse tensor meta info + std::vector kernel_params; + // 2 input + kernel_params.push(a_); + kernel_params.push(b_); + // 1 output + at::Tensor out = at::empty(a.sizes(), a.options()); + kernel_params.push(&out); + // if input和output都连续 + // 或者stride相同和第一个tensor torch.ops.aten.is_non_overlapping_and_dense + // 但是后者不是都连续,为什么stride=1,如连续的ab转置,我们可以忽略它的stride,只计算element wise就行 + // 但返回的时候,是不是要somehow拿回它的stride,不过这可能是python端里的问题 + std::vector tensors = {a_, b_, out}; + // WrapperGenerator: gen_kernel_launch + // KernelGenerator: + if (pointwise_dynamic::use_fast_path(tensors)) { + int task_shape = tensors[0].numel(); + void* task_shape_ptr = &task_shape; + int stride = 1; + void* stride_ptr = &stride; + int ndim = 1; + int fast_path_stride_order = 0; + void* fast_path_stride_order_ptr = &fast_path_stride_order + // push args + // stride for input + kernel_params.push(stride_ptr); + kernel_params.push(fast_path_stride_order_ptr); + kernel_params.push(stride_ptr); + kernel_params.push(fast_path_stride_order_ptr); + // stride for output + kernel_params.push(stride_ptr); + + // task_space -> shape_args... shape = out0.shape + // use fast path需要考虑shape吗 + // prepare args里设置 task_shape = (tensors[0].numel(),) + kernel_params.push(task_shape_ptr); + // num_tasks -> num_tasks = out0.numel() + kernel_params.push(task_shape_ptr); + } else { + // TODO + // stride for input/output + // calculate task space + // shapes = tuple(item.shape for item in in_tensors), + std::vector shapes; + shapes.reserve(2); + for (const auto& tensor : in_tensors) { + shapes.push_back(tensor.shape()); + } + Shape task_shape = broadcast_shapes(shapes); + int64_t ndim = task_shape.size(); + // task_shape = broadcast_shapes(shapes) + // get stride, TODO,using ndim as python warpper + auto a_strides = a_.strides(); + for (int64_t stride : a_strides) { + kernel_params.push_back(&stride); + } + auto b_strides = b_.strides(); + for (int64_t stride : b_strides) { + kernel_params.push_back(&stride); + } + auto out_strides = out.strides(); + for (int64_t stride : out_strides) { + kernel_params.push_back(&stride); + } + } + // # tile size & tiles_per_cta, gsl style + // tile_sizes = heuristics_for_tile_size(512, *shape) + int64_t tile_sizes = 1024; + int64_t num_tiles = cdiv(task_shape, tile_sizes); // aka num blocks + // num_ctas = min(65536, num_tiles) + int64_t num_ctas = std::min(65536, num_tiles); + // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) + int64_t tiles_per_cta = cdiv(num_tiles, num_ctas); + // one_tile_per_cta = tiles_per_cta==1 + bool one_tile_per_cta = (tiles_per_cta == 1); + // get function + std::array is_tensor; + checkIfScalar(scalar_tensor, vector_tensor, is_tensor); + const TritonKernel kernel; + if (is_tensor[0] && is_tensor[1]) { + &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func"); + } else if (is_tensor[0] && !is_tensor[1]) { + &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_tensor_scalar"); + } else if (!is_tensor[0] && is_tensor[1]) { + &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_scalar_tensor"); + } else { + return a_ + b_; + } + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + c10::DeviceGuard guard(out.device()); + CUstream raw_stream = static_cast(stream.stream()); + const int num_warps = 8; + const int num_stages = 1; + f(stream, num_tiles, 1, 1, num_warps, num_stages, kernel_params); +} + +}; // namespace flag_gems diff --git a/lib/utils.cpp b/lib/utils.cpp index cd04079ae..17ffd3bb7 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -43,6 +43,20 @@ std::filesystem::path get_flag_gems_src_path() { return flag_gems_src_dir; } +std::filesystem::path get_code_cache_dir() { + const char* env_cache_dir = std::getenv("FLAGGEMS_CACHE_DIR"); + std::filesystem::path cache_dir; + if (env_cache_dir) { + cache_dir = std::filesystem::path(env_cache_dir); + } else { + cache_dir = std::filesystem::path(std::getenv("HOME")) / ".flaggems"; + } + std::filesystem::create_directories(cache_dir); + std::filesystem::path code_cache_dir = cache_dir / "code_cache"; + std::filesystem::create_directories(code_cache_dir); + return code_cache_dir; +} + int64_t next_power_of_2(int64_t n) { if (n <= 1) return 1; --n; @@ -81,3 +95,198 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2) { return true; } } // namespace flag_gems::utils + +namespace flag_gems::pointwise_dynamic { +using Shape = std::vector; +using Stride = std::vector; +void checkIfScalar(const torch::Tensor& tensor1, + const torch::Tensor& tensor2, + std::array& is_tensor) { + is_tensor[0] = (tensor1.dim() == 0); + is_tensor[1] = (tensor2.dim() == 0); +} + +class StridedBuffer { + public: + StridedBuffer(const torch::Tensor& base, + c10::optional shape = c10::nullopt, + c10::optional strides = c10::nullopt, + c10::optional dtype = c10::nullopt, + int64_t offset = 0) + : base_(base), + dtype_(dtype.has_value() ? dtype.value() : to_custom_dtype(base.dtype())), + offset_(offset) { + if (offset_ == 0) { + data_ptr_ = base_.data_ptr(); + } else { + // TODO kunlunxin case + data_ptr_ = static_cast(base_.data_ptr()) + base_.element_size() * offset_; + } + shape_ = shape.has_value() ? shape.value().vec() : base_.sizes().vec(); + strides_ = strides.has_value() ? strides.value().vec() : base_.strides().vec(); + device_ = base_.device(); + ndim_ = shape_.size(); + } + + const c10::IntArrayRef strides() const { + return strides_; + } + + const c10::IntArrayRef sizes() const { + return shape_; + } + + size_t element_size() const { + return torch::elementSize(to_torch_dtype(dtype_)); + } + + long numel() const { + long num = 1; + for (long s : shape_) { + num *= s; + } + return num; + } + + int64_t dim() const { + return ndim_; + } + + const torch::Tensor& unwrap() const { + return base_; + } + + void* data_ptr() const { + return data_ptr_; + } + + torch::Storage untyped_storage() const { + return base_.storage(); + } + + StridedBuffer clone() const { + return StridedBuffer(base_.clone(), shape_, strides_, dtype_, offset_); + } + + StridedBuffer& copy_(const StridedBuffer& src) { + base_.copy_(base_.new_empty(src.sizes(), src.strides()) + .as_strided(src.sizes(), src.strides()) + .copy_(src.unwrap())); + strides_ = src.strides(); + shape_ = src.sizes(); + dtype_ = src.dtype(); + offset_ = src.offset_; + data_ptr_ = src.data_ptr(); + + return *this; + } + + StridedBuffer& copy_(const torch::Tensor& src) { + StridedBuffer src_buffer(src); + return this->copy_(src_buffer); + } + + torch::Device device() const { + return device_; + } + + dType dtype() const { + return dtype_; + } + + long offset() const { + return offset_; + } + + private: + torch::Tensor base_; + dType dtype_; + void* data_ptr_; + int64_t offset_; + std::vector shape_; + std::vector strides_; + torch::Device device_; + int64_t ndim_; +}; + +bool broadcastable_to(const Shape& shape, const Shape& new_shape) { + int r1 = shape.size(); + int r2 = new_shape.size(); + int min_rank = std::min(r1, r2); + + for (int i = 1; i <= min_rank; ++i) { + int dim1 = shape[r1 - i]; + int dim2 = new_shape[r2 - i]; + if (dim1 != dim2 && dim1 != 1 && dim2 != 1) { + return false; + } + } + return true; +} + +Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape) { + assert(broadcastable_to(shape, new_shape) && "Shapes are not broadcastable."); + + int r1 = shape.size(); + int r2 = new_shape.size(); + int d = r2 - r1; + + Stride new_stride(r2, 0); + for (int i = 0; i < r1; ++i) { + int new_dim_index = d + i; + if (shape[i] == 1 && new_shape[new_dim_index] > 1) { + new_stride[new_dim_index] = 0; + } else { + new_stride[new_dim_index] = stride[i]; + } + } + return new_stride; +} + +bool all_the_same_shape(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + const auto& first_shape = tensors[0].sizes(); + for (const auto& tensor : tensors) { + if (!tensor.sizes().equals(first_shape)) { + return false; + } + } + return true; +} + +bool all_c_contiguous(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + for (const auto& tensor : tensors) { + if (!tensor.is_contiguous()) { + return false; + } + } + return true; +} + +bool all_the_same_stride(const std::vector& tensors) { + if (tensors.empty()) { + return true; + } + const auto& first_stride = tensors[0].strides(); + for (const auto& tensor : tensors) { + if (!tensor.strides().equals(first_stride)) { + return false; + } + } + return true; +} +bool use_fast_path(const std::vector& tensors) { + if (!all_the_same_shape(tensors)) { + return false; + } + if (all_c_contiguous(tensors)) { + return true; + } + return all_the_same_stride(tensors) && tensors[0].is_non_overlapping_and_dense(); +} +}; // namespace flag_gems::pointwise_dynamic diff --git a/pointwise.log b/pointwise.log new file mode 100644 index 000000000..06b58f0bf --- /dev/null +++ b/pointwise.log @@ -0,0 +1,960 @@ +============================= test session starts ============================== +platform linux -- Python 3.10.12, pytest-8.4.1, pluggy-1.6.0 +rootdir: /home/fyf/FlagGems +configfile: pytest.ini +plugins: hypothesis-6.136.4 +collected 79 items / 77 deselected / 2 selected + +test_pointwise_dynamic.py prepare args +(tensor(-0.2970, device='cuda:0'), tensor(-0.9287, device='cuda:0')) +(StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +codegen config +CodeGenConfig(max_tile_size=1024, max_grid_size=(65535, 65535, 65535), max_num_warps_per_cta=32, prefer_block_pointer=True, prefer_1d_tile=False) +prepare args +(tensor([1.1823, 1.3317], device='cuda:0'), tensor([ 0.5293, -1.8621], device='cuda:0')) +(StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[-0.8132, 1.1786], + [-0.7438, -0.1906]], device='cuda:0'), tensor([[ 0.4930, -1.2164], + [ 0.9088, -0.1768]], device='cuda:0')) +(StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[-1.8601, -0.4227], + [ 1.3705, -1.0152]], + + [[ 0.4315, -1.4218], + [ 1.0209, -0.2430]]], device='cuda:0'), tensor([[[ 1.4617, 0.6862], + [-0.1194, 1.3785]], + + [[ 1.1918, -0.0827], + [ 0.8230, 0.9496]]], device='cuda:0')) +(StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[-0.3224, 1.3309], + [ 0.6534, 1.0808]], + + [[ 1.0072, -1.1201], + [ 0.6319, 0.4277]]], + + + [[[ 1.9672, 0.6689], + [-3.6808, 0.7714]], + + [[-1.7392, 0.3517], + [ 0.9204, 0.1764]]]], device='cuda:0'), tensor([[[[-0.7463, -0.7569], + [-0.8233, -0.6181]], + + [[ 0.6934, -0.0429], + [-0.7358, 0.1099]]], + + + [[[-0.1601, -0.0552], + [-0.3414, 1.3781]], + + [[-0.2480, -0.6523], + [-1.7379, -0.6234]]]], device='cuda:0')) +(StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[ 0.2704, -1.0088], + [-0.3647, -1.2222]], + + [[ 1.0919, -0.2680], + [ 0.8758, 1.6583]]], + + + [[[-0.5207, 0.2852], + [-0.8225, -1.1796]], + + [[ 0.2731, -0.4979], + [-0.9975, 0.6967]]]], + + + + [[[[-1.2260, 0.2383], + [-0.2300, -0.1896]], + + [[ 0.3966, 1.9089], + [ 1.2267, 1.0300]]], + + + [[[-0.2939, 1.3133], + [ 0.0414, 0.4334]], + + [[-0.3053, 0.8554], + [ 0.6063, 0.1726]]]]], device='cuda:0'), tensor([[[[[-1.6378, 1.6238], + [-0.5613, 1.1061]], + + [[-0.0363, 0.0104], + [-0.3310, 0.0274]]], + + + [[[ 0.8872, -1.0808], + [ 0.5427, 0.2029]], + + [[-0.8999, 0.1127], + [-0.1466, 0.4300]]]], + + + + [[[[ 0.2388, -0.4143], + [-1.6903, -1.2033]], + + [[-1.6196, 0.6702], + [-1.5508, -0.0859]]], + + + [[[-0.8872, -1.0914], + [ 0.4906, 0.9720]], + + [[ 1.5728, 0.7931], + [-0.4117, 1.1539]]]]], device='cuda:0')) +(StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[[-1.6717, 0.9015], + [ 1.3460, -1.0790]], + + [[-0.6996, 0.1213], + [-0.3573, 0.3552]]], + + + [[[ 1.5661, -1.7230], + [-1.7996, 0.4824]], + + [[-1.3555, 1.1841], + [ 1.2727, 0.2422]]]], + + + + [[[[-2.1553, -0.5068], + [-0.4852, -0.6205]], + + [[ 0.0753, -0.2457], + [-0.8317, -0.3223]]], + + + [[[-0.4561, -0.5499], + [-0.3770, -0.5122]], + + [[ 0.0839, 1.1138], + [-0.0291, -0.2110]]]]], + + + + + [[[[[-0.0495, 0.0353], + [ 0.3858, 0.7092]], + + [[ 1.0674, -0.0896], + [ 0.4753, 1.6110]]], + + + [[[-1.0610, 0.6949], + [-0.0659, 0.1247]], + + [[ 0.9214, -1.1054], + [-0.3938, 0.6080]]]], + + + + [[[[ 0.3110, 0.8723], + [ 0.2618, -1.0005]], + + [[ 1.7426, 1.0750], + [-0.0059, -1.2323]]], + + + [[[ 1.0717, -1.3280], + [ 0.1808, -0.4660]], + + [[-1.3863, -0.9706], + [-0.1340, -0.6058]]]]]], device='cuda:0'), tensor([[[[[[ 1.2663, -1.5993], + [ 0.3489, -0.2854]], + + [[-1.5415, 0.5565], + [-0.1737, -0.9266]]], + + + [[[ 0.1387, 0.0365], + [ 0.0643, 0.8480]], + + [[-0.5534, 0.4295], + [-1.1992, -0.8280]]]], + + + + [[[[ 1.1308, 0.6357], + [ 0.0803, -0.7917]], + + [[-1.9792, 0.2240], + [-0.1467, -0.3046]]], + + + [[[ 1.8291, -1.0905], + [ 0.6813, -0.6597]], + + [[-0.3329, -0.0720], + [-0.7636, 0.6887]]]]], + + + + + [[[[[ 1.5301, -0.6431], + [-0.1242, -0.6108]], + + [[ 0.7454, -0.6965], + [ 0.1111, 0.6575]]], + + + [[[ 1.1191, -0.2867], + [-1.1339, 1.3141]], + + [[-1.9876, -1.3804], + [ 0.4192, -0.2563]]]], + + + + [[[[-0.6585, 1.5626], + [ 1.1675, -2.2182]], + + [[ 0.9029, 0.1574], + [-0.0679, 0.7056]]], + + + [[[-0.1949, -0.1326], + [-0.1437, 1.9092]], + + [[-0.9308, 1.6464], + [-1.3118, 1.4294]]]]]], device='cuda:0')) +(StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[[[ 2.0424e-01, 3.6634e+00], + [-7.3341e-01, -1.2448e+00]], + + [[ 5.4542e-01, 3.5558e-01], + [-5.4735e-01, -2.9916e-01]]], + + + [[[ 1.1489e+00, 1.4721e+00], + [-4.6577e-01, 8.1158e-01]], + + [[-1.1490e-01, 1.3631e+00], + [ 7.9627e-01, -2.8227e-01]]]], + + + + [[[[ 6.9279e-02, 1.6411e-01], + [ 2.9733e-01, 2.7890e+00]], + + [[ 1.6268e+00, -1.8650e-01], + [ 7.6546e-01, -1.9230e+00]]], + + + [[[-1.8418e+00, 6.2578e-01], + [-2.1408e+00, -3.9034e-01]], + + [[-7.7308e-01, 2.7685e+00], + [ 2.2738e-01, 7.9492e-01]]]]], + + + + + [[[[[-2.0471e+00, -1.6379e-01], + [-9.3400e-02, -1.0257e+00]], + + [[-2.8094e-01, -9.6701e-02], + [ 1.0009e+00, -1.8583e+00]]], + + + [[[ 1.7921e-01, -8.8393e-01], + [ 1.5230e-01, 1.7851e+00]], + + [[-1.9208e-01, -9.8954e-01], + [-8.1615e-01, -7.8057e-01]]]], + + + + [[[[ 1.3682e+00, -2.3161e+00], + [-1.2802e+00, -1.7351e+00]], + + [[-2.0993e+00, -1.3359e+00], + [ 2.1560e-01, 1.4425e+00]]], + + + [[[-2.3548e-01, 7.7176e-02], + [-4.9335e-01, 1.4003e+00]], + + [[ 5.7574e-01, -8.6277e-01], + [-1.2609e+00, 1.7326e-01]]]]]], + + + + + + [[[[[[-8.1147e-01, -6.6557e-03], + [ 5.9860e-01, -2.2498e+00]], + + [[ 1.6007e+00, 1.2472e+00], + [ 3.9557e-01, -8.8763e-01]]], + + + [[[ 6.4478e-01, -3.5087e-01], + [-1.0402e+00, 1.1489e+00]], + + [[ 1.5770e+00, -5.5470e-01], + [-1.9678e-01, -1.3599e-01]]]], + + + + [[[[ 4.2071e-01, 1.5784e+00], + [ 7.4890e-01, 1.5878e-01]], + + [[-7.4601e-02, -2.6819e-01], + [-1.3362e-02, -8.7745e-01]]], + + + [[[ 5.4868e-01, 9.4958e-01], + [-1.3364e+00, -1.1842e+00]], + + [[-4.5081e-01, -5.3200e-02], + [ 5.6642e-01, -1.5040e-01]]]]], + + + + + [[[[[ 1.5685e+00, -4.2357e-02], + [ 2.0459e+00, 9.0893e-01]], + + [[-8.7294e-01, -1.2778e+00], + [-9.9728e-02, -1.7435e+00]]], + + + [[[-1.0093e+00, 1.3365e+00], + [ 1.4186e+00, 1.0965e+00]], + + [[-7.9264e-01, -8.2843e-01], + [-4.7937e-01, 1.1248e+00]]]], + + + + [[[[ 5.4494e-01, -3.1463e-01], + [ 1.3364e+00, 4.9499e-01]], + + [[ 8.0316e-01, 5.9074e-01], + [-8.1497e-01, -1.2023e-01]]], + + + [[[ 4.8511e-01, -6.7446e-02], + [ 5.5175e-01, -4.7541e-01]], + + [[ 2.1157e-01, 2.6811e-01], + [-3.1363e-03, -1.0848e+00]]]]]]], device='cuda:0'), tensor([[[[[[[ 0.1743, -0.1971], + [ 1.9595, -1.2222]], + + [[ 1.6432, -0.3672], + [-0.2254, 0.8645]]], + + + [[[-1.3769, 0.9489], + [-0.5037, -0.6765]], + + [[ 1.4774, -0.6471], + [ 0.5856, -0.3982]]]], + + + + [[[[-0.5755, 0.6125], + [-1.0343, -0.6612]], + + [[-0.1660, -0.1347], + [-0.2614, 2.0332]]], + + + [[[ 0.5120, 1.7416], + [-2.5207, -1.0132]], + + [[ 0.7363, -0.7231], + [-0.5677, -0.0228]]]]], + + + + + [[[[[-2.1274, -1.4172], + [-0.5463, 0.5046]], + + [[-0.1216, 0.1479], + [ 0.2131, 0.5524]]], + + + [[[ 1.0420, -0.6186], + [-0.2048, 0.0830]], + + [[ 0.4373, -1.0390], + [ 0.8188, 1.5698]]]], + + + + [[[[-1.8029, 0.9458], + [-1.1819, 0.6787]], + + [[ 0.4058, 0.7104], + [ 0.4785, 0.1812]]], + + + [[[-0.6681, -0.5415], + [ 1.1918, -0.8995]], + + [[ 1.8353, -0.3662], + [-0.4716, -0.0552]]]]]], + + + + + + [[[[[[-0.0369, 1.5682], + [ 0.0950, 0.7663]], + + [[ 0.6449, -1.8643], + [ 0.5386, 0.7891]]], + + + [[[-0.4242, -0.1472], + [-0.4784, 0.7631]], + + [[-0.0643, -0.6710], + [-1.6131, -1.2870]]]], + + + + [[[[ 0.3732, -0.4229], + [-2.1129, 1.5201]], + + [[ 0.4253, -1.9120], + [-0.0107, 0.2316]]], + + + [[[-1.0561, -0.9999], + [ 0.1939, -1.3623]], + + [[ 0.5999, 0.6589], + [ 1.4880, -0.6617]]]]], + + + + + [[[[[ 0.6459, -1.4469], + [-1.4411, -0.5453]], + + [[ 1.1159, 1.3651], + [ 2.6172, 0.8242]]], + + + [[[-0.4146, -4.0123], + [ 0.2882, -0.5872]], + + [[-2.1117, 0.0071], + [-0.1420, -1.7221]]]], + + + + [[[[ 0.6570, 2.1023], + [ 0.0969, -0.2243]], + + [[-0.7129, -0.4179], + [ 1.0163, -0.3797]]], + + + [[[-0.4553, -1.4885], + [-0.6644, -0.2643]], + + [[ 0.1670, -0.3915], + [-0.5854, -0.0465]]]]]]], device='cuda:0')) +(StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +.prepare args +(tensor(-0.4134, device='cuda:0'), tensor(-0.8861, device='cuda:0')) +(StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +codegen config +CodeGenConfig(max_tile_size=1024, max_grid_size=(65535, 65535, 65535), max_num_warps_per_cta=32, prefer_block_pointer=False, prefer_1d_tile=False) +prepare args +(tensor([ 0.4362, -0.2752], device='cuda:0'), tensor([-0.5101, 0.0443], device='cuda:0')) +(StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[ 0.3208, -1.7019], + [-0.2168, -0.8074]], device='cuda:0'), tensor([[ 0.1845, 2.7074], + [ 0.5607, -0.5639]], device='cuda:0')) +(StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[ 0.4626, 3.0048], + [ 0.1396, -0.0194]], + + [[-1.0254, -2.4699], + [-0.0748, 0.6536]]], device='cuda:0'), tensor([[[ 0.6808, -0.3787], + [-1.3052, -0.1843]], + + [[-0.8677, -0.0071], + [ 0.0053, 0.3579]]], device='cuda:0')) +(StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[-0.5866, -0.9811], + [ 2.3893, 0.1781]], + + [[-1.0846, 0.5295], + [-0.9425, 0.3646]]], + + + [[[ 1.0202, -0.7207], + [ 0.4135, -1.0128]], + + [[-1.2374, -1.4222], + [-0.1221, -0.9303]]]], device='cuda:0'), tensor([[[[ 0.0718, -1.2530], + [ 0.2905, -1.2469]], + + [[ 0.5271, -0.4675], + [ 1.0905, 0.3822]]], + + + [[[ 1.0471, 1.7749], + [ 0.1588, 0.1025]], + + [[ 1.3431, -0.4345], + [ 0.9363, -0.3985]]]], device='cuda:0')) +(StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[-0.4167, 1.6186], + [ 0.6353, 0.6347]], + + [[ 0.6815, 0.0033], + [-1.1068, -1.8601]]], + + + [[[-0.8668, 0.0528], + [ 0.1634, -0.9488]], + + [[-1.1946, -0.6738], + [-0.1948, -1.3550]]]], + + + + [[[[ 0.2253, 0.2737], + [-0.1906, -0.1625]], + + [[ 1.1501, -2.2624], + [-0.8433, 0.1347]]], + + + [[[-1.3359, -0.3359], + [ 0.0135, -0.2707]], + + [[ 0.5104, 0.3117], + [-1.3617, 2.1708]]]]], device='cuda:0'), tensor([[[[[-0.3748, 1.8427], + [-1.0176, 0.6997]], + + [[-0.0785, 0.5860], + [-0.1548, 2.1246]]], + + + [[[ 1.1961, 0.5185], + [-0.1329, 0.5269]], + + [[ 1.8418, 0.2618], + [ 1.0670, 0.0048]]]], + + + + [[[[-0.3343, 0.6220], + [ 0.7129, 0.5356]], + + [[-3.0137, -0.9617], + [-0.9842, 0.2289]]], + + + [[[ 1.3555, 0.3382], + [-0.5355, -1.6174]], + + [[-1.5904, 0.4780], + [-0.4948, 1.9038]]]]], device='cuda:0')) +(StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[[-0.5968, 1.2613], + [ 1.0317, -0.5629]], + + [[ 0.1888, -0.3378], + [ 0.7123, -0.5086]]], + + + [[[-0.5209, -1.4746], + [ 0.5645, -0.6823]], + + [[ 0.6070, 0.3448], + [ 0.5543, 1.3481]]]], + + + + [[[[-0.5193, -0.5582], + [-1.4659, -0.6427]], + + [[ 0.9501, 0.9771], + [ 0.4604, 0.6087]]], + + + [[[ 1.0890, -0.4698], + [-0.2139, 0.3466]], + + [[ 0.8443, 0.1935], + [-1.9194, 0.6502]]]]], + + + + + [[[[[-0.7959, 0.6719], + [-0.0453, 0.1200]], + + [[ 0.0031, -1.3698], + [ 1.3819, -0.0564]]], + + + [[[ 1.6223, 2.5625], + [ 0.7277, -1.0128]], + + [[ 0.5218, 0.8861], + [-0.1644, 1.8201]]]], + + + + [[[[-0.7792, -0.4471], + [ 1.5924, 1.4347]], + + [[-0.4343, -0.5292], + [ 0.2309, 0.0933]]], + + + [[[-0.0781, 0.5380], + [-0.3952, 0.1455]], + + [[-2.1586, 0.0138], + [ 0.4798, 0.6924]]]]]], device='cuda:0'), tensor([[[[[[-2.3175, 0.2156], + [-0.7388, 0.2079]], + + [[-1.0429, -2.5669], + [-1.2059, -0.8969]]], + + + [[[-0.3532, 0.1744], + [-0.7268, 1.1393]], + + [[ 1.3921, -0.4389], + [-1.0334, 1.1344]]]], + + + + [[[[-1.0301, 0.7112], + [-0.2814, -1.8565]], + + [[-0.3059, -0.3104], + [-0.1747, -1.5077]]], + + + [[[-0.7768, -0.5173], + [ 0.8804, 1.9415]], + + [[-0.9273, -1.3389], + [ 0.0869, 0.3431]]]]], + + + + + [[[[[ 0.2382, -0.6564], + [ 0.0112, 0.5209]], + + [[-1.9071, -0.6068], + [ 0.6979, 0.4808]]], + + + [[[-1.4359, 0.3032], + [ 0.9301, -0.2266]], + + [[ 1.3977, -1.7108], + [-0.6738, -0.9103]]]], + + + + [[[[ 0.8537, 0.8330], + [-1.3226, 0.6371]], + + [[ 0.0989, 1.9216], + [-0.0214, -1.1547]]], + + + [[[-0.7310, -0.4316], + [-0.2408, -0.0127]], + + [[ 1.3543, 0.7188], + [ 0.3211, -0.7355]]]]]], device='cuda:0')) +(StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +prepare args +(tensor([[[[[[[-3.8181e-01, -3.5140e-01], + [ 3.0736e+00, -1.1741e+00]], + + [[-2.3109e-01, 5.3014e-01], + [-1.6435e-01, 9.9630e-03]]], + + + [[[ 1.2646e+00, -7.9442e-03], + [-2.9683e-01, -8.7684e-01]], + + [[-3.7105e-02, -3.4636e-01], + [ 1.1675e+00, -5.9899e-01]]]], + + + + [[[[ 8.8007e-02, -7.8832e-01], + [-1.8963e-01, -6.6822e-01]], + + [[ 2.5199e+00, -1.1756e-01], + [-9.1058e-01, -7.4305e-01]]], + + + [[[ 8.5866e-01, -1.8395e-01], + [-9.6957e-01, -7.6585e-01]], + + [[ 9.0692e-01, 1.7178e-01], + [ 8.9621e-01, -1.4574e-01]]]]], + + + + + [[[[[-5.1221e-01, 1.7885e+00], + [-5.4083e-01, -3.3127e-01]], + + [[ 5.6645e-01, -1.7224e+00], + [-7.1040e-01, 4.3608e-01]]], + + + [[[-7.6702e-01, -9.2200e-01], + [-5.8096e-03, -7.2986e-01]], + + [[ 2.4741e-01, 1.1015e+00], + [ 8.7678e-01, 4.4875e-01]]]], + + + + [[[[ 1.1284e+00, 2.1082e+00], + [-1.0700e-01, 3.1007e-01]], + + [[-4.7766e-01, -4.2508e-01], + [ 1.8567e-01, -9.1674e-01]]], + + + [[[-6.4479e-01, -5.4386e-01], + [-2.2923e-01, -4.6170e-01]], + + [[-3.3790e-01, -1.0072e+00], + [ 2.4262e-01, -3.3350e-01]]]]]], + + + + + + [[[[[[-3.3302e-01, -8.5371e-01], + [-2.6683e+00, 5.3133e-01]], + + [[ 7.0524e-01, 3.1367e-01], + [-1.5984e-02, 2.4901e-01]]], + + + [[[-1.8934e-01, -2.2555e+00], + [ 7.7969e-01, -6.6572e-01]], + + [[ 2.2181e+00, 2.4506e-01], + [ 1.0696e+00, 3.3704e-01]]]], + + + + [[[[-1.1058e-04, 2.4665e+00], + [-5.5239e-02, -5.1851e-01]], + + [[ 2.1208e-02, -1.5241e+00], + [ 5.0546e-01, 4.6546e-01]]], + + + [[[ 6.8653e-01, -4.5201e-01], + [ 1.8622e-01, -2.5306e-01]], + + [[ 1.3076e-01, -1.1173e+00], + [ 1.2160e+00, 2.0257e+00]]]]], + + + + + [[[[[-1.1573e+00, -4.6983e-01], + [ 1.0154e+00, -2.7183e-01]], + + [[ 7.6097e-01, 8.3070e-02], + [-5.6702e-01, -3.1605e-01]]], + + + [[[-9.3585e-01, -1.2475e+00], + [ 1.3100e+00, 2.1488e+00]], + + [[-1.9884e-01, -1.3029e+00], + [-9.4814e-02, 7.2414e-01]]]], + + + + [[[[ 3.9301e-01, 1.0673e-01], + [-6.4631e-01, 3.2812e-01]], + + [[-4.0081e-01, -6.2356e-01], + [ 1.0353e+00, -9.1513e-01]]], + + + [[[-7.5020e-01, 9.6213e-01], + [ 2.4811e-01, -7.6768e-01]], + + [[ 9.8705e-01, -1.9507e-01], + [ 2.6377e-01, -3.0627e+00]]]]]]], device='cuda:0'), tensor([[[[[[[ 2.4538, -0.8922], + [ 1.3684, -0.5150]], + + [[ 1.3176, 0.5348], + [-0.0852, -1.7798]]], + + + [[[ 1.0275, 0.3459], + [-2.3377, -0.9182]], + + [[-1.0226, -1.2612], + [ 0.0259, 2.4485]]]], + + + + [[[[ 0.1809, 0.8438], + [-0.4643, 0.8054]], + + [[ 0.3220, -2.4891], + [ 1.0310, -0.0935]]], + + + [[[ 1.4551, -0.5795], + [-1.6939, -0.6869]], + + [[-0.2361, -1.3001], + [ 0.7484, -0.7862]]]]], + + + + + [[[[[-1.3405, -0.4824], + [ 1.2345, -0.5692]], + + [[-2.6157, 0.8614], + [-0.2938, -1.2741]]], + + + [[[-1.3419, 0.1435], + [ 0.7502, -1.0826]], + + [[ 0.7598, 0.4087], + [ 0.8646, -0.5177]]]], + + + + [[[[-0.2879, 0.6889], + [ 1.9037, -1.6276]], + + [[ 2.6030, 2.0182], + [-0.3841, -0.7770]]], + + + [[[ 0.4936, 0.5611], + [ 0.8565, 0.3895]], + + [[-1.0566, 0.5147], + [ 0.2500, 0.9725]]]]]], + + + + + + [[[[[[ 0.2410, 0.1544], + [ 0.4306, -0.5967]], + + [[ 0.1601, 0.7948], + [-1.5519, -0.0909]]], + + + [[[ 2.1980, 0.4862], + [ 0.4591, 1.7357]], + + [[ 0.8565, 1.6170], + [ 0.6806, -1.7946]]]], + + + + [[[[-0.5146, -0.8587], + [ 0.6568, 0.3039]], + + [[ 2.1688, -0.8730], + [-0.0051, 1.6054]]], + + + [[[-0.4713, 0.2377], + [-0.0813, -1.3609]], + + [[ 1.7947, -0.5005], + [ 1.0356, 0.4802]]]]], + + + + + [[[[[ 1.7519, 0.4037], + [-0.2573, 0.9032]], + + [[ 1.1248, 0.3404], + [ 0.2109, -0.9327]]], + + + [[[ 0.0837, 0.7593], + [-1.2091, 2.3370]], + + [[ 0.5729, -1.5820], + [ 1.5224, 1.8282]]]], + + + + [[[[ 0.4065, -0.9427], + [-0.1904, -1.8901]], + + [[ 0.0204, -1.1088], + [ 0.6798, -0.4446]]], + + + [[[ 0.9965, 1.0413], + [ 1.3706, 1.0881]], + + [[-0.3065, -0.5953], + [ 1.7009, 1.1090]]]]]]], device='cuda:0')) +(StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) +{} +. + +=============================== warnings summary =============================== +../../.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: 11 warnings + /home/fyf/.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details. + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + +-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html +================ 2 passed, 77 deselected, 11 warnings in 0.98s ================= diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 9e3506c7e..66984fe21 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -296,6 +296,8 @@ def gen_signature(self, code, with_block_pointer=False): # signature: strides, for each tensor arguments ndim = self.ndim + if ndim == 1: + code.writeline("# use fast path or simple linear tensor") if ndim > 0: # strides for inputs for i in range(schema.num_input_tensors()): @@ -1054,6 +1056,8 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: def codegen(self, code: IndentedBuffer): # the only runtime determined factor is the rank of the task space code = self.generate_imports(code) + print("codegen config") + print(self.config) if self.config.prefer_1d_tile: code = self.wrapper_gen.codegen_1d_tile(code) code = self.kernel_gen.codegen_1d_tile(code) @@ -1108,6 +1112,7 @@ def use_fast_path(tensors): def prepare_args(self, *args, **kwargs): # output allocation(when needed) # task simplification & task-rank infernece & input-output reinterpretation + print("prepare args") schema = self.fx outputs_that_need_allocation: List[int] = [] out_tensors = [] @@ -1142,6 +1147,7 @@ def prepare_args(self, *args, **kwargs): task_shape = (tensors[0].numel(),) strides = (1,) ndim = 1 + # print(args) # input args = tuple( ( StridedBuffer(item, task_shape, strides) @@ -1150,10 +1156,12 @@ def prepare_args(self, *args, **kwargs): ) for i, item in enumerate(args) ) + print(args) # 通常是两个tensor kwargs = { k: StridedBuffer(item, task_shape, strides) for k, item in kwargs.items() } + print(kwargs) for seq_id, output_id in enumerate(outputs_that_need_allocation): kwargs[f"out{output_id}"] = StridedBuffer( allocated_outputs[seq_id], task_shape, strides diff --git a/src/flag_gems/utils/tensor_wrapper.py b/src/flag_gems/utils/tensor_wrapper.py index 0b871108f..184639b83 100644 --- a/src/flag_gems/utils/tensor_wrapper.py +++ b/src/flag_gems/utils/tensor_wrapper.py @@ -121,3 +121,12 @@ def copy_(self, src): src_buffer = StridedBuffer(src) self.copy_(src_buffer) return self + + def __repr__(self): + return ( + f"StridedBuffer(shape={self.shape}, " + f"strides={self._strides}, " + f"dtype={self.dtype}, " + f"offset={self.offset}, " + f"device={self.device})" + ) From 1117b82a37afea5b376d2b4e2f629b3bf89b2633 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 7 Aug 2025 11:27:16 +0800 Subject: [PATCH 11/22] tmp --- lib/pointwise_dynamic.cpp | 148 ++++++++------------------------------ lib/utils.cpp | 4 +- 2 files changed, 33 insertions(+), 119 deletions(-) diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index 826abb530..d6f01be6c 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -5,115 +5,24 @@ #include "c10/cuda/CUDAStream.h" #include "triton_jit/triton_jit_function.h" -namespace pointwise_dynamic { - -// 构造函数 -// src/flag_gems/utils/pointwise_dynamic.py:prepare_args -/* -args = tuple( -( -StridedBuffer( - item, - task_shape, - broadcasted_stride(item.shape, item.stride(), task_shape), -) -if schema.is_tensor(i) -else item -) -for i, item in enumerate(args) -) -kwargs = { - k: StridedBuffer( - item, - task_shape, - broadcasted_stride(item.shape, item.stride(), task_shape), - ) - for k, item in kwargs.items() -} -*/ -pointwise_dynamic::StridedBuffer - - Shape - broadcast(const Shape& s1, const Shape& s2) { - if (s1.empty()) { - return s2; - } - if (s2.empty()) { - return s1; - } - - const Shape* _s1 = &s1; - const Shape* _s2 = &s2; - - if (_s1->size() < _s2->size()) { - std::swap(_s1, _s2); - } - - size_t r1 = _s1->size(); - size_t r2 = _s2->size(); - size_t d = r1 - r2; - - Shape s = *_s1; - - for (size_t i = 0; i < r2; ++i) { - if ((*_s1)[d + i] == 1) { - s[d + i] = (*_s2)[i]; - } else if ((*_s2)[i] == 1) { - s[d + i] = (*_s1)[d + i]; - } else if ((*_s2)[i] == (*_s1)[d + i]) { - s[d + i] = (*_s2)[i]; - } else { - std::string msg = "Unbroadcastable shapes: ("; - for (size_t j = 0; j < s1.size(); ++j) msg += std::to_string(s1[j]) + (j < s1.size() - 1 ? ", " : ""); - msg += ") and ("; - for (size_t j = 0; j < s2.size(); ++j) msg += std::to_string(s2[j]) + (j < s2.size() - 1 ? ", " : ""); - msg += ")"; - throw std::invalid_argument(msg); - } - } - - return s; -} - -template -Shape broadcast_shapes(const Iterable& shapes) { - if (std::empty(shapes)) { - return {}; - } - - auto it = std::begin(shapes); - Shape result_shape = *it; - ++it; - - for (; it != std::end(shapes); ++it) { - result_shape = broadcast(result_shape, *it); - } - - return result_shape; -} -}; // namespace pointwise_dynamic - namespace flag_gems { using namespace triton_jit; -int64_t cdiv(int64_t a, int64_t b) { - return (a + b - 1) / b; -} +using Shape = std::vector; +using Stride = std::vector; + at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // TODO: parse tensor meta info std::vector kernel_params; // 2 input - kernel_params.push(a_); - kernel_params.push(b_); + void* a_ptr = a_.data_ptr(); + void* b_ptr = b_.data_ptr(); + kernel_params.push_back(&a_ptr); + kernel_params.push_back(&b_ptr); + // 1 output - at::Tensor out = at::empty(a.sizes(), a.options()); - kernel_params.push(&out); - // if input和output都连续 - // 或者stride相同和第一个tensor torch.ops.aten.is_non_overlapping_and_dense - // 但是后者不是都连续,为什么stride=1,如连续的ab转置,我们可以忽略它的stride,只计算element wise就行 - // 但返回的时候,是不是要somehow拿回它的stride,不过这可能是python端里的问题 + at::Tensor out = at::empty(a_.sizes(), a_.options()); + kernel_params.push_back(&out); std::vector tensors = {a_, b_, out}; - // WrapperGenerator: gen_kernel_launch - // KernelGenerator: if (pointwise_dynamic::use_fast_path(tensors)) { int task_shape = tensors[0].numel(); void* task_shape_ptr = &task_shape; @@ -125,18 +34,18 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // push args // stride for input kernel_params.push(stride_ptr); - kernel_params.push(fast_path_stride_order_ptr); - kernel_params.push(stride_ptr); - kernel_params.push(fast_path_stride_order_ptr); + kernel_params.push_back(fast_path_stride_order_ptr); + kernel_params.push_back(stride_ptr); + kernel_params.push_back(fast_path_stride_order_ptr); // stride for output - kernel_params.push(stride_ptr); + kernel_params.push_back(stride_ptr); // task_space -> shape_args... shape = out0.shape // use fast path需要考虑shape吗 // prepare args里设置 task_shape = (tensors[0].numel(),) - kernel_params.push(task_shape_ptr); + kernel_params.push_back(task_shape_ptr); // num_tasks -> num_tasks = out0.numel() - kernel_params.push(task_shape_ptr); + kernel_params.push_back(task_shape_ptr); } else { // TODO // stride for input/output @@ -164,29 +73,31 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { kernel_params.push_back(&stride); } } + void* global_scratch = nullptr; + kernel_params.push_back(&global_scratch); // # tile size & tiles_per_cta, gsl style // tile_sizes = heuristics_for_tile_size(512, *shape) int64_t tile_sizes = 1024; - int64_t num_tiles = cdiv(task_shape, tile_sizes); // aka num blocks + int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks // num_ctas = min(65536, num_tiles) int64_t num_ctas = std::min(65536, num_tiles); // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) - int64_t tiles_per_cta = cdiv(num_tiles, num_ctas); + int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); // one_tile_per_cta = tiles_per_cta==1 bool one_tile_per_cta = (tiles_per_cta == 1); // get function std::array is_tensor; checkIfScalar(scalar_tensor, vector_tensor, is_tensor); - const TritonKernel kernel; + TritonJITFunction f; if (is_tensor[0] && is_tensor[1]) { - &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), - "add_func"); + f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func"); } else if (is_tensor[0] && !is_tensor[1]) { - &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), - "add_func_tensor_scalar"); + f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_tensor_scalar"); } else if (!is_tensor[0] && is_tensor[1]) { - &f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), - "add_func_scalar_tensor"); + f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), + "add_func_scalar_tensor"); } else { return a_ + b_; } @@ -195,7 +106,10 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { CUstream raw_stream = static_cast(stream.stream()); const int num_warps = 8; const int num_stages = 1; - f(stream, num_tiles, 1, 1, num_warps, num_stages, kernel_params); + + std::string signature = "*fp32:16,*fp32:16,*fp32:16,i64,1024"; + f.launch_with_raw_args(raw_stream, num_tiles, 1, 1, num_warps, num_stages, signature, kernel_params.data()); + return out; } }; // namespace flag_gems diff --git a/lib/utils.cpp b/lib/utils.cpp index 17ffd3bb7..a4a111179 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -105,7 +105,7 @@ void checkIfScalar(const torch::Tensor& tensor1, is_tensor[0] = (tensor1.dim() == 0); is_tensor[1] = (tensor2.dim() == 0); } - +/* class StridedBuffer { public: StridedBuffer(const torch::Tensor& base, @@ -208,7 +208,7 @@ class StridedBuffer { torch::Device device_; int64_t ndim_; }; - +*/ bool broadcastable_to(const Shape& shape, const Shape& new_shape) { int r1 = shape.size(); int r2 = new_shape.size(); From 3685efb421da3b934f97a51c6504e7fa447f5f10 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Fri, 8 Aug 2025 15:07:24 +0800 Subject: [PATCH 12/22] sum cpp warper --- ctests/test_triton_reduction.cpp | 44 +++++++++- include/flag_gems/utils.h | 1 + lib/sum.cpp | 144 +++++++++++++++++++++++-------- lib/utils.cpp | 23 +++++ 4 files changed, 171 insertions(+), 41 deletions(-) diff --git a/ctests/test_triton_reduction.cpp b/ctests/test_triton_reduction.cpp index a2e84c161..7802d5afb 100644 --- a/ctests/test_triton_reduction.cpp +++ b/ctests/test_triton_reduction.cpp @@ -16,17 +16,55 @@ TEST(reduction_op_test, sum) { EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-5, 1e-8)); } -TEST(reduction_op_test, sum_dim) { +TEST(reduction_op_test, sum_dim_to_sum) { + const torch::Device device(torch::kCUDA, 0); + torch::Tensor a = torch::randn({32, 1024}, device); + + torch::Tensor out_torch = at::sum(a, {at::IntArrayRef {}}, false, c10::nullopt); + torch::Tensor out_triton = flag_gems::sum_dim(a, {at::IntArrayRef {}}, false, c10::nullopt); + if (!torch::allclose(out_torch, out_triton, 1e-3, 1e-3)) { + LOG(INFO) << "Difference:\n" << out_torch - out_triton; + } + + EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-3, 1e-3)); +} + +TEST(reduction_op_test, sum_dim_inner) { const torch::Device device(torch::kCUDA, 0); torch::Tensor a = torch::randn({32, 1024}, device); torch::Tensor out_torch = at::sum(a, {1}); torch::Tensor out_triton = flag_gems::sum_dim(a, {1}); - if (!torch::allclose(out_torch, out_triton, 1e-5, 1e-8)) { + if (!torch::allclose(out_torch, out_triton, 1e-3, 1e-3)) { LOG(INFO) << "Difference:\n" << out_torch - out_triton; } - EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-5, 1e-8)); + EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-3, 1e-3)); +} + +TEST(reduction_op_test, sum_dim_non_inner) { + const torch::Device device(torch::kCUDA, 0); + torch::Tensor a = torch::randn({32, 1024, 32}, device); + + torch::Tensor out_torch = at::sum(a, {1}); + torch::Tensor out_triton = flag_gems::sum_dim(a, {1}); + if (!torch::allclose(out_torch, out_triton, 1e-3, 1e-3)) { + LOG(INFO) << "Difference:\n" << out_torch - out_triton; + } + + EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-3, 1e-3)); +} + +TEST(reduction_op_test, sum_dim_multi) { + const torch::Device device(torch::kCUDA, 0); + torch::Tensor a = torch::randn({32, 1024, 32}, device); + + torch::Tensor out_torch = at::sum(a, {2, 0}); + torch::Tensor out_triton = flag_gems::sum_dim(a, {2, 0}); + if (!torch::allclose(out_torch, out_triton, 1e-3, 1e-3)) { + LOG(INFO) << "Difference:\n" << out_torch - out_triton; + } + EXPECT_TRUE(torch::allclose(out_torch, out_triton, 1e-3, 1e-3)); } TEST(reduction_op_test, nonzero) { diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index fd245cbcd..baa6d62ce 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -21,5 +21,6 @@ std::tuple permute_reduction_axes_right(const at:: int reduction_axis); std::tuple permute_reduction_axes_right( const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt); +std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis); int cdiv(int a, int b); } // namespace flag_gems::utils diff --git a/lib/sum.cpp b/lib/sum.cpp index ee4842066..0c64b53ab 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -45,54 +45,122 @@ at::Tensor sum_dim(const at::Tensor &self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + at::TensorOptions out_options = self.options(); + at::ScalarType out_dtype; + if (dtype.has_value()) { + out_dtype = dtype.value(); + } else { + out_dtype = self.dtype().toScalarType(); + ; + if (out_dtype == torch::kBool) { + out_dtype = torch::kInt64; + } + } + out_options = out_options.dtype(out_dtype); at::DimVector dims_ = at::native::make_dim_vector(dim, self.dim()); at::maybe_wrap_dims(dims_, self.dim()); at::DimVector shape = at::meta::get_reduction_shape(self, dims_, keepdim, false); - c10::ScalarType out_dtype = at::native::get_dtype_from_self(self, dtype, true); - at::Tensor out = at::empty(shape, self.options()); + at::Tensor out = at::empty(at::IntArrayRef(shape), out_options); + out = out.contiguous(); - auto [permuted_self, non_reduction_size, reduction_size] = utils::permute_reduction_axes_right(self, dims_); - permuted_self = permuted_self.contiguous(); - - /* signature to remind yourself - def sum_kernel( - in_ptr, - out_ptr, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, - ): - */ - const TritonJITFunction &f = - TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), - "sum_kernel"); - - // add utility to build this automatically + if (!dim.has_value() || dim->empty()) { + if (!keepdim) { + return flag_gems::sum(self, std::optional {}); + } else { + at::Tensor result = flag_gems::sum(self, dtype); + return result.reshape(std::vector(self.dim(), 1)); + } + } int64_t tile_m = 4; int64_t tile_n = 512; + int64_t tile_k = 4; const int num_warps = 8; const int num_stages = 2; - const unsigned int num_blocks = (non_reduction_size + tile_m - 1) / tile_m; - - c10::DeviceGuard guard(out.device()); + c10::DeviceGuard guard(self.device()); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream raw_stream = static_cast(stream.stream()); - f(raw_stream, - num_blocks, - 1, - 1, - num_warps, - num_stages, - permuted_self, - out, - non_reduction_size, - reduction_size, - tile_m, - tile_n, - num_stages); - return out; + at::Tensor self_contiguous = self.contiguous(); + LOG(INFO) << "dims_.size()" << dims_.size(); + if (dims_.size() == 1) { + int64_t selected_dim = dims_[0]; + // M, N, K in python sum_dim_comm + auto [non_reduction_size, reduction_size, remain_size] = utils::parse_reduction_axes(self, selected_dim); + bool ONE_TILE_PER_CTA = (tile_n >= reduction_size); + if (remain_size > 1) { + const TritonJITFunction &f = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), + "sum_dim_kernel_non_inner"); + f(raw_stream, + non_reduction_size, + utils::cdiv(remain_size, tile_k), + 1, + num_warps, + num_stages, + out, + self_contiguous, + non_reduction_size, + reduction_size, + remain_size, + tile_n, + tile_k, + ONE_TILE_PER_CTA); + } else { + LOG(INFO) << "K=1"; + auto [non_reduction_size, reduction_size, remain_size] = + utils::parse_reduction_axes(self, selected_dim); + const TritonJITFunction &f = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), + "sum_dim_kernel_inner"); + f(raw_stream, + non_reduction_size, + 1, + 1, + num_warps, + num_stages, + out, + self_contiguous, + non_reduction_size, + reduction_size, + tile_n, + ONE_TILE_PER_CTA); + } + return out; + } else { + auto [permuted_self, non_reduction_size, reduction_size] = + utils::permute_reduction_axes_right(self, dims_); + const unsigned int num_blocks = (non_reduction_size + tile_m - 1) / tile_m; + permuted_self = permuted_self.contiguous(); + /* signature to remind yourself + def sum_kernel( + in_ptr, + out_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + ): + */ + const TritonJITFunction &f = + TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "sum.py"), + "sum_dim_kernel"); + c10::DeviceGuard guard(out.device()); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + f(raw_stream, + num_blocks, + 1, + 1, + num_warps, + num_stages, + permuted_self, + out, + non_reduction_size, + reduction_size, + tile_m, + tile_n); + return out; + } } } // namespace flag_gems diff --git a/lib/utils.cpp b/lib/utils.cpp index 4ed35665d..0feb6ecca 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -80,6 +80,7 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2) { return true; } + std::tuple permute_reduction_axes_right( const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt) { int64_t dim = tensor.dim(); @@ -130,6 +131,28 @@ std::tuple permute_reduction_axes_right(const at:: return {tensor.permute(permute_order), non_reduction_size, reduction_size}; } + +std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis) { + int64_t dim = tensor.dim(); + c10::DimVector left_axes, right_axes, remain_axes; // 声明remain_axes + int64_t non_reduction_size = 1; + int64_t reduction_size = 1; + int64_t remain_size = 1; + + for (int64_t i = 0; i < dim; ++i) { + if (i < reduction_axis) { // 规约轴左边的维度 + left_axes.push_back(i); + non_reduction_size *= tensor.size(i); + } else if (i == reduction_axis) { // 规约轴本身 + right_axes.push_back(i); // 将规约轴放在中间 + reduction_size *= tensor.size(i); + } else { // 规约轴右边的维度 + remain_axes.push_back(i); + remain_size *= tensor.size(i); + } + } + return {non_reduction_size, reduction_size, remain_size}; +} int cdiv(int a, int b) { return (a + b - 1) / b; } From 24c4e400d7fb3a0c5f503d719ec3d9ee719fccee Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Fri, 8 Aug 2025 15:17:53 +0800 Subject: [PATCH 13/22] remove useless comment --- lib/sum.cpp | 2 -- lib/utils.cpp | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lib/sum.cpp b/lib/sum.cpp index 0c64b53ab..00226006e 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -80,7 +80,6 @@ at::Tensor sum_dim(const at::Tensor &self, c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); CUstream raw_stream = static_cast(stream.stream()); at::Tensor self_contiguous = self.contiguous(); - LOG(INFO) << "dims_.size()" << dims_.size(); if (dims_.size() == 1) { int64_t selected_dim = dims_[0]; // M, N, K in python sum_dim_comm @@ -105,7 +104,6 @@ at::Tensor sum_dim(const at::Tensor &self, tile_k, ONE_TILE_PER_CTA); } else { - LOG(INFO) << "K=1"; auto [non_reduction_size, reduction_size, remain_size] = utils::parse_reduction_axes(self, selected_dim); const TritonJITFunction &f = diff --git a/lib/utils.cpp b/lib/utils.cpp index 0feb6ecca..7b6b240ef 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -134,19 +134,19 @@ std::tuple permute_reduction_axes_right(const at:: std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis) { int64_t dim = tensor.dim(); - c10::DimVector left_axes, right_axes, remain_axes; // 声明remain_axes + c10::DimVector left_axes, right_axes, remain_axes; int64_t non_reduction_size = 1; int64_t reduction_size = 1; int64_t remain_size = 1; for (int64_t i = 0; i < dim; ++i) { - if (i < reduction_axis) { // 规约轴左边的维度 + if (i < reduction_axis) { left_axes.push_back(i); non_reduction_size *= tensor.size(i); - } else if (i == reduction_axis) { // 规约轴本身 - right_axes.push_back(i); // 将规约轴放在中间 + } else if (i == reduction_axis) { + right_axes.push_back(i); reduction_size *= tensor.size(i); - } else { // 规约轴右边的维度 + } else { remain_axes.push_back(i); remain_size *= tensor.size(i); } From e8e60155bb24cf8629545511b3e75b2e5a8f128a Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Fri, 8 Aug 2025 16:03:59 +0800 Subject: [PATCH 14/22] tmp update for merge --- include/flag_gems/utils.h | 12 +++++++++++- lib/pointwise_dynamic.cpp | 29 ++++------------------------- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index 3aa08909d..8dd96f076 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -17,4 +17,14 @@ std::filesystem::path get_triton_src_path(); std::filesystem::path get_flag_gems_src_path(); int64_t next_power_of_2(int64_t n); bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); -} // namespace flag_gems::utils +}; // namespace flag_gems::utils + +namespace flag_gems::pointwise_dynamic { +using Shape = std::vector; +using Stride = std::vector; +bool broadcastable_to(const Shape& shape, const Shape& new_shape); +Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape); +bool all_the_same_shape(const std::vector& tensors); +bool all_c_contiguous(const std::vector& tensors); +bool use_fast_path(const std::vector& tensors); +}; // namespace flag_gems::pointwise_dynamic diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index d6f01be6c..b1e1d6f26 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -30,10 +30,10 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { void* stride_ptr = &stride; int ndim = 1; int fast_path_stride_order = 0; - void* fast_path_stride_order_ptr = &fast_path_stride_order - // push args - // stride for input - kernel_params.push(stride_ptr); + void* fast_path_stride_order_ptr = &fast_path_stride_order; + // push args + // stride for input + kernel_params.push(stride_ptr); kernel_params.push_back(fast_path_stride_order_ptr); kernel_params.push_back(stride_ptr); kernel_params.push_back(fast_path_stride_order_ptr); @@ -51,27 +51,6 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // stride for input/output // calculate task space // shapes = tuple(item.shape for item in in_tensors), - std::vector shapes; - shapes.reserve(2); - for (const auto& tensor : in_tensors) { - shapes.push_back(tensor.shape()); - } - Shape task_shape = broadcast_shapes(shapes); - int64_t ndim = task_shape.size(); - // task_shape = broadcast_shapes(shapes) - // get stride, TODO,using ndim as python warpper - auto a_strides = a_.strides(); - for (int64_t stride : a_strides) { - kernel_params.push_back(&stride); - } - auto b_strides = b_.strides(); - for (int64_t stride : b_strides) { - kernel_params.push_back(&stride); - } - auto out_strides = out.strides(); - for (int64_t stride : out_strides) { - kernel_params.push_back(&stride); - } } void* global_scratch = nullptr; kernel_params.push_back(&global_scratch); From 90228784e0250a1060afe1f285bb65396e611553 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Fri, 8 Aug 2025 17:04:32 +0800 Subject: [PATCH 15/22] pointwise dynamic add finish use fast path version --- include/flag_gems/utils.h | 19 ++++++++++--------- lib/CMakeLists.txt | 1 - lib/pointwise_dynamic.cpp | 22 ++++++++++++++-------- lib/utils.cpp | 15 --------------- 4 files changed, 24 insertions(+), 33 deletions(-) diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index a45394b0a..f638e5dab 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -11,27 +11,28 @@ #include "torch/torch.h" namespace flag_gems::utils { - +using Shape = std::vector; std::filesystem::path get_path_of_this_library(); std::filesystem::path get_triton_src_path(); std::filesystem::path get_flag_gems_src_path(); int64_t next_power_of_2(int64_t n); +std::tuple permute_reduction_axes_right(const at::Tensor& tensor, + int reduction_axis); +std::tuple permute_reduction_axes_right( + const at::Tensor& tensor, at::OptionalIntArrayRef reduction_axes_opt); +std::tuple parse_reduction_axes(const at::Tensor& tensor, int reduction_axis); +int cdiv(int a, int b); bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); }; // namespace flag_gems::utils namespace flag_gems::pointwise_dynamic { using Shape = std::vector; using Stride = std::vector; -bool broadcastable_to(const Shape& shape, const Shape& new_shape); Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape); bool all_the_same_shape(const std::vector& tensors); bool all_c_contiguous(const std::vector& tensors); bool use_fast_path(const std::vector& tensors); -std::tuple permute_reduction_axes_right(const at::Tensor& tensor, - int reduction_axis); -std::tuple permute_reduction_axes_right( - const at::Tensor& tensor, at::OptionalIntArrayRef reduction_axes_opt); -std::tuple parse_reduction_axes(const at::Tensor& tensor, int reduction_axis); -int cdiv(int a, int b); - +void checkIfScalar(const torch::Tensor& tensor1, + const torch::Tensor& tensor2, + std::array& is_tensor); }; // namespace flag_gems::pointwise_dynamic diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index ee06fbdaf..180ad60db 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -2,7 +2,6 @@ add_library(operators SHARED zeros.cpp utils.cpp - add.cpp sum.cpp max.cpp mm.cpp diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index b1e1d6f26..3f41cdb59 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -23,8 +23,9 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { at::Tensor out = at::empty(a_.sizes(), a_.options()); kernel_params.push_back(&out); std::vector tensors = {a_, b_, out}; + int task_shape; if (pointwise_dynamic::use_fast_path(tensors)) { - int task_shape = tensors[0].numel(); + task_shape = tensors[0].numel(); void* task_shape_ptr = &task_shape; int stride = 1; void* stride_ptr = &stride; @@ -33,7 +34,7 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { void* fast_path_stride_order_ptr = &fast_path_stride_order; // push args // stride for input - kernel_params.push(stride_ptr); + kernel_params.push_back(stride_ptr); kernel_params.push_back(fast_path_stride_order_ptr); kernel_params.push_back(stride_ptr); kernel_params.push_back(fast_path_stride_order_ptr); @@ -41,8 +42,6 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { kernel_params.push_back(stride_ptr); // task_space -> shape_args... shape = out0.shape - // use fast path需要考虑shape吗 - // prepare args里设置 task_shape = (tensors[0].numel(),) kernel_params.push_back(task_shape_ptr); // num_tasks -> num_tasks = out0.numel() kernel_params.push_back(task_shape_ptr); @@ -59,15 +58,15 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { int64_t tile_sizes = 1024; int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks // num_ctas = min(65536, num_tiles) - int64_t num_ctas = std::min(65536, num_tiles); + int64_t num_ctas = std::min(static_cast(65536), num_tiles); // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); // one_tile_per_cta = tiles_per_cta==1 bool one_tile_per_cta = (tiles_per_cta == 1); // get function std::array is_tensor; - checkIfScalar(scalar_tensor, vector_tensor, is_tensor); - TritonJITFunction f; + pointwise_dynamic::checkIfScalar(a_, b_, is_tensor); + std::optional f; if (is_tensor[0] && is_tensor[1]) { f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func"); @@ -87,7 +86,14 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { const int num_stages = 1; std::string signature = "*fp32:16,*fp32:16,*fp32:16,i64,1024"; - f.launch_with_raw_args(raw_stream, num_tiles, 1, 1, num_warps, num_stages, signature, kernel_params.data()); + f->launch_with_raw_args(raw_stream, + num_tiles, + 1, + 1, + num_warps, + num_stages, + signature, + kernel_params.data()); return out; } diff --git a/lib/utils.cpp b/lib/utils.cpp index 98b3c4b16..31f7d293c 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -285,21 +285,6 @@ class StridedBuffer { int64_t ndim_; }; */ -bool broadcastable_to(const Shape& shape, const Shape& new_shape) { - int r1 = shape.size(); - int r2 = new_shape.size(); - int min_rank = std::min(r1, r2); - - for (int i = 1; i <= min_rank; ++i) { - int dim1 = shape[r1 - i]; - int dim2 = new_shape[r2 - i]; - if (dim1 != dim2 && dim1 != 1 && dim2 != 1) { - return false; - } - } - return true; -} - Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape) { assert(broadcastable_to(shape, new_shape) && "Shapes are not broadcastable."); From aedc21c2881d14a922297600a348d30dcfe75114 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Tue, 12 Aug 2025 10:09:26 +0800 Subject: [PATCH 16/22] tmp --- ctests/test_triton_pointwise.cpp | 14 +- demo.py | 17 +- include/flag_gems/utils.h | 42 ++++- lib/add.cpp | 102 ------------ lib/pointwise_dynamic.cpp | 105 ++++++++++-- lib/sum.cpp | 4 +- lib/utils.cpp | 193 ++++++++++++----------- src/flag_gems/csrc/aten_patch.cpp | 1 + src/flag_gems/utils/pointwise_dynamic.py | 9 +- src/flag_gems/utils/shape_utils.py | 4 +- 10 files changed, 276 insertions(+), 215 deletions(-) delete mode 100644 lib/add.cpp diff --git a/ctests/test_triton_pointwise.cpp b/ctests/test_triton_pointwise.cpp index 664896cf1..a06d5b1a4 100644 --- a/ctests/test_triton_pointwise.cpp +++ b/ctests/test_triton_pointwise.cpp @@ -2,7 +2,7 @@ #include "flag_gems/operators.h" #include "torch/torch.h" -TEST(pointwise_op_test, add) { +TEST(pointwise_op_simple_test, add) { const torch::Device device(torch::kCUDA, 0); torch::Tensor a = torch::randn({10, 10}, device); torch::Tensor b = torch::randn({10, 10}, device); @@ -12,3 +12,15 @@ TEST(pointwise_op_test, add) { EXPECT_TRUE(torch::allclose(out_torch, out_triton)); } + +TEST(pointwise_op_broadcast_test, add) { + const torch::Device device(torch::kCUDA, 0); + torch::Tensor a = torch::randn({30, 50}, device); + torch::Tensor b = torch::randn({50}, device); + + torch::Tensor out_torch = a + b; + torch::Tensor out_triton = flag_gems::add_tensor(a, b); + std::cout << "out_torch sizes: " << out_torch.sizes() << std::endl; + std::cout << "out_triton sizes: " << out_triton.sizes() << std::endl; + EXPECT_TRUE(torch::allclose(out_torch, out_triton)); +} diff --git a/demo.py b/demo.py index 81dd32f43..b6af2bc8e 100644 --- a/demo.py +++ b/demo.py @@ -2,8 +2,19 @@ import flag_gems -shape = (4,) -x = torch.randn(shape, device=flag_gems.device) -y = torch.randn_like(x) +# 创建一个形状为 (3, 4) 的张量 x +shape_x = (3, 4) +x = torch.randn(shape_x, device=flag_gems.device) + +# 创建一个形状为 (1, 4) 的张量 y +shape_y = (1, 4) +y = torch.randn(shape_y, device=flag_gems.device) + +# 使用 flag_gems 的上下文 with flag_gems.use_gems(): + # 这里的 y 将会被广播到形状 (3, 4) 以匹配 x 的形状 C = torch.add(x, y) + +print("x:", x) +print("y:", y) +print("C:", C) diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index f638e5dab..ff44bce19 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -11,7 +11,7 @@ #include "torch/torch.h" namespace flag_gems::utils { -using Shape = std::vector; +using Shape = c10::IntArrayRef; std::filesystem::path get_path_of_this_library(); std::filesystem::path get_triton_src_path(); std::filesystem::path get_flag_gems_src_path(); @@ -26,13 +26,47 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); }; // namespace flag_gems::utils namespace flag_gems::pointwise_dynamic { -using Shape = std::vector; -using Stride = std::vector; -Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape); +using ShapeR = c10::IntArrayRef; +using ShapeW = std::vector; +using StrideR = c10::IntArrayRef; +using StrideW = std::vector; bool all_the_same_shape(const std::vector& tensors); bool all_c_contiguous(const std::vector& tensors); bool use_fast_path(const std::vector& tensors); void checkIfScalar(const torch::Tensor& tensor1, const torch::Tensor& tensor2, std::array& is_tensor); +ShapeW broadcast(const ShapeR& s1, const ShapeR& s2); +ShapeW broadcast_shapes(const std::vector& shapes); +StrideW broadcasted_stride(const ShapeR& shape, const StrideR& stride, const ShapeR& new_shape); +void print_shapes(const std::vector& shapes); +StrideW stride_order(const StrideR& strides); +StrideR create_stride_r_view(const StrideW& stride_w); +class StridedBuffer { + public: + StridedBuffer(const torch::Tensor& base, + c10::optional shape = c10::nullopt, + c10::optional strides = c10::nullopt, + int64_t offset = 0); + + const c10::IntArrayRef strides() const; + const c10::IntArrayRef sizes() const; + long numel() const; + int64_t dim() const; + const torch::Tensor& unwrap() const; + void* data_ptr() const; + torch::Storage untyped_storage() const; + StridedBuffer clone() const; + StridedBuffer& copy_(const StridedBuffer& src); + StridedBuffer& copy_(const torch::Tensor& src); + long offset() const; + + private: + torch::Tensor base_; + void* data_ptr_; + int64_t offset_; + std::vector shape_; + std::vector strides_; + int64_t ndim_; +}; }; // namespace flag_gems::pointwise_dynamic diff --git a/lib/add.cpp b/lib/add.cpp deleted file mode 100644 index 102e4ec73..000000000 --- a/lib/add.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "flag_gems/operators.h" -#include "flag_gems/utils.h" - -#include -#include "c10/cuda/CUDAStream.h" -#include "triton_jit/triton_jit_function.h" - -namespace flag_gems { -using namespace triton_jit; - -/* -at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { - auto res = torch::broadcast_tensors({a_, b_}); - res[0] = res[0].contiguous(); - res[1] = res[1].contiguous(); - const at::Tensor &a = res[0]; - const at::Tensor &b = res[1]; - - at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); - at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); - - const TritonJITFunction &f = - TritonJITFunction::getInstance(std::string(utils::get_triton_src_path() / "binary_add.py"), - "binary_pointwise_kernel"); - - // add utility to build this automatically - int64_t tile_size = 1024; - const int num_warps = 8; - const int num_stages = 1; - int64_t n = out.numel(); - const unsigned int num_blocks = (n + tile_size - 1) / tile_size; - - // getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - c10::DeviceGuard guard(out.device()); - CUstream raw_stream = static_cast(stream.stream()); - f(raw_stream, num_blocks, 1, 1, num_warps, num_stages, a, b, out, n, tile_size); - return out; -} -*/ - -at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { - // 1. Broadcasting and ensuring contiguous memory layout - auto res = torch::broadcast_tensors({a_, b_}); - const at::Tensor& a = res[0].contiguous(); - const at::Tensor& b = res[1].contiguous(); - - // 2. Determine output dtype and create output tensor - at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); - at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); - - // 3. Get the TritonJITFunction instance - const TritonJITFunction& f = - TritonJITFunction::getInstance(std::string(utils::get_triton_src_path() / "binary_add.py"), - "binary_pointwise_kernel"); - - // 4. Manually prepare the raw argument list (void**) and the signature - int64_t tile_size = 1024; - int64_t n = out.numel(); - - // This is the raw C-style argument array that the CUDA kernel expects. - // It contains the addresses of all the kernel's parameters. - std::vector raw_args_list; - - // Push the data pointers for the tensors. - raw_args_list.push_back(a.data_ptr()); - raw_args_list.push_back(b.data_ptr()); - raw_args_list.push_back(out.data_ptr()); - - // Push the addresses of scalar values. - // NOTE: The scalars 'n' and 'tile_size' must have their addresses taken. - // This is why we use references or variables. - raw_args_list.push_back(&n); - raw_args_list.push_back(&tile_size); - - // 5. Manually generate the signature string - // This must match the kernel's type-based signature for overload resolution. - // This is an example; the exact signature depends on the kernel definition. - std::string signature = "tl.pointer_type,tl.pointer_type,tl.pointer_type,int64,int64"; - - // 6. Set up the launch configuration - const int num_warps = 8; - const int num_stages = 1; - const unsigned int num_blocks = (n + tile_size - 1) / tile_size; - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - c10::DeviceGuard guard(out.device()); - CUstream raw_stream = static_cast(stream.stream()); - - // 7. Launch the kernel using the raw argument list - f.launch_with_raw_args(raw_stream, - num_blocks, - 1, - 1, - num_warps, - num_stages, - signature, - raw_args_list.data()); - - return out; -} - -} // namespace flag_gems diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index 3f41cdb59..8d7d43271 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -3,28 +3,39 @@ #include #include "c10/cuda/CUDAStream.h" +#include "c10/util/Logging.h" #include "triton_jit/triton_jit_function.h" namespace flag_gems { using namespace triton_jit; -using Shape = std::vector; -using Stride = std::vector; - +using Shape = c10::IntArrayRef; +using Stride = c10::IntArrayRef; at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // TODO: parse tensor meta info + // LOG(INFO)<< fmt::format("add tensor"); + std::cout << "add tensor"; std::vector kernel_params; // 2 input void* a_ptr = a_.data_ptr(); void* b_ptr = b_.data_ptr(); kernel_params.push_back(&a_ptr); kernel_params.push_back(&b_ptr); + int64_t val0 = 1; + kernel_params.push_back(&val0); - // 1 output - at::Tensor out = at::empty(a_.sizes(), a_.options()); + // calculate task_space + std::vector shapes; + shapes.push_back(a_.sizes()); + shapes.push_back(b_.sizes()); + pointwise_dynamic::ShapeW task_space = pointwise_dynamic::broadcast_shapes(shapes); + int ndim = task_space.size(); + // prepare output with size of task_space + at::Tensor out = at::empty(task_space); kernel_params.push_back(&out); std::vector tensors = {a_, b_, out}; int task_shape; if (pointwise_dynamic::use_fast_path(tensors)) { + std::cout << "use fast path"; task_shape = tensors[0].numel(); void* task_shape_ptr = &task_shape; int stride = 1; @@ -46,10 +57,85 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // num_tasks -> num_tasks = out0.numel() kernel_params.push_back(task_shape_ptr); } else { - // TODO - // stride for input/output - // calculate task space - // shapes = tuple(item.shape for item in in_tensors), + std::cout << "else"; + // broadcast tensor + // ndim = len(task_shape) + // shapes = tuple(item.shape for item in in_tensors) + // task_shape = broadcast_shapes(shapes) + // c10::IntArrayRef vs at::DimVector + + // broad tensor and warp with StridedBuffer + // TODO:确定copy机制是否高效 + pointwise_dynamic::StridedBuffer a = pointwise_dynamic::StridedBuffer( + a_, + task_shape, + pointwise_dynamic::broadcasted_stride(a_.sizes(), a_.strides(), task_shape)); + pointwise_dynamic::StridedBuffer b = pointwise_dynamic::StridedBuffer( + b_, + task_shape, + pointwise_dynamic::broadcasted_stride(b_.sizes(), b_.strides(), task_shape)); + + // input stride + const c10::IntArrayRef a_strides = a.strides(); + for (int i = 0; i < ndim; i++) { + kernel_params.push_back(const_cast(&a_strides[i])); + } + if (ndim >= 2) { + const pointwise_dynamic::StrideW a_strides_vec(a_strides.begin(), a_strides.end()); + std::vector order_vec = pointwise_dynamic::stride_order(a_strides_vec); + for (int i = 0; i < ndim; i++) { + long order_val = order_vec[i]; + kernel_params.push_back(const_cast(&order_val)); + } + } else { + pointwise_dynamic::StrideW zero_stride(1, 0); + void* zero_stride_ptr = zero_stride.data(); + kernel_params.push_back(&zero_stride_ptr); + } + + const c10::IntArrayRef b_strides = b.strides(); + for (int i = 0; i < ndim; i++) { + kernel_params.push_back(const_cast(&b_strides[i])); + } + if (ndim >= 2) { + const pointwise_dynamic::StrideW b_strides_vec(b_strides.begin(), b_strides.end()); + std::vector order_vec = pointwise_dynamic::stride_order(b_strides_vec); + for (int i = 0; i < ndim; i++) { + long order_val = order_vec[i]; + kernel_params.push_back(const_cast(&order_val)); + } + } else { + pointwise_dynamic::StrideW zero_stride(1, 0); + void* zero_stride_ptr = zero_stride.data(); + kernel_params.push_back(&zero_stride_ptr); + } + // output stride + // TODO:封装 push 1d tensor metadata的函数 + const c10::IntArrayRef output_strides = out.strides(); + for (int i = 0; i < ndim; i++) { + kernel_params.push_back(const_cast(&output_strides[i])); + } + if (ndim >= 2) { + const pointwise_dynamic::StrideW output_strides_vec(output_strides.begin(), output_strides.end()); + std::vector order_vec = pointwise_dynamic::stride_order(output_strides_vec); + for (int i = 0; i < ndim; i++) { + long order_val = order_vec[i]; + kernel_params.push_back(const_cast(&order_val)); + } + } else { + pointwise_dynamic::StrideW zero_stride(1, 0); + void* zero_stride_ptr = zero_stride.data(); + kernel_params.push_back(&zero_stride_ptr); + } + + // task space + for (int i = 0; i < ndim; i++) { + int64_t si = task_space[i]; + kernel_params.push_back(const_cast(&si)); + } + // num_task out的 + int64_t num_task = out.numel(); + kernel_params.push_back(const_cast(&num_task)); } void* global_scratch = nullptr; kernel_params.push_back(&global_scratch); @@ -67,6 +153,7 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { std::array is_tensor; pointwise_dynamic::checkIfScalar(a_, b_, is_tensor); std::optional f; + // TODO: code gen in c++ if (is_tensor[0] && is_tensor[1]) { f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func"); diff --git a/lib/sum.cpp b/lib/sum.cpp index 00226006e..485e2dd46 100644 --- a/lib/sum.cpp +++ b/lib/sum.cpp @@ -57,9 +57,9 @@ at::Tensor sum_dim(const at::Tensor &self, } } out_options = out_options.dtype(out_dtype); - at::DimVector dims_ = at::native::make_dim_vector(dim, self.dim()); + c10::DimVector dims_ = at::native::make_dim_vector(dim, self.dim()); at::maybe_wrap_dims(dims_, self.dim()); - at::DimVector shape = at::meta::get_reduction_shape(self, dims_, keepdim, false); + c10::IntArrayRef shape = at::meta::get_reduction_shape(self, dims_, keepdim, false); at::Tensor out = at::empty(at::IntArrayRef(shape), out_options); out = out.contiguous(); diff --git a/lib/utils.cpp b/lib/utils.cpp index 31f7d293c..682ea100d 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -173,126 +173,127 @@ int cdiv(int a, int b) { } // namespace flag_gems::utils namespace flag_gems::pointwise_dynamic { -using Shape = std::vector; -using Stride = std::vector; void checkIfScalar(const torch::Tensor& tensor1, const torch::Tensor& tensor2, std::array& is_tensor) { is_tensor[0] = (tensor1.dim() == 0); is_tensor[1] = (tensor2.dim() == 0); } -/* -class StridedBuffer { - public: - StridedBuffer(const torch::Tensor& base, - c10::optional shape = c10::nullopt, - c10::optional strides = c10::nullopt, - c10::optional dtype = c10::nullopt, - int64_t offset = 0) - : base_(base), - dtype_(dtype.has_value() ? dtype.value() : to_custom_dtype(base.dtype())), - offset_(offset) { - if (offset_ == 0) { - data_ptr_ = base_.data_ptr(); - } else { - // TODO kunlunxin case - data_ptr_ = static_cast(base_.data_ptr()) + base_.element_size() * offset_; - } - shape_ = shape.has_value() ? shape.value().vec() : base_.sizes().vec(); - strides_ = strides.has_value() ? strides.value().vec() : base_.strides().vec(); - device_ = base_.device(); - ndim_ = shape_.size(); +StridedBuffer::StridedBuffer(const torch::Tensor& base, + c10::optional shape, + c10::optional strides, + int64_t offset) + : base_(base.contiguous()), offset_(offset) { + if (offset_ == 0) { + data_ptr_ = base_.data_ptr(); + } else { + data_ptr_ = static_cast(base_.data_ptr()) + base_.element_size() * offset_; } + shape_ = shape.has_value() ? shape.value().vec() : base_.sizes().vec(); + strides_ = strides.has_value() ? strides.value().vec() : base_.strides().vec(); + ndim_ = shape_.size(); +} - const c10::IntArrayRef strides() const { - return strides_; - } +const c10::IntArrayRef StridedBuffer::strides() const { + return strides_; +} - const c10::IntArrayRef sizes() const { - return shape_; - } +const c10::IntArrayRef StridedBuffer::sizes() const { + return shape_; +} - size_t element_size() const { - return torch::elementSize(to_torch_dtype(dtype_)); +long StridedBuffer::numel() const { + long num = 1; + for (long s : shape_) { + num *= s; } + return num; +} - long numel() const { - long num = 1; - for (long s : shape_) { - num *= s; - } - return num; - } +int64_t StridedBuffer::dim() const { + return ndim_; +} - int64_t dim() const { - return ndim_; - } +const torch::Tensor& StridedBuffer::unwrap() const { + return base_; +} - const torch::Tensor& unwrap() const { - return base_; - } +void* StridedBuffer::data_ptr() const { + return data_ptr_; +} - void* data_ptr() const { - return data_ptr_; - } +torch::Storage StridedBuffer::untyped_storage() const { + return base_.storage(); +} - torch::Storage untyped_storage() const { - return base_.storage(); - } +StridedBuffer StridedBuffer::clone() const { + torch::Tensor cloned_base = base_.clone(); + return StridedBuffer(cloned_base, shape_, strides_, offset_); +} - StridedBuffer clone() const { - return StridedBuffer(base_.clone(), shape_, strides_, dtype_, offset_); - } +StridedBuffer& StridedBuffer::copy_(const StridedBuffer& src) { + torch::Tensor temp_dst = torch::empty_like(src.unwrap()); + temp_dst.copy_(src.unwrap()); - StridedBuffer& copy_(const StridedBuffer& src) { - base_.copy_(base_.new_empty(src.sizes(), src.strides()) - .as_strided(src.sizes(), src.strides()) - .copy_(src.unwrap())); - strides_ = src.strides(); - shape_ = src.sizes(); - dtype_ = src.dtype(); - offset_ = src.offset_; - data_ptr_ = src.data_ptr(); - - return *this; - } + base_ = temp_dst; + strides_ = src.strides().vec(); + shape_ = src.sizes().vec(); + offset_ = src.offset(); + data_ptr_ = base_.data_ptr(); - StridedBuffer& copy_(const torch::Tensor& src) { - StridedBuffer src_buffer(src); - return this->copy_(src_buffer); - } + return *this; +} + +StridedBuffer& StridedBuffer::copy_(const torch::Tensor& src) { + StridedBuffer src_buffer(src); + return this->copy_(src_buffer); +} + +long StridedBuffer::offset() const { + return offset_; +} + +ShapeW broadcast(const ShapeR& s1, const ShapeR& s2) { + long ndim = std::max(s1.size(), s2.size()); + ShapeW output_shape(ndim); + long p1 = s1.size() - 1; + long p2 = s2.size() - 1; - torch::Device device() const { - return device_; + for (long i = ndim - 1; i >= 0; --i) { + long d1 = (p1 >= 0) ? s1[p1] : 1; + long d2 = (p2 >= 0) ? s2[p2] : 1; + + if (d1 != d2 && d1 != 1 && d2 != 1) { + // 抛出异常或返回错误,因为形状不可广播 + throw std::runtime_error("Shapes are not broadcastable."); + } + output_shape[i] = std::max(d1, d2); + if (p1 >= 0) p1--; + if (p2 >= 0) p2--; } + return output_shape; +} - dType dtype() const { - return dtype_; +ShapeW broadcast_shapes(const std::vector& shapes) { + if (shapes.empty()) { + return {}; } - long offset() const { - return offset_; + ShapeW output_shape(shapes[0].begin(), shapes[0].end()); + for (size_t i = 1; i < shapes.size(); ++i) { + output_shape = broadcast(output_shape, shapes[i]); } + return output_shape; +} - private: - torch::Tensor base_; - dType dtype_; - void* data_ptr_; - int64_t offset_; - std::vector shape_; - std::vector strides_; - torch::Device device_; - int64_t ndim_; -}; -*/ -Stride broadcasted_stride(const Shape& shape, const Stride& stride, const Shape& new_shape) { +StrideW broadcasted_stride(const ShapeR& shape, const StrideR& stride, const ShapeR& new_shape) { assert(broadcastable_to(shape, new_shape) && "Shapes are not broadcastable."); int r1 = shape.size(); int r2 = new_shape.size(); int d = r2 - r1; - Stride new_stride(r2, 0); + StrideW new_stride(r2, 0); for (int i = 0; i < r1; ++i) { int new_dim_index = d + i; if (shape[i] == 1 && new_shape[new_dim_index] > 1) { @@ -350,4 +351,20 @@ bool use_fast_path(const std::vector& tensors) { } return all_the_same_stride(tensors) && tensors[0].is_non_overlapping_and_dense(); } +StrideW stride_order(const StrideR& strides) { + // Create a vector of indices from 0 to strides.size() - 1 + StrideW indices(strides.size()); + std::iota(indices.begin(), indices.end(), 0); + + // Sort the indices based on the absolute value of the corresponding stride + std::sort(indices.begin(), indices.end(), [&](int64_t i, int64_t j) { + return std::abs(strides[i]) < std::abs(strides[j]); + }); + + return indices; +} + +StrideR create_stride_r_view(const StrideW& stride_w) { + return StrideR(reinterpret_cast(stride_w.data()), stride_w.size()); +} }; // namespace flag_gems::pointwise_dynamic diff --git a/src/flag_gems/csrc/aten_patch.cpp b/src/flag_gems/csrc/aten_patch.cpp index 374e979c7..3c6932272 100644 --- a/src/flag_gems/csrc/aten_patch.cpp +++ b/src/flag_gems/csrc/aten_patch.cpp @@ -33,5 +33,6 @@ TORCH_LIBRARY_IMPL(aten, CUDA, m) { REGISTER_AND_LOG("max.dim", max_dim) REGISTER_AND_LOG("max", max) REGISTER_AND_LOG("sum", sum) + REGISTER_AND_LOG("add", add_tensor) } } // namespace flag_gems diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 66984fe21..e7fb0a976 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -296,8 +296,6 @@ def gen_signature(self, code, with_block_pointer=False): # signature: strides, for each tensor arguments ndim = self.ndim - if ndim == 1: - code.writeline("# use fast path or simple linear tensor") if ndim > 0: # strides for inputs for i in range(schema.num_input_tensors()): @@ -875,6 +873,7 @@ def gen_kernel_launch( with_block_pointer = self.config.prefer_block_pointer code.writeline("# kernel launch") + code.writeline("print('launch with wrapper')") for i in range(schema.num_input_tensors()): code.writeline(f"in{i}_strides = in{i}.stride()") if not with_block_pointer: @@ -1156,7 +1155,6 @@ def prepare_args(self, *args, **kwargs): ) for i, item in enumerate(args) ) - print(args) # 通常是两个tensor kwargs = { k: StridedBuffer(item, task_shape, strides) for k, item in kwargs.items() @@ -1171,7 +1169,6 @@ def prepare_args(self, *args, **kwargs): # tensor that is not broadcated, no attempts to simplify task, no reordering, # no dimenion collapsing shapes = tuple(item.shape for item in in_tensors) - task_shape = broadcast_shapes(shapes) if out_tensors: @@ -1200,6 +1197,7 @@ def prepare_args(self, *args, **kwargs): torch.empty(task_shape, dtype=dtype, device=device) for dtype in outputs_dtypes_for_allocation ] + print(args) args = tuple( ( StridedBuffer( @@ -1212,11 +1210,12 @@ def prepare_args(self, *args, **kwargs): ) for i, item in enumerate(args) ) + print(args) kwargs = { k: StridedBuffer( item, task_shape, - broadcasted_stride(item.shape, item.stride(), task_shape), + (item.shape, item.stride(), task_shape), ) for k, item in kwargs.items() } diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index 58c6009b1..50285af18 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -159,7 +159,9 @@ def ordered_stride(shape: Shape, order: Perm) -> Stride: def stride_order(strides): # we also handle negative strides - return sorted(range(len(strides)), key=lambda i: abs(strides[i])) + res = sorted(range(len(strides)), key=lambda i: abs(strides[i])) + print("stride_order", res) + return res def all_the_same_shape(tensors: Sequence[torch.Tensor]) -> bool: From 720192b3f66d19aadcd0498a399c449ff2674ae9 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Thu, 14 Aug 2025 10:34:43 +0800 Subject: [PATCH 17/22] tmp change --- lib/CMakeLists.txt | 8 + lib/pointwise_dynamic.cpp | 112 +- lib/utils.cpp | 6 +- pointwise.log | 960 ------------------ src/flag_gems/ops/sum.py | 34 - .../_kunlunxin/utils/pointwise_dynamic.py | 2 + src/flag_gems/utils/pointwise_dynamic.py | 6 +- src/flag_gems/utils/triton_lang_extension.py | 35 + 8 files changed, 134 insertions(+), 1029 deletions(-) delete mode 100644 pointwise.log diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 180ad60db..0eb657a20 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,11 @@ +find_package(Python COMPONENTS Interpreter Development REQUIRED) + +if(NOT Python_INCLUDE_DIRS OR NOT Python_LIBRARIES) + message(FATAL_ERROR "Python development files not found. Please ensure Python is installed and development headers are available.") +endif() + +include_directories(${Python_INCLUDE_DIRS}) + add_library(operators SHARED zeros.cpp diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index 8d7d43271..6a3da6a35 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -1,62 +1,93 @@ #include "flag_gems/operators.h" #include "flag_gems/utils.h" +#include +#include #include #include "c10/cuda/CUDAStream.h" #include "c10/util/Logging.h" +#include "pybind11/embed.h" +#include "triton_jit/pointwise_generator.h" #include "triton_jit/triton_jit_function.h" namespace flag_gems { using namespace triton_jit; -using Shape = c10::IntArrayRef; -using Stride = c10::IntArrayRef; + +/* +def add_func( + in0_ptr: tl.tensor, # of tl.pointer_type + in1_ptr: tl.tensor, # of tl.pointer_type + out0_ptr: tl.tensor, # of tl.pointer_type + in0_stride0: int, in0_stride1: int, # strides for in0 + in1_stride0: int, in1_stride1: int, # strides for in1 + out0_stride0: int, out0_stride1: int, # strides for out0 + s0: int, s1: int, # task_space + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, +): +*/ + +namespace py = pybind11; at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // TODO: parse tensor meta info // LOG(INFO)<< fmt::format("add tensor"); - std::cout << "add tensor"; + std::string signature; + std::cout << "add tensor\n"; std::vector kernel_params; // 2 input void* a_ptr = a_.data_ptr(); void* b_ptr = b_.data_ptr(); kernel_params.push_back(&a_ptr); + signature.append("*fp32:16,"); kernel_params.push_back(&b_ptr); - int64_t val0 = 1; - kernel_params.push_back(&val0); - - // calculate task_space - std::vector shapes; - shapes.push_back(a_.sizes()); - shapes.push_back(b_.sizes()); - pointwise_dynamic::ShapeW task_space = pointwise_dynamic::broadcast_shapes(shapes); - int ndim = task_space.size(); - // prepare output with size of task_space - at::Tensor out = at::empty(task_space); + signature.append("*fp32:16,"); + // TODO: use fast path没有这个,但 + // int64_t val0 = 1; + // signature.push("1,"); + // kernel_params.push_back(&val0); + int ndim; + at::Tensor out = at::empty_like(a_); kernel_params.push_back(&out); + signature.append("*fp32:16,"); std::vector tensors = {a_, b_, out}; int task_shape; if (pointwise_dynamic::use_fast_path(tensors)) { - std::cout << "use fast path"; - task_shape = tensors[0].numel(); + // prepare output with size of task_space + std::cout << "use fast path\n"; + task_shape = a_.numel(); void* task_shape_ptr = &task_shape; int stride = 1; void* stride_ptr = &stride; - int ndim = 1; + ndim = 1; int fast_path_stride_order = 0; void* fast_path_stride_order_ptr = &fast_path_stride_order; // push args // stride for input kernel_params.push_back(stride_ptr); - kernel_params.push_back(fast_path_stride_order_ptr); + signature.append("i64,"); + // kernel_params.push_back(fast_path_stride_order_ptr); kernel_params.push_back(stride_ptr); - kernel_params.push_back(fast_path_stride_order_ptr); + signature.append("i64,"); + // kernel_params.push_back(fast_path_stride_order_ptr); // stride for output kernel_params.push_back(stride_ptr); - + signature.append("i64,"); // task_space -> shape_args... shape = out0.shape kernel_params.push_back(task_shape_ptr); + signature.append("i64,"); // num_tasks -> num_tasks = out0.numel() kernel_params.push_back(task_shape_ptr); + signature.append("i64,"); } else { + // calculate task_space + std::vector shapes; + shapes.push_back(a_.sizes()); + shapes.push_back(b_.sizes()); + pointwise_dynamic::ShapeW task_space = pointwise_dynamic::broadcast_shapes(shapes); + ndim = task_space.size(); + // prepare output with size of task_space std::cout << "else"; // broadcast tensor // ndim = len(task_shape) @@ -137,33 +168,55 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { int64_t num_task = out.numel(); kernel_params.push_back(const_cast(&num_task)); } - void* global_scratch = nullptr; - kernel_params.push_back(&global_scratch); + /* + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, + */ // # tile size & tiles_per_cta, gsl style // tile_sizes = heuristics_for_tile_size(512, *shape) int64_t tile_sizes = 1024; int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks + // num_ctas = min(65536, num_tiles) int64_t num_ctas = std::min(static_cast(65536), num_tiles); // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); + kernel_params.push_back(reinterpret_cast(tiles_per_cta)); + signature.append("i64,"); + signature.append(std::to_string(tile_sizes)); + signature.append(","); // one_tile_per_cta = tiles_per_cta==1 bool one_tile_per_cta = (tiles_per_cta == 1); + signature.append(std::to_string(one_tile_per_cta)); + + void* global_scratch = nullptr; + kernel_params.push_back(&global_scratch); // get function - std::array is_tensor; - pointwise_dynamic::checkIfScalar(a_, b_, is_tensor); + std::array is_scalar; + pointwise_dynamic::checkIfScalar(a_, b_, is_scalar); std::optional f; // TODO: code gen in c++ - if (is_tensor[0] && is_tensor[1]) { - f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), - "add_func"); - } else if (is_tensor[0] && !is_tensor[1]) { + + auto ans_tuple = gen_add(ndim); + std::string kernel_name = std::get<0>(ans_tuple); + std::string file_path = std::get<1>(ans_tuple); + + std::cout << "file_path:" << file_path << std::endl; + + // TODO: 四种情况 + if (!is_scalar[0] && !is_scalar[1]) { + f = TritonJITFunction::getInstance(file_path, "add"); + } else if (!is_scalar[0] && is_scalar[1]) { + // TODO f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func_tensor_scalar"); - } else if (!is_tensor[0] && is_tensor[1]) { + } else if (is_scalar[0] && !is_scalar[1]) { + // TODO f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func_scalar_tensor"); } else { + std::cout << "else"; return a_ + b_; } c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); @@ -172,7 +225,6 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { const int num_warps = 8; const int num_stages = 1; - std::string signature = "*fp32:16,*fp32:16,*fp32:16,i64,1024"; f->launch_with_raw_args(raw_stream, num_tiles, 1, diff --git a/lib/utils.cpp b/lib/utils.cpp index 682ea100d..01da5ae91 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -175,9 +175,9 @@ int cdiv(int a, int b) { namespace flag_gems::pointwise_dynamic { void checkIfScalar(const torch::Tensor& tensor1, const torch::Tensor& tensor2, - std::array& is_tensor) { - is_tensor[0] = (tensor1.dim() == 0); - is_tensor[1] = (tensor2.dim() == 0); + std::array& is_scalar) { + is_scalar[0] = (tensor1.dim() == 0); + is_scalar[1] = (tensor2.dim() == 0); } StridedBuffer::StridedBuffer(const torch::Tensor& base, c10::optional shape, diff --git a/pointwise.log b/pointwise.log deleted file mode 100644 index 06b58f0bf..000000000 --- a/pointwise.log +++ /dev/null @@ -1,960 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.10.12, pytest-8.4.1, pluggy-1.6.0 -rootdir: /home/fyf/FlagGems -configfile: pytest.ini -plugins: hypothesis-6.136.4 -collected 79 items / 77 deselected / 2 selected - -test_pointwise_dynamic.py prepare args -(tensor(-0.2970, device='cuda:0'), tensor(-0.9287, device='cuda:0')) -(StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -codegen config -CodeGenConfig(max_tile_size=1024, max_grid_size=(65535, 65535, 65535), max_num_warps_per_cta=32, prefer_block_pointer=True, prefer_1d_tile=False) -prepare args -(tensor([1.1823, 1.3317], device='cuda:0'), tensor([ 0.5293, -1.8621], device='cuda:0')) -(StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[-0.8132, 1.1786], - [-0.7438, -0.1906]], device='cuda:0'), tensor([[ 0.4930, -1.2164], - [ 0.9088, -0.1768]], device='cuda:0')) -(StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[-1.8601, -0.4227], - [ 1.3705, -1.0152]], - - [[ 0.4315, -1.4218], - [ 1.0209, -0.2430]]], device='cuda:0'), tensor([[[ 1.4617, 0.6862], - [-0.1194, 1.3785]], - - [[ 1.1918, -0.0827], - [ 0.8230, 0.9496]]], device='cuda:0')) -(StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[-0.3224, 1.3309], - [ 0.6534, 1.0808]], - - [[ 1.0072, -1.1201], - [ 0.6319, 0.4277]]], - - - [[[ 1.9672, 0.6689], - [-3.6808, 0.7714]], - - [[-1.7392, 0.3517], - [ 0.9204, 0.1764]]]], device='cuda:0'), tensor([[[[-0.7463, -0.7569], - [-0.8233, -0.6181]], - - [[ 0.6934, -0.0429], - [-0.7358, 0.1099]]], - - - [[[-0.1601, -0.0552], - [-0.3414, 1.3781]], - - [[-0.2480, -0.6523], - [-1.7379, -0.6234]]]], device='cuda:0')) -(StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[ 0.2704, -1.0088], - [-0.3647, -1.2222]], - - [[ 1.0919, -0.2680], - [ 0.8758, 1.6583]]], - - - [[[-0.5207, 0.2852], - [-0.8225, -1.1796]], - - [[ 0.2731, -0.4979], - [-0.9975, 0.6967]]]], - - - - [[[[-1.2260, 0.2383], - [-0.2300, -0.1896]], - - [[ 0.3966, 1.9089], - [ 1.2267, 1.0300]]], - - - [[[-0.2939, 1.3133], - [ 0.0414, 0.4334]], - - [[-0.3053, 0.8554], - [ 0.6063, 0.1726]]]]], device='cuda:0'), tensor([[[[[-1.6378, 1.6238], - [-0.5613, 1.1061]], - - [[-0.0363, 0.0104], - [-0.3310, 0.0274]]], - - - [[[ 0.8872, -1.0808], - [ 0.5427, 0.2029]], - - [[-0.8999, 0.1127], - [-0.1466, 0.4300]]]], - - - - [[[[ 0.2388, -0.4143], - [-1.6903, -1.2033]], - - [[-1.6196, 0.6702], - [-1.5508, -0.0859]]], - - - [[[-0.8872, -1.0914], - [ 0.4906, 0.9720]], - - [[ 1.5728, 0.7931], - [-0.4117, 1.1539]]]]], device='cuda:0')) -(StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[[-1.6717, 0.9015], - [ 1.3460, -1.0790]], - - [[-0.6996, 0.1213], - [-0.3573, 0.3552]]], - - - [[[ 1.5661, -1.7230], - [-1.7996, 0.4824]], - - [[-1.3555, 1.1841], - [ 1.2727, 0.2422]]]], - - - - [[[[-2.1553, -0.5068], - [-0.4852, -0.6205]], - - [[ 0.0753, -0.2457], - [-0.8317, -0.3223]]], - - - [[[-0.4561, -0.5499], - [-0.3770, -0.5122]], - - [[ 0.0839, 1.1138], - [-0.0291, -0.2110]]]]], - - - - - [[[[[-0.0495, 0.0353], - [ 0.3858, 0.7092]], - - [[ 1.0674, -0.0896], - [ 0.4753, 1.6110]]], - - - [[[-1.0610, 0.6949], - [-0.0659, 0.1247]], - - [[ 0.9214, -1.1054], - [-0.3938, 0.6080]]]], - - - - [[[[ 0.3110, 0.8723], - [ 0.2618, -1.0005]], - - [[ 1.7426, 1.0750], - [-0.0059, -1.2323]]], - - - [[[ 1.0717, -1.3280], - [ 0.1808, -0.4660]], - - [[-1.3863, -0.9706], - [-0.1340, -0.6058]]]]]], device='cuda:0'), tensor([[[[[[ 1.2663, -1.5993], - [ 0.3489, -0.2854]], - - [[-1.5415, 0.5565], - [-0.1737, -0.9266]]], - - - [[[ 0.1387, 0.0365], - [ 0.0643, 0.8480]], - - [[-0.5534, 0.4295], - [-1.1992, -0.8280]]]], - - - - [[[[ 1.1308, 0.6357], - [ 0.0803, -0.7917]], - - [[-1.9792, 0.2240], - [-0.1467, -0.3046]]], - - - [[[ 1.8291, -1.0905], - [ 0.6813, -0.6597]], - - [[-0.3329, -0.0720], - [-0.7636, 0.6887]]]]], - - - - - [[[[[ 1.5301, -0.6431], - [-0.1242, -0.6108]], - - [[ 0.7454, -0.6965], - [ 0.1111, 0.6575]]], - - - [[[ 1.1191, -0.2867], - [-1.1339, 1.3141]], - - [[-1.9876, -1.3804], - [ 0.4192, -0.2563]]]], - - - - [[[[-0.6585, 1.5626], - [ 1.1675, -2.2182]], - - [[ 0.9029, 0.1574], - [-0.0679, 0.7056]]], - - - [[[-0.1949, -0.1326], - [-0.1437, 1.9092]], - - [[-0.9308, 1.6464], - [-1.3118, 1.4294]]]]]], device='cuda:0')) -(StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[[[ 2.0424e-01, 3.6634e+00], - [-7.3341e-01, -1.2448e+00]], - - [[ 5.4542e-01, 3.5558e-01], - [-5.4735e-01, -2.9916e-01]]], - - - [[[ 1.1489e+00, 1.4721e+00], - [-4.6577e-01, 8.1158e-01]], - - [[-1.1490e-01, 1.3631e+00], - [ 7.9627e-01, -2.8227e-01]]]], - - - - [[[[ 6.9279e-02, 1.6411e-01], - [ 2.9733e-01, 2.7890e+00]], - - [[ 1.6268e+00, -1.8650e-01], - [ 7.6546e-01, -1.9230e+00]]], - - - [[[-1.8418e+00, 6.2578e-01], - [-2.1408e+00, -3.9034e-01]], - - [[-7.7308e-01, 2.7685e+00], - [ 2.2738e-01, 7.9492e-01]]]]], - - - - - [[[[[-2.0471e+00, -1.6379e-01], - [-9.3400e-02, -1.0257e+00]], - - [[-2.8094e-01, -9.6701e-02], - [ 1.0009e+00, -1.8583e+00]]], - - - [[[ 1.7921e-01, -8.8393e-01], - [ 1.5230e-01, 1.7851e+00]], - - [[-1.9208e-01, -9.8954e-01], - [-8.1615e-01, -7.8057e-01]]]], - - - - [[[[ 1.3682e+00, -2.3161e+00], - [-1.2802e+00, -1.7351e+00]], - - [[-2.0993e+00, -1.3359e+00], - [ 2.1560e-01, 1.4425e+00]]], - - - [[[-2.3548e-01, 7.7176e-02], - [-4.9335e-01, 1.4003e+00]], - - [[ 5.7574e-01, -8.6277e-01], - [-1.2609e+00, 1.7326e-01]]]]]], - - - - - - [[[[[[-8.1147e-01, -6.6557e-03], - [ 5.9860e-01, -2.2498e+00]], - - [[ 1.6007e+00, 1.2472e+00], - [ 3.9557e-01, -8.8763e-01]]], - - - [[[ 6.4478e-01, -3.5087e-01], - [-1.0402e+00, 1.1489e+00]], - - [[ 1.5770e+00, -5.5470e-01], - [-1.9678e-01, -1.3599e-01]]]], - - - - [[[[ 4.2071e-01, 1.5784e+00], - [ 7.4890e-01, 1.5878e-01]], - - [[-7.4601e-02, -2.6819e-01], - [-1.3362e-02, -8.7745e-01]]], - - - [[[ 5.4868e-01, 9.4958e-01], - [-1.3364e+00, -1.1842e+00]], - - [[-4.5081e-01, -5.3200e-02], - [ 5.6642e-01, -1.5040e-01]]]]], - - - - - [[[[[ 1.5685e+00, -4.2357e-02], - [ 2.0459e+00, 9.0893e-01]], - - [[-8.7294e-01, -1.2778e+00], - [-9.9728e-02, -1.7435e+00]]], - - - [[[-1.0093e+00, 1.3365e+00], - [ 1.4186e+00, 1.0965e+00]], - - [[-7.9264e-01, -8.2843e-01], - [-4.7937e-01, 1.1248e+00]]]], - - - - [[[[ 5.4494e-01, -3.1463e-01], - [ 1.3364e+00, 4.9499e-01]], - - [[ 8.0316e-01, 5.9074e-01], - [-8.1497e-01, -1.2023e-01]]], - - - [[[ 4.8511e-01, -6.7446e-02], - [ 5.5175e-01, -4.7541e-01]], - - [[ 2.1157e-01, 2.6811e-01], - [-3.1363e-03, -1.0848e+00]]]]]]], device='cuda:0'), tensor([[[[[[[ 0.1743, -0.1971], - [ 1.9595, -1.2222]], - - [[ 1.6432, -0.3672], - [-0.2254, 0.8645]]], - - - [[[-1.3769, 0.9489], - [-0.5037, -0.6765]], - - [[ 1.4774, -0.6471], - [ 0.5856, -0.3982]]]], - - - - [[[[-0.5755, 0.6125], - [-1.0343, -0.6612]], - - [[-0.1660, -0.1347], - [-0.2614, 2.0332]]], - - - [[[ 0.5120, 1.7416], - [-2.5207, -1.0132]], - - [[ 0.7363, -0.7231], - [-0.5677, -0.0228]]]]], - - - - - [[[[[-2.1274, -1.4172], - [-0.5463, 0.5046]], - - [[-0.1216, 0.1479], - [ 0.2131, 0.5524]]], - - - [[[ 1.0420, -0.6186], - [-0.2048, 0.0830]], - - [[ 0.4373, -1.0390], - [ 0.8188, 1.5698]]]], - - - - [[[[-1.8029, 0.9458], - [-1.1819, 0.6787]], - - [[ 0.4058, 0.7104], - [ 0.4785, 0.1812]]], - - - [[[-0.6681, -0.5415], - [ 1.1918, -0.8995]], - - [[ 1.8353, -0.3662], - [-0.4716, -0.0552]]]]]], - - - - - - [[[[[[-0.0369, 1.5682], - [ 0.0950, 0.7663]], - - [[ 0.6449, -1.8643], - [ 0.5386, 0.7891]]], - - - [[[-0.4242, -0.1472], - [-0.4784, 0.7631]], - - [[-0.0643, -0.6710], - [-1.6131, -1.2870]]]], - - - - [[[[ 0.3732, -0.4229], - [-2.1129, 1.5201]], - - [[ 0.4253, -1.9120], - [-0.0107, 0.2316]]], - - - [[[-1.0561, -0.9999], - [ 0.1939, -1.3623]], - - [[ 0.5999, 0.6589], - [ 1.4880, -0.6617]]]]], - - - - - [[[[[ 0.6459, -1.4469], - [-1.4411, -0.5453]], - - [[ 1.1159, 1.3651], - [ 2.6172, 0.8242]]], - - - [[[-0.4146, -4.0123], - [ 0.2882, -0.5872]], - - [[-2.1117, 0.0071], - [-0.1420, -1.7221]]]], - - - - [[[[ 0.6570, 2.1023], - [ 0.0969, -0.2243]], - - [[-0.7129, -0.4179], - [ 1.0163, -0.3797]]], - - - [[[-0.4553, -1.4885], - [-0.6644, -0.2643]], - - [[ 0.1670, -0.3915], - [-0.5854, -0.0465]]]]]]], device='cuda:0')) -(StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -.prepare args -(tensor(-0.4134, device='cuda:0'), tensor(-0.8861, device='cuda:0')) -(StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(1,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -codegen config -CodeGenConfig(max_tile_size=1024, max_grid_size=(65535, 65535, 65535), max_num_warps_per_cta=32, prefer_block_pointer=False, prefer_1d_tile=False) -prepare args -(tensor([ 0.4362, -0.2752], device='cuda:0'), tensor([-0.5101, 0.0443], device='cuda:0')) -(StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(2,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[ 0.3208, -1.7019], - [-0.2168, -0.8074]], device='cuda:0'), tensor([[ 0.1845, 2.7074], - [ 0.5607, -0.5639]], device='cuda:0')) -(StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(4,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[ 0.4626, 3.0048], - [ 0.1396, -0.0194]], - - [[-1.0254, -2.4699], - [-0.0748, 0.6536]]], device='cuda:0'), tensor([[[ 0.6808, -0.3787], - [-1.3052, -0.1843]], - - [[-0.8677, -0.0071], - [ 0.0053, 0.3579]]], device='cuda:0')) -(StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(8,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[-0.5866, -0.9811], - [ 2.3893, 0.1781]], - - [[-1.0846, 0.5295], - [-0.9425, 0.3646]]], - - - [[[ 1.0202, -0.7207], - [ 0.4135, -1.0128]], - - [[-1.2374, -1.4222], - [-0.1221, -0.9303]]]], device='cuda:0'), tensor([[[[ 0.0718, -1.2530], - [ 0.2905, -1.2469]], - - [[ 0.5271, -0.4675], - [ 1.0905, 0.3822]]], - - - [[[ 1.0471, 1.7749], - [ 0.1588, 0.1025]], - - [[ 1.3431, -0.4345], - [ 0.9363, -0.3985]]]], device='cuda:0')) -(StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(16,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[-0.4167, 1.6186], - [ 0.6353, 0.6347]], - - [[ 0.6815, 0.0033], - [-1.1068, -1.8601]]], - - - [[[-0.8668, 0.0528], - [ 0.1634, -0.9488]], - - [[-1.1946, -0.6738], - [-0.1948, -1.3550]]]], - - - - [[[[ 0.2253, 0.2737], - [-0.1906, -0.1625]], - - [[ 1.1501, -2.2624], - [-0.8433, 0.1347]]], - - - [[[-1.3359, -0.3359], - [ 0.0135, -0.2707]], - - [[ 0.5104, 0.3117], - [-1.3617, 2.1708]]]]], device='cuda:0'), tensor([[[[[-0.3748, 1.8427], - [-1.0176, 0.6997]], - - [[-0.0785, 0.5860], - [-0.1548, 2.1246]]], - - - [[[ 1.1961, 0.5185], - [-0.1329, 0.5269]], - - [[ 1.8418, 0.2618], - [ 1.0670, 0.0048]]]], - - - - [[[[-0.3343, 0.6220], - [ 0.7129, 0.5356]], - - [[-3.0137, -0.9617], - [-0.9842, 0.2289]]], - - - [[[ 1.3555, 0.3382], - [-0.5355, -1.6174]], - - [[-1.5904, 0.4780], - [-0.4948, 1.9038]]]]], device='cuda:0')) -(StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(32,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[[-0.5968, 1.2613], - [ 1.0317, -0.5629]], - - [[ 0.1888, -0.3378], - [ 0.7123, -0.5086]]], - - - [[[-0.5209, -1.4746], - [ 0.5645, -0.6823]], - - [[ 0.6070, 0.3448], - [ 0.5543, 1.3481]]]], - - - - [[[[-0.5193, -0.5582], - [-1.4659, -0.6427]], - - [[ 0.9501, 0.9771], - [ 0.4604, 0.6087]]], - - - [[[ 1.0890, -0.4698], - [-0.2139, 0.3466]], - - [[ 0.8443, 0.1935], - [-1.9194, 0.6502]]]]], - - - - - [[[[[-0.7959, 0.6719], - [-0.0453, 0.1200]], - - [[ 0.0031, -1.3698], - [ 1.3819, -0.0564]]], - - - [[[ 1.6223, 2.5625], - [ 0.7277, -1.0128]], - - [[ 0.5218, 0.8861], - [-0.1644, 1.8201]]]], - - - - [[[[-0.7792, -0.4471], - [ 1.5924, 1.4347]], - - [[-0.4343, -0.5292], - [ 0.2309, 0.0933]]], - - - [[[-0.0781, 0.5380], - [-0.3952, 0.1455]], - - [[-2.1586, 0.0138], - [ 0.4798, 0.6924]]]]]], device='cuda:0'), tensor([[[[[[-2.3175, 0.2156], - [-0.7388, 0.2079]], - - [[-1.0429, -2.5669], - [-1.2059, -0.8969]]], - - - [[[-0.3532, 0.1744], - [-0.7268, 1.1393]], - - [[ 1.3921, -0.4389], - [-1.0334, 1.1344]]]], - - - - [[[[-1.0301, 0.7112], - [-0.2814, -1.8565]], - - [[-0.3059, -0.3104], - [-0.1747, -1.5077]]], - - - [[[-0.7768, -0.5173], - [ 0.8804, 1.9415]], - - [[-0.9273, -1.3389], - [ 0.0869, 0.3431]]]]], - - - - - [[[[[ 0.2382, -0.6564], - [ 0.0112, 0.5209]], - - [[-1.9071, -0.6068], - [ 0.6979, 0.4808]]], - - - [[[-1.4359, 0.3032], - [ 0.9301, -0.2266]], - - [[ 1.3977, -1.7108], - [-0.6738, -0.9103]]]], - - - - [[[[ 0.8537, 0.8330], - [-1.3226, 0.6371]], - - [[ 0.0989, 1.9216], - [-0.0214, -1.1547]]], - - - [[[-0.7310, -0.4316], - [-0.2408, -0.0127]], - - [[ 1.3543, 0.7188], - [ 0.3211, -0.7355]]]]]], device='cuda:0')) -(StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(64,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -prepare args -(tensor([[[[[[[-3.8181e-01, -3.5140e-01], - [ 3.0736e+00, -1.1741e+00]], - - [[-2.3109e-01, 5.3014e-01], - [-1.6435e-01, 9.9630e-03]]], - - - [[[ 1.2646e+00, -7.9442e-03], - [-2.9683e-01, -8.7684e-01]], - - [[-3.7105e-02, -3.4636e-01], - [ 1.1675e+00, -5.9899e-01]]]], - - - - [[[[ 8.8007e-02, -7.8832e-01], - [-1.8963e-01, -6.6822e-01]], - - [[ 2.5199e+00, -1.1756e-01], - [-9.1058e-01, -7.4305e-01]]], - - - [[[ 8.5866e-01, -1.8395e-01], - [-9.6957e-01, -7.6585e-01]], - - [[ 9.0692e-01, 1.7178e-01], - [ 8.9621e-01, -1.4574e-01]]]]], - - - - - [[[[[-5.1221e-01, 1.7885e+00], - [-5.4083e-01, -3.3127e-01]], - - [[ 5.6645e-01, -1.7224e+00], - [-7.1040e-01, 4.3608e-01]]], - - - [[[-7.6702e-01, -9.2200e-01], - [-5.8096e-03, -7.2986e-01]], - - [[ 2.4741e-01, 1.1015e+00], - [ 8.7678e-01, 4.4875e-01]]]], - - - - [[[[ 1.1284e+00, 2.1082e+00], - [-1.0700e-01, 3.1007e-01]], - - [[-4.7766e-01, -4.2508e-01], - [ 1.8567e-01, -9.1674e-01]]], - - - [[[-6.4479e-01, -5.4386e-01], - [-2.2923e-01, -4.6170e-01]], - - [[-3.3790e-01, -1.0072e+00], - [ 2.4262e-01, -3.3350e-01]]]]]], - - - - - - [[[[[[-3.3302e-01, -8.5371e-01], - [-2.6683e+00, 5.3133e-01]], - - [[ 7.0524e-01, 3.1367e-01], - [-1.5984e-02, 2.4901e-01]]], - - - [[[-1.8934e-01, -2.2555e+00], - [ 7.7969e-01, -6.6572e-01]], - - [[ 2.2181e+00, 2.4506e-01], - [ 1.0696e+00, 3.3704e-01]]]], - - - - [[[[-1.1058e-04, 2.4665e+00], - [-5.5239e-02, -5.1851e-01]], - - [[ 2.1208e-02, -1.5241e+00], - [ 5.0546e-01, 4.6546e-01]]], - - - [[[ 6.8653e-01, -4.5201e-01], - [ 1.8622e-01, -2.5306e-01]], - - [[ 1.3076e-01, -1.1173e+00], - [ 1.2160e+00, 2.0257e+00]]]]], - - - - - [[[[[-1.1573e+00, -4.6983e-01], - [ 1.0154e+00, -2.7183e-01]], - - [[ 7.6097e-01, 8.3070e-02], - [-5.6702e-01, -3.1605e-01]]], - - - [[[-9.3585e-01, -1.2475e+00], - [ 1.3100e+00, 2.1488e+00]], - - [[-1.9884e-01, -1.3029e+00], - [-9.4814e-02, 7.2414e-01]]]], - - - - [[[[ 3.9301e-01, 1.0673e-01], - [-6.4631e-01, 3.2812e-01]], - - [[-4.0081e-01, -6.2356e-01], - [ 1.0353e+00, -9.1513e-01]]], - - - [[[-7.5020e-01, 9.6213e-01], - [ 2.4811e-01, -7.6768e-01]], - - [[ 9.8705e-01, -1.9507e-01], - [ 2.6377e-01, -3.0627e+00]]]]]]], device='cuda:0'), tensor([[[[[[[ 2.4538, -0.8922], - [ 1.3684, -0.5150]], - - [[ 1.3176, 0.5348], - [-0.0852, -1.7798]]], - - - [[[ 1.0275, 0.3459], - [-2.3377, -0.9182]], - - [[-1.0226, -1.2612], - [ 0.0259, 2.4485]]]], - - - - [[[[ 0.1809, 0.8438], - [-0.4643, 0.8054]], - - [[ 0.3220, -2.4891], - [ 1.0310, -0.0935]]], - - - [[[ 1.4551, -0.5795], - [-1.6939, -0.6869]], - - [[-0.2361, -1.3001], - [ 0.7484, -0.7862]]]]], - - - - - [[[[[-1.3405, -0.4824], - [ 1.2345, -0.5692]], - - [[-2.6157, 0.8614], - [-0.2938, -1.2741]]], - - - [[[-1.3419, 0.1435], - [ 0.7502, -1.0826]], - - [[ 0.7598, 0.4087], - [ 0.8646, -0.5177]]]], - - - - [[[[-0.2879, 0.6889], - [ 1.9037, -1.6276]], - - [[ 2.6030, 2.0182], - [-0.3841, -0.7770]]], - - - [[[ 0.4936, 0.5611], - [ 0.8565, 0.3895]], - - [[-1.0566, 0.5147], - [ 0.2500, 0.9725]]]]]], - - - - - - [[[[[[ 0.2410, 0.1544], - [ 0.4306, -0.5967]], - - [[ 0.1601, 0.7948], - [-1.5519, -0.0909]]], - - - [[[ 2.1980, 0.4862], - [ 0.4591, 1.7357]], - - [[ 0.8565, 1.6170], - [ 0.6806, -1.7946]]]], - - - - [[[[-0.5146, -0.8587], - [ 0.6568, 0.3039]], - - [[ 2.1688, -0.8730], - [-0.0051, 1.6054]]], - - - [[[-0.4713, 0.2377], - [-0.0813, -1.3609]], - - [[ 1.7947, -0.5005], - [ 1.0356, 0.4802]]]]], - - - - - [[[[[ 1.7519, 0.4037], - [-0.2573, 0.9032]], - - [[ 1.1248, 0.3404], - [ 0.2109, -0.9327]]], - - - [[[ 0.0837, 0.7593], - [-1.2091, 2.3370]], - - [[ 0.5729, -1.5820], - [ 1.5224, 1.8282]]]], - - - - [[[[ 0.4065, -0.9427], - [-0.1904, -1.8901]], - - [[ 0.0204, -1.1088], - [ 0.6798, -0.4446]]], - - - [[[ 0.9965, 1.0413], - [ 1.3706, 1.0881]], - - [[-0.3065, -0.5953], - [ 1.7009, 1.1090]]]]]]], device='cuda:0')) -(StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0), StridedBuffer(shape=(128,), strides=(1,), dtype=torch.float32, offset=0, device=cuda:0)) -{} -. - -=============================== warnings summary =============================== -../../.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: 11 warnings - /home/fyf/.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details. - warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -================ 2 passed, 77 deselected, 11 warnings in 0.98s ================= diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index c65b37163..47699b0ee 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -58,40 +58,6 @@ def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): tl.store(out, sum_val) -@libentry() -@triton.jit -def sum_kernel( - in_ptr, - out_ptr, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, -): - if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( - in_ptr.dtype.element_ty == tl.bfloat16 - ): - cdtype = tl.float32 - else: - cdtype = in_ptr.dtype.element_ty - - # Map the program id to the row of inp it should compute. - row_ids = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - row_mask = row_ids < M - - acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) - for off in tl.range(0, N, BLOCK_N, STAGE): - col_ids = off + tl.arange(0, BLOCK_N) - col_mask = col_ids < N - mask = row_mask[:, None] & col_mask[None, :] - - a = tl.load(in_ptr + row_ids[:, None] * N + col_ids, mask, other=0).to(cdtype) - acc += a - out = tl.sum(acc, axis=1) - tl.store(out_ptr + row_ids, out, row_mask) - - def sum(inp, *, dtype=None): logger.debug("GEMS SUM") M = inp.numel() diff --git a/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py b/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py index 4b4f3b29c..6c168ef6c 100644 --- a/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py +++ b/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py @@ -197,6 +197,8 @@ def signature(self, outputs_in_arg: bool = False) -> str: for _ in range(self.num_outputs()): output_types.append("StridedBuffer") sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' + print("signature function call") + print(sig) return sig def _compute_input_id(self): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index e7fb0a976..edd040b8f 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -238,9 +238,11 @@ def __init__( self.fn_module = scalar_fn.__module__ def gen_import_function(self, code: IndentedBuffer): - code.writeline(f'"""Quoted source of {self.fn_name}:') + code.writemultiline("# gen import functions") + print("# gen import functions") + # code.writeline(f'"""Quoted source of {self.fn_name}:') code.writemultiline(self.fn.src) - code.writeline('"""') + # code.writeline('"""') code.newline() def gen_decorators(self, code): diff --git a/src/flag_gems/utils/triton_lang_extension.py b/src/flag_gems/utils/triton_lang_extension.py index 8ea38aac8..4fab21986 100644 --- a/src/flag_gems/utils/triton_lang_extension.py +++ b/src/flag_gems/utils/triton_lang_extension.py @@ -103,3 +103,38 @@ def fmod(x, y): def trunc(x): """trunc default - truncate to integer""" return tl.where(x >= 0, tl.floor(x), tl.ceil(x)) + + +# --- Pointwise Functions --- + +# src/flag_gems/ops/add.py + + +def add_func(x, y, alpha): + return x + y * alpha + + +def add_func_tensor_scalar(x, y, alpha): + return x + y * alpha + + +def add_func_scalar_tensor(x, y, alpha): + return x + y * alpha + + +# src/flag_gems/ops/fill.py for lib/fill.cpp +def fill_scalar_func(inp, value_scalar): + return tl.full(inp.shape, value_scalar, dtype=inp.dtype) + + +def fill_tensor_func(inp, value): + return value + + +# zeros 还不是pointwise dynamic + +# src/flag_gems/ops/copy.py for lib/cat.cpp + + +def copy(src): + return src From e01d53c8b94ae490459e6dbfa57ba41590f4451d Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Fri, 15 Aug 2025 17:43:43 +0800 Subject: [PATCH 18/22] use fast path work --- ctests/test_triton_pointwise.cpp | 7 +- include/flag_gems/utils.h | 33 +++++ lib/pointwise_dynamic.cpp | 112 +++++++++++------ lib/utils.cpp | 123 +++++++++++++++++++ src/flag_gems/utils/triton_lang_extension.py | 23 ++-- 5 files changed, 244 insertions(+), 54 deletions(-) diff --git a/ctests/test_triton_pointwise.cpp b/ctests/test_triton_pointwise.cpp index a06d5b1a4..bc87d516c 100644 --- a/ctests/test_triton_pointwise.cpp +++ b/ctests/test_triton_pointwise.cpp @@ -4,12 +4,13 @@ TEST(pointwise_op_simple_test, add) { const torch::Device device(torch::kCUDA, 0); - torch::Tensor a = torch::randn({10, 10}, device); - torch::Tensor b = torch::randn({10, 10}, device); + torch::Tensor a = torch::randn({128}, device); + torch::Tensor b = torch::randn({128}, device); torch::Tensor out_torch = a + b; torch::Tensor out_triton = flag_gems::add_tensor(a, b); - + std::cout << "out_torch sizes: " << out_torch.sizes() << std::endl; + std::cout << "out_triton sizes: " << out_triton.sizes() << std::endl; EXPECT_TRUE(torch::allclose(out_torch, out_triton)); } diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index ff44bce19..b17f064ce 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -69,4 +69,37 @@ class StridedBuffer { std::vector strides_; int64_t ndim_; }; +class ParamStack { + private: + std::vector kernel_params; + std::string signature; + std::vector tensor_ptr; + std::vector strides; + std::vector task_shape; + std::vector task_partition; + std::string constexp; + + private: + void push_strides(); + void push_task_shape(); + void push_task_partition(); + void add_global_scratch(); + + public: + ParamStack(int max_args = 32) { + kernel_params.reserve(max_args); + tensor_ptr.reserve(max_args); + } + void save_tensor(at::Tensor& tensor); + void save_tensor(const at::Tensor& tensor); + void save_stride(int64_t stride); + void save_task_shape(int64_t shape); + void save_task_partition(int64_t partition); + void save_constexpr(int64_t value); + void save_constexpr(bool value); + void** get_params(); + std::string get_signature(); + + void build(); +}; }; // namespace flag_gems::pointwise_dynamic diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index 6a3da6a35..f04ccd171 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -34,8 +34,8 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // TODO: parse tensor meta info // LOG(INFO)<< fmt::format("add tensor"); std::string signature; - std::cout << "add tensor\n"; std::vector kernel_params; + pointwise_dynamic::ParamStack stk = pointwise_dynamic::ParamStack(); // 2 input void* a_ptr = a_.data_ptr(); void* b_ptr = b_.data_ptr(); @@ -47,39 +47,63 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { // int64_t val0 = 1; // signature.push("1,"); // kernel_params.push_back(&val0); - int ndim; + // general args + int64_t ndim; + int64_t num_ctas; + int64_t tiles_per_cta; + int64_t tile_sizes; at::Tensor out = at::empty_like(a_); - kernel_params.push_back(&out); + void* out_ptr = out.data_ptr(); + kernel_params.push_back(&out_ptr); signature.append("*fp32:16,"); std::vector tensors = {a_, b_, out}; - int task_shape; + int64_t task_shape; + const int num_warps = 4; // TODO:pointwise codegen 静态指定 + const int num_stages = 1; if (pointwise_dynamic::use_fast_path(tensors)) { // prepare output with size of task_space std::cout << "use fast path\n"; task_shape = a_.numel(); void* task_shape_ptr = &task_shape; - int stride = 1; + int64_t stride = 1; void* stride_ptr = &stride; ndim = 1; - int fast_path_stride_order = 0; + int64_t fast_path_stride_order = 0; void* fast_path_stride_order_ptr = &fast_path_stride_order; // push args // stride for input - kernel_params.push_back(stride_ptr); - signature.append("i64,"); + // kernel_params.push_back(stride_ptr); + signature.append("i64:1,"); // kernel_params.push_back(fast_path_stride_order_ptr); - kernel_params.push_back(stride_ptr); - signature.append("i64,"); + // kernel_params.push_back(stride_ptr); + signature.append("i64:1,"); // kernel_params.push_back(fast_path_stride_order_ptr); // stride for output - kernel_params.push_back(stride_ptr); - signature.append("i64,"); + // kernel_params.push_back(stride_ptr); + signature.append("i64:1,"); + stk.save_stride(stride); + stk.save_stride(stride); + stk.save_stride(stride); // task_space -> shape_args... shape = out0.shape kernel_params.push_back(task_shape_ptr); signature.append("i64,"); + stk.save_task_shape(task_shape); // num_tasks -> num_tasks = out0.numel() kernel_params.push_back(task_shape_ptr); signature.append("i64,"); + stk.save_task_shape(task_shape); + + int64_t tile_sizes = num_warps * 32; + int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks + + // num_ctas = min(65536, num_tiles) + num_ctas = std::min(static_cast(65536), num_tiles); + // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) + tiles_per_cta = utils::cdiv(num_tiles, num_ctas); + void* tiles_per_cta_ptr = &tiles_per_cta; + // kernel_params.push_back(tiles_per_cta_ptr); + signature.append("i64:1,"); + // stk.save_task_partition(tiles_per_cta); } else { // calculate task_space std::vector shapes; @@ -164,34 +188,33 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { int64_t si = task_space[i]; kernel_params.push_back(const_cast(&si)); } - // num_task out的 - int64_t num_task = out.numel(); - kernel_params.push_back(const_cast(&num_task)); + tile_sizes = num_warps * 32; + int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks + // num_ctas = min(65536, num_tiles) + /* TODO,处理tiles_per_cta 这件事 + num_ctas = std::min(static_cast(65536), num_tiles); + // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) + int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); + void* tiles_per_cta_ptr = &tiles_per_cta; + kernel_params.push_back(tiles_per_cta_ptr); + // num_tasks -> num_tasks = out0.numel() + kernel_params.push_back(task_shape_ptr); + // num_task out的 + int64_t num_task = out.numel(); + kernel_params.push_back(const_cast(&num_task)); + */ } - /* - tiles_per_cta: int, - tile_size: tl.constexpr, - one_tile_per_cta: tl.constexpr, - */ - // # tile size & tiles_per_cta, gsl style - // tile_sizes = heuristics_for_tile_size(512, *shape) - int64_t tile_sizes = 1024; - int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks - - // num_ctas = min(65536, num_tiles) - int64_t num_ctas = std::min(static_cast(65536), num_tiles); - // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) - int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); - kernel_params.push_back(reinterpret_cast(tiles_per_cta)); - signature.append("i64,"); signature.append(std::to_string(tile_sizes)); signature.append(","); + stk.save_constexpr(tile_sizes); // one_tile_per_cta = tiles_per_cta==1 bool one_tile_per_cta = (tiles_per_cta == 1); signature.append(std::to_string(one_tile_per_cta)); + stk.save_constexpr(one_tile_per_cta); void* global_scratch = nullptr; kernel_params.push_back(&global_scratch); + // get function std::array is_scalar; pointwise_dynamic::checkIfScalar(a_, b_, is_scalar); @@ -202,11 +225,9 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { std::string kernel_name = std::get<0>(ans_tuple); std::string file_path = std::get<1>(ans_tuple); - std::cout << "file_path:" << file_path << std::endl; - // TODO: 四种情况 if (!is_scalar[0] && !is_scalar[1]) { - f = TritonJITFunction::getInstance(file_path, "add"); + f = TritonJITFunction::getInstance(file_path, kernel_name); } else if (!is_scalar[0] && is_scalar[1]) { // TODO f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), @@ -222,15 +243,32 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); c10::DeviceGuard guard(out.device()); CUstream raw_stream = static_cast(stream.stream()); - const int num_warps = 8; - const int num_stages = 1; + stk.save_tensor(a_); + stk.save_tensor(b_); + stk.save_tensor(out); + // const expr需要在这里... + stk.build(); + std::cout << "size of params" << kernel_params.size() << std::endl; + + std::cout << "file_path:" << file_path << std::endl; + std::cout << "signature:" << signature << std::endl; + + std::cout << "--- Launching with raw args ---" << std::endl; + std::cout << "raw_stream: " << raw_stream << std::endl; + std::cout << "num_ctas: " << num_ctas << std::endl; + std::cout << "num_warps: " << num_warps << std::endl; + std::cout << "num_stages: " << num_stages << std::endl; + std::cout << "signature: " << signature << std::endl; + std::cout << "params: " << kernel_params << std::endl; f->launch_with_raw_args(raw_stream, - num_tiles, + num_ctas, 1, 1, num_warps, num_stages, + // stk.get_signature(), + // stk.get_params() signature, kernel_params.data()); return out; diff --git a/lib/utils.cpp b/lib/utils.cpp index 01da5ae91..a01eb5b1d 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -367,4 +367,127 @@ StrideW stride_order(const StrideR& strides) { StrideR create_stride_r_view(const StrideW& stride_w) { return StrideR(reinterpret_cast(stride_w.data()), stride_w.size()); } + +void ParamStack::save_tensor(const at::Tensor& tensor) { + void* p_item = tensor.data_ptr(); + tensor_ptr.push_back(p_item); + if (tensor.dtype() == at::kFloat) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp32:16,"); + } else if (tensor.dtype() == at::kInt) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*int32:16,"); + } else if (tensor.dtype() == at::kDouble) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp64:16,"); + } else if (tensor.dtype() == at::kHalf) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp16:16,"); + } else { + std::runtime_error("TypeError: we only support fp64/32/16 and int32 now"); + } +} + +/* +void *p_item = item.data_ptr(); +data_pointers.push_back(p_item); +kernel_args.push_back(&(data_pointers.back())); +*/ +void ParamStack::save_tensor(at::Tensor& tensor) { + void* p_item = tensor.data_ptr(); + tensor_ptr.push_back(p_item); + if (tensor.dtype() == at::kFloat) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp32:16,"); + } else if (tensor.dtype() == at::kInt) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*int32:16,"); + } else if (tensor.dtype() == at::kDouble) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp64:16,"); + } else if (tensor.dtype() == at::kHalf) { + kernel_params.push_back(&(tensor_ptr.back())); + signature.append("*fp16:16,"); + } else { + std::runtime_error("TypeError: we only support fp64/32/16 and int32 now"); + } +} + +std::string ParamStack::get_signature() { + if (!signature.empty() && signature.back() == ',') { + signature.pop_back(); + } + return signature; +} + +void** ParamStack::get_params() { + void** res = kernel_params.empty() ? nullptr : kernel_params.data(); + if (res == nullptr) { + // kernel_params 是空的 + std::cout << "The parameter stack is empty." << std::endl; + } else { + // kernel_params 不为空 + std::cout << "The parameter stack is not empty." << std::endl; + } + return res; +} + +void ParamStack::save_stride(int64_t stride) { + strides.push_back(stride); +} + +void ParamStack::save_task_shape(int64_t shape) { + task_shape.push_back(shape); +} + +void ParamStack::save_task_partition(int64_t partition) { + task_partition.push_back(partition); +} + +void ParamStack::push_strides() { + for (auto& stride : strides) { + kernel_params.push_back(static_cast(&stride)); + signature.append("i64,"); + } +} + +void ParamStack::push_task_shape() { + for (auto& shape : task_shape) { + kernel_params.push_back(static_cast(&shape)); + signature.append("i64,"); + } +} + +void ParamStack::push_task_partition() { + for (auto& partition : task_partition) { + kernel_params.push_back(static_cast(&partition)); + signature.append("i64,"); + } +} + +void ParamStack::add_global_scratch() { + void* global_scratch = nullptr; + kernel_params.push_back(global_scratch); +} + +void ParamStack::build() { + push_strides(); + push_task_shape(); + push_task_partition(); + signature.append(constexp); + add_global_scratch(); +} + +void ParamStack::save_constexpr(int64_t value) { + constexp.append(std::to_string(value) + ","); +} + +void ParamStack::save_constexpr(bool value) { + if (value) { + constexp.append("True,"); + } else { + constexp.append("False,"); + } +} + }; // namespace flag_gems::pointwise_dynamic diff --git a/src/flag_gems/utils/triton_lang_extension.py b/src/flag_gems/utils/triton_lang_extension.py index 4fab21986..8fc017176 100644 --- a/src/flag_gems/utils/triton_lang_extension.py +++ b/src/flag_gems/utils/triton_lang_extension.py @@ -107,34 +107,29 @@ def trunc(x): # --- Pointwise Functions --- -# src/flag_gems/ops/add.py - -def add_func(x, y, alpha): +# src/flag_gems/ops/add.py for lib/add.cpp +@triton.jit +def add_func(x, y, alpha=1): return x + y * alpha -def add_func_tensor_scalar(x, y, alpha): +@triton.jit +def add_func_tensor_scalar(x, y, alpha=1): return x + y * alpha -def add_func_scalar_tensor(x, y, alpha): +@triton.jit +def add_func_scalar_tensor(x, y, alpha=1): return x + y * alpha # src/flag_gems/ops/fill.py for lib/fill.cpp +@triton.jit def fill_scalar_func(inp, value_scalar): return tl.full(inp.shape, value_scalar, dtype=inp.dtype) +@triton.jit def fill_tensor_func(inp, value): return value - - -# zeros 还不是pointwise dynamic - -# src/flag_gems/ops/copy.py for lib/cat.cpp - - -def copy(src): - return src From dec3f3d9a88f56943ee233903d0e8c3e9a05482f Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Sat, 16 Aug 2025 01:43:47 +0800 Subject: [PATCH 19/22] pointwise dynamic fastpath --- ctests/test_triton_pointwise.cpp | 12 -- include/flag_gems/utils.h | 42 +----- lib/pointwise_dynamic.cpp | 214 ++----------------------------- lib/utils.cpp | 177 ++++--------------------- 4 files changed, 38 insertions(+), 407 deletions(-) diff --git a/ctests/test_triton_pointwise.cpp b/ctests/test_triton_pointwise.cpp index bc87d516c..0dfce6b46 100644 --- a/ctests/test_triton_pointwise.cpp +++ b/ctests/test_triton_pointwise.cpp @@ -13,15 +13,3 @@ TEST(pointwise_op_simple_test, add) { std::cout << "out_triton sizes: " << out_triton.sizes() << std::endl; EXPECT_TRUE(torch::allclose(out_torch, out_triton)); } - -TEST(pointwise_op_broadcast_test, add) { - const torch::Device device(torch::kCUDA, 0); - torch::Tensor a = torch::randn({30, 50}, device); - torch::Tensor b = torch::randn({50}, device); - - torch::Tensor out_torch = a + b; - torch::Tensor out_triton = flag_gems::add_tensor(a, b); - std::cout << "out_torch sizes: " << out_torch.sizes() << std::endl; - std::cout << "out_triton sizes: " << out_triton.sizes() << std::endl; - EXPECT_TRUE(torch::allclose(out_torch, out_triton)); -} diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index b17f064ce..353aeeb9d 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -26,49 +26,11 @@ bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2); }; // namespace flag_gems::utils namespace flag_gems::pointwise_dynamic { -using ShapeR = c10::IntArrayRef; -using ShapeW = std::vector; -using StrideR = c10::IntArrayRef; -using StrideW = std::vector; -bool all_the_same_shape(const std::vector& tensors); -bool all_c_contiguous(const std::vector& tensors); -bool use_fast_path(const std::vector& tensors); void checkIfScalar(const torch::Tensor& tensor1, const torch::Tensor& tensor2, std::array& is_tensor); -ShapeW broadcast(const ShapeR& s1, const ShapeR& s2); -ShapeW broadcast_shapes(const std::vector& shapes); -StrideW broadcasted_stride(const ShapeR& shape, const StrideR& stride, const ShapeR& new_shape); -void print_shapes(const std::vector& shapes); -StrideW stride_order(const StrideR& strides); -StrideR create_stride_r_view(const StrideW& stride_w); -class StridedBuffer { - public: - StridedBuffer(const torch::Tensor& base, - c10::optional shape = c10::nullopt, - c10::optional strides = c10::nullopt, - int64_t offset = 0); - - const c10::IntArrayRef strides() const; - const c10::IntArrayRef sizes() const; - long numel() const; - int64_t dim() const; - const torch::Tensor& unwrap() const; - void* data_ptr() const; - torch::Storage untyped_storage() const; - StridedBuffer clone() const; - StridedBuffer& copy_(const StridedBuffer& src); - StridedBuffer& copy_(const torch::Tensor& src); - long offset() const; +bool use_fast_path(const std::vector& tensors); - private: - torch::Tensor base_; - void* data_ptr_; - int64_t offset_; - std::vector shape_; - std::vector strides_; - int64_t ndim_; -}; class ParamStack { private: std::vector kernel_params; @@ -78,6 +40,7 @@ class ParamStack { std::vector task_shape; std::vector task_partition; std::string constexp; + void* global_scratch; private: void push_strides(); @@ -89,6 +52,7 @@ class ParamStack { ParamStack(int max_args = 32) { kernel_params.reserve(max_args); tensor_ptr.reserve(max_args); + void* global_scratch = nullptr; } void save_tensor(at::Tensor& tensor); void save_tensor(const at::Tensor& tensor); diff --git a/lib/pointwise_dynamic.cpp b/lib/pointwise_dynamic.cpp index f04ccd171..eb3fac1cb 100644 --- a/lib/pointwise_dynamic.cpp +++ b/lib/pointwise_dynamic.cpp @@ -13,231 +13,56 @@ namespace flag_gems { using namespace triton_jit; -/* -def add_func( - in0_ptr: tl.tensor, # of tl.pointer_type - in1_ptr: tl.tensor, # of tl.pointer_type - out0_ptr: tl.tensor, # of tl.pointer_type - in0_stride0: int, in0_stride1: int, # strides for in0 - in1_stride0: int, in1_stride1: int, # strides for in1 - out0_stride0: int, out0_stride1: int, # strides for out0 - s0: int, s1: int, # task_space - num_tasks: int, - tiles_per_cta: int, - tile_size: tl.constexpr, - one_tile_per_cta: tl.constexpr, -): -*/ - namespace py = pybind11; at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { - // TODO: parse tensor meta info - // LOG(INFO)<< fmt::format("add tensor"); - std::string signature; - std::vector kernel_params; pointwise_dynamic::ParamStack stk = pointwise_dynamic::ParamStack(); - // 2 input - void* a_ptr = a_.data_ptr(); - void* b_ptr = b_.data_ptr(); - kernel_params.push_back(&a_ptr); - signature.append("*fp32:16,"); - kernel_params.push_back(&b_ptr); - signature.append("*fp32:16,"); - // TODO: use fast path没有这个,但 - // int64_t val0 = 1; - // signature.push("1,"); - // kernel_params.push_back(&val0); - // general args - int64_t ndim; + int64_t task_shape, ndim; int64_t num_ctas; int64_t tiles_per_cta; int64_t tile_sizes; + int64_t num_tiles; at::Tensor out = at::empty_like(a_); - void* out_ptr = out.data_ptr(); - kernel_params.push_back(&out_ptr); - signature.append("*fp32:16,"); std::vector tensors = {a_, b_, out}; - int64_t task_shape; - const int num_warps = 4; // TODO:pointwise codegen 静态指定 + const int num_warps = 4; const int num_stages = 1; if (pointwise_dynamic::use_fast_path(tensors)) { - // prepare output with size of task_space - std::cout << "use fast path\n"; task_shape = a_.numel(); - void* task_shape_ptr = &task_shape; int64_t stride = 1; - void* stride_ptr = &stride; ndim = 1; - int64_t fast_path_stride_order = 0; - void* fast_path_stride_order_ptr = &fast_path_stride_order; - // push args - // stride for input - // kernel_params.push_back(stride_ptr); - signature.append("i64:1,"); - // kernel_params.push_back(fast_path_stride_order_ptr); - // kernel_params.push_back(stride_ptr); - signature.append("i64:1,"); - // kernel_params.push_back(fast_path_stride_order_ptr); - // stride for output - // kernel_params.push_back(stride_ptr); - signature.append("i64:1,"); stk.save_stride(stride); stk.save_stride(stride); stk.save_stride(stride); - // task_space -> shape_args... shape = out0.shape - kernel_params.push_back(task_shape_ptr); - signature.append("i64,"); stk.save_task_shape(task_shape); - // num_tasks -> num_tasks = out0.numel() - kernel_params.push_back(task_shape_ptr); - signature.append("i64,"); stk.save_task_shape(task_shape); - - int64_t tile_sizes = num_warps * 32; - int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks - - // num_ctas = min(65536, num_tiles) + tile_sizes = num_warps * 32; + num_tiles = utils::cdiv(task_shape, tile_sizes); num_ctas = std::min(static_cast(65536), num_tiles); - // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) tiles_per_cta = utils::cdiv(num_tiles, num_ctas); - void* tiles_per_cta_ptr = &tiles_per_cta; - // kernel_params.push_back(tiles_per_cta_ptr); - signature.append("i64:1,"); - // stk.save_task_partition(tiles_per_cta); + stk.save_task_partition(tiles_per_cta); } else { - // calculate task_space - std::vector shapes; - shapes.push_back(a_.sizes()); - shapes.push_back(b_.sizes()); - pointwise_dynamic::ShapeW task_space = pointwise_dynamic::broadcast_shapes(shapes); - ndim = task_space.size(); - // prepare output with size of task_space - std::cout << "else"; - // broadcast tensor - // ndim = len(task_shape) - // shapes = tuple(item.shape for item in in_tensors) - // task_shape = broadcast_shapes(shapes) - // c10::IntArrayRef vs at::DimVector - - // broad tensor and warp with StridedBuffer - // TODO:确定copy机制是否高效 - pointwise_dynamic::StridedBuffer a = pointwise_dynamic::StridedBuffer( - a_, - task_shape, - pointwise_dynamic::broadcasted_stride(a_.sizes(), a_.strides(), task_shape)); - pointwise_dynamic::StridedBuffer b = pointwise_dynamic::StridedBuffer( - b_, - task_shape, - pointwise_dynamic::broadcasted_stride(b_.sizes(), b_.strides(), task_shape)); - - // input stride - const c10::IntArrayRef a_strides = a.strides(); - for (int i = 0; i < ndim; i++) { - kernel_params.push_back(const_cast(&a_strides[i])); - } - if (ndim >= 2) { - const pointwise_dynamic::StrideW a_strides_vec(a_strides.begin(), a_strides.end()); - std::vector order_vec = pointwise_dynamic::stride_order(a_strides_vec); - for (int i = 0; i < ndim; i++) { - long order_val = order_vec[i]; - kernel_params.push_back(const_cast(&order_val)); - } - } else { - pointwise_dynamic::StrideW zero_stride(1, 0); - void* zero_stride_ptr = zero_stride.data(); - kernel_params.push_back(&zero_stride_ptr); - } - - const c10::IntArrayRef b_strides = b.strides(); - for (int i = 0; i < ndim; i++) { - kernel_params.push_back(const_cast(&b_strides[i])); - } - if (ndim >= 2) { - const pointwise_dynamic::StrideW b_strides_vec(b_strides.begin(), b_strides.end()); - std::vector order_vec = pointwise_dynamic::stride_order(b_strides_vec); - for (int i = 0; i < ndim; i++) { - long order_val = order_vec[i]; - kernel_params.push_back(const_cast(&order_val)); - } - } else { - pointwise_dynamic::StrideW zero_stride(1, 0); - void* zero_stride_ptr = zero_stride.data(); - kernel_params.push_back(&zero_stride_ptr); - } - // output stride - // TODO:封装 push 1d tensor metadata的函数 - const c10::IntArrayRef output_strides = out.strides(); - for (int i = 0; i < ndim; i++) { - kernel_params.push_back(const_cast(&output_strides[i])); - } - if (ndim >= 2) { - const pointwise_dynamic::StrideW output_strides_vec(output_strides.begin(), output_strides.end()); - std::vector order_vec = pointwise_dynamic::stride_order(output_strides_vec); - for (int i = 0; i < ndim; i++) { - long order_val = order_vec[i]; - kernel_params.push_back(const_cast(&order_val)); - } - } else { - pointwise_dynamic::StrideW zero_stride(1, 0); - void* zero_stride_ptr = zero_stride.data(); - kernel_params.push_back(&zero_stride_ptr); - } - - // task space - for (int i = 0; i < ndim; i++) { - int64_t si = task_space[i]; - kernel_params.push_back(const_cast(&si)); - } - tile_sizes = num_warps * 32; - int64_t num_tiles = utils::cdiv(task_shape, tile_sizes); // aka num blocks - // num_ctas = min(65536, num_tiles) - /* TODO,处理tiles_per_cta 这件事 - num_ctas = std::min(static_cast(65536), num_tiles); - // tiles_per_cta = triton.cdiv(num_tiles, num_ctas) - int64_t tiles_per_cta = utils::cdiv(num_tiles, num_ctas); - void* tiles_per_cta_ptr = &tiles_per_cta; - kernel_params.push_back(tiles_per_cta_ptr); - // num_tasks -> num_tasks = out0.numel() - kernel_params.push_back(task_shape_ptr); - // num_task out的 - int64_t num_task = out.numel(); - kernel_params.push_back(const_cast(&num_task)); - */ + std::runtime_error("NotImplementError"); } - signature.append(std::to_string(tile_sizes)); - signature.append(","); stk.save_constexpr(tile_sizes); - // one_tile_per_cta = tiles_per_cta==1 - bool one_tile_per_cta = (tiles_per_cta == 1); - signature.append(std::to_string(one_tile_per_cta)); + int64_t one_tile_per_cta = (tiles_per_cta == 1); stk.save_constexpr(one_tile_per_cta); - void* global_scratch = nullptr; - kernel_params.push_back(&global_scratch); - - // get function std::array is_scalar; pointwise_dynamic::checkIfScalar(a_, b_, is_scalar); std::optional f; - // TODO: code gen in c++ - auto ans_tuple = gen_add(ndim); std::string kernel_name = std::get<0>(ans_tuple); std::string file_path = std::get<1>(ans_tuple); - - // TODO: 四种情况 if (!is_scalar[0] && !is_scalar[1]) { f = TritonJITFunction::getInstance(file_path, kernel_name); } else if (!is_scalar[0] && is_scalar[1]) { - // TODO + std::runtime_error("NotImplementError"); f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func_tensor_scalar"); } else if (is_scalar[0] && !is_scalar[1]) { - // TODO + std::runtime_error("NotImplementError"); f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"), "add_func_scalar_tensor"); } else { - std::cout << "else"; return a_ + b_; } c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); @@ -247,30 +72,15 @@ at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { stk.save_tensor(a_); stk.save_tensor(b_); stk.save_tensor(out); - // const expr需要在这里... stk.build(); - std::cout << "size of params" << kernel_params.size() << std::endl; - - std::cout << "file_path:" << file_path << std::endl; - std::cout << "signature:" << signature << std::endl; - - std::cout << "--- Launching with raw args ---" << std::endl; - std::cout << "raw_stream: " << raw_stream << std::endl; - std::cout << "num_ctas: " << num_ctas << std::endl; - std::cout << "num_warps: " << num_warps << std::endl; - std::cout << "num_stages: " << num_stages << std::endl; - std::cout << "signature: " << signature << std::endl; - std::cout << "params: " << kernel_params << std::endl; f->launch_with_raw_args(raw_stream, num_ctas, 1, 1, num_warps, num_stages, - // stk.get_signature(), - // stk.get_params() - signature, - kernel_params.data()); + stk.get_signature(), + stk.get_params()); return out; } diff --git a/lib/utils.cpp b/lib/utils.cpp index a01eb5b1d..c06432343 100644 --- a/lib/utils.cpp +++ b/lib/utils.cpp @@ -179,131 +179,6 @@ void checkIfScalar(const torch::Tensor& tensor1, is_scalar[0] = (tensor1.dim() == 0); is_scalar[1] = (tensor2.dim() == 0); } -StridedBuffer::StridedBuffer(const torch::Tensor& base, - c10::optional shape, - c10::optional strides, - int64_t offset) - : base_(base.contiguous()), offset_(offset) { - if (offset_ == 0) { - data_ptr_ = base_.data_ptr(); - } else { - data_ptr_ = static_cast(base_.data_ptr()) + base_.element_size() * offset_; - } - shape_ = shape.has_value() ? shape.value().vec() : base_.sizes().vec(); - strides_ = strides.has_value() ? strides.value().vec() : base_.strides().vec(); - ndim_ = shape_.size(); -} - -const c10::IntArrayRef StridedBuffer::strides() const { - return strides_; -} - -const c10::IntArrayRef StridedBuffer::sizes() const { - return shape_; -} - -long StridedBuffer::numel() const { - long num = 1; - for (long s : shape_) { - num *= s; - } - return num; -} - -int64_t StridedBuffer::dim() const { - return ndim_; -} - -const torch::Tensor& StridedBuffer::unwrap() const { - return base_; -} - -void* StridedBuffer::data_ptr() const { - return data_ptr_; -} - -torch::Storage StridedBuffer::untyped_storage() const { - return base_.storage(); -} - -StridedBuffer StridedBuffer::clone() const { - torch::Tensor cloned_base = base_.clone(); - return StridedBuffer(cloned_base, shape_, strides_, offset_); -} - -StridedBuffer& StridedBuffer::copy_(const StridedBuffer& src) { - torch::Tensor temp_dst = torch::empty_like(src.unwrap()); - temp_dst.copy_(src.unwrap()); - - base_ = temp_dst; - strides_ = src.strides().vec(); - shape_ = src.sizes().vec(); - offset_ = src.offset(); - data_ptr_ = base_.data_ptr(); - - return *this; -} - -StridedBuffer& StridedBuffer::copy_(const torch::Tensor& src) { - StridedBuffer src_buffer(src); - return this->copy_(src_buffer); -} - -long StridedBuffer::offset() const { - return offset_; -} - -ShapeW broadcast(const ShapeR& s1, const ShapeR& s2) { - long ndim = std::max(s1.size(), s2.size()); - ShapeW output_shape(ndim); - long p1 = s1.size() - 1; - long p2 = s2.size() - 1; - - for (long i = ndim - 1; i >= 0; --i) { - long d1 = (p1 >= 0) ? s1[p1] : 1; - long d2 = (p2 >= 0) ? s2[p2] : 1; - - if (d1 != d2 && d1 != 1 && d2 != 1) { - // 抛出异常或返回错误,因为形状不可广播 - throw std::runtime_error("Shapes are not broadcastable."); - } - output_shape[i] = std::max(d1, d2); - if (p1 >= 0) p1--; - if (p2 >= 0) p2--; - } - return output_shape; -} - -ShapeW broadcast_shapes(const std::vector& shapes) { - if (shapes.empty()) { - return {}; - } - - ShapeW output_shape(shapes[0].begin(), shapes[0].end()); - for (size_t i = 1; i < shapes.size(); ++i) { - output_shape = broadcast(output_shape, shapes[i]); - } - return output_shape; -} - -StrideW broadcasted_stride(const ShapeR& shape, const StrideR& stride, const ShapeR& new_shape) { - assert(broadcastable_to(shape, new_shape) && "Shapes are not broadcastable."); - - int r1 = shape.size(); - int r2 = new_shape.size(); - int d = r2 - r1; - - StrideW new_stride(r2, 0); - for (int i = 0; i < r1; ++i) { - int new_dim_index = d + i; - if (shape[i] == 1 && new_shape[new_dim_index] > 1) { - new_stride[new_dim_index] = 0; - } else { - new_stride[new_dim_index] = stride[i]; - } - } - return new_stride; -} bool all_the_same_shape(const std::vector& tensors) { if (tensors.empty()) { @@ -351,22 +226,6 @@ bool use_fast_path(const std::vector& tensors) { } return all_the_same_stride(tensors) && tensors[0].is_non_overlapping_and_dense(); } -StrideW stride_order(const StrideR& strides) { - // Create a vector of indices from 0 to strides.size() - 1 - StrideW indices(strides.size()); - std::iota(indices.begin(), indices.end(), 0); - - // Sort the indices based on the absolute value of the corresponding stride - std::sort(indices.begin(), indices.end(), [&](int64_t i, int64_t j) { - return std::abs(strides[i]) < std::abs(strides[j]); - }); - - return indices; -} - -StrideR create_stride_r_view(const StrideW& stride_w) { - return StrideR(reinterpret_cast(stride_w.data()), stride_w.size()); -} void ParamStack::save_tensor(const at::Tensor& tensor) { void* p_item = tensor.data_ptr(); @@ -388,11 +247,6 @@ void ParamStack::save_tensor(const at::Tensor& tensor) { } } -/* -void *p_item = item.data_ptr(); -data_pointers.push_back(p_item); -kernel_args.push_back(&(data_pointers.back())); -*/ void ParamStack::save_tensor(at::Tensor& tensor) { void* p_item = tensor.data_ptr(); tensor_ptr.push_back(p_item); @@ -433,7 +287,11 @@ void** ParamStack::get_params() { } void ParamStack::save_stride(int64_t stride) { - strides.push_back(stride); + if (stride == 1) { + strides.push_back(0); + } else { + strides.push_back(stride); + } } void ParamStack::save_task_shape(int64_t shape) { @@ -441,13 +299,21 @@ void ParamStack::save_task_shape(int64_t shape) { } void ParamStack::save_task_partition(int64_t partition) { - task_partition.push_back(partition); + if (partition == 1) { + task_partition.push_back(0); + } else { + task_partition.push_back(partition); + } } void ParamStack::push_strides() { for (auto& stride : strides) { - kernel_params.push_back(static_cast(&stride)); - signature.append("i64,"); + if (stride != 0) { + kernel_params.push_back(static_cast(&stride)); + signature.append("i64,"); + } else { + signature.append("i64:1,"); + } } } @@ -460,14 +326,17 @@ void ParamStack::push_task_shape() { void ParamStack::push_task_partition() { for (auto& partition : task_partition) { - kernel_params.push_back(static_cast(&partition)); - signature.append("i64,"); + if (partition != 0) { + kernel_params.push_back(static_cast(&partition)); + signature.append("i64,"); + } else { + signature.append("i64:1,"); + } } } void ParamStack::add_global_scratch() { - void* global_scratch = nullptr; - kernel_params.push_back(global_scratch); + kernel_params.push_back(&global_scratch); } void ParamStack::build() { From 0b477d594fedecf9bd8f7d8a1056c5573d2c4036 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Sat, 16 Aug 2025 02:04:40 +0800 Subject: [PATCH 20/22] tmp --- ctests/test_triton_pointwise.cpp | 2 -- lib/CMakeLists.txt | 8 -------- reduce.log | 32 -------------------------------- 3 files changed, 42 deletions(-) delete mode 100644 reduce.log diff --git a/ctests/test_triton_pointwise.cpp b/ctests/test_triton_pointwise.cpp index 0dfce6b46..610b7fda6 100644 --- a/ctests/test_triton_pointwise.cpp +++ b/ctests/test_triton_pointwise.cpp @@ -9,7 +9,5 @@ TEST(pointwise_op_simple_test, add) { torch::Tensor out_torch = a + b; torch::Tensor out_triton = flag_gems::add_tensor(a, b); - std::cout << "out_torch sizes: " << out_torch.sizes() << std::endl; - std::cout << "out_triton sizes: " << out_triton.sizes() << std::endl; EXPECT_TRUE(torch::allclose(out_torch, out_triton)); } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e90ab92c5..9d318bd9f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,11 +1,3 @@ -find_package(Python COMPONENTS Interpreter Development REQUIRED) - -if(NOT Python_INCLUDE_DIRS OR NOT Python_LIBRARIES) - message(FATAL_ERROR "Python development files not found. Please ensure Python is installed and development headers are available.") -endif() - -include_directories(${Python_INCLUDE_DIRS}) - add_library(operators SHARED zeros.cpp diff --git a/reduce.log b/reduce.log deleted file mode 100644 index bec07588f..000000000 --- a/reduce.log +++ /dev/null @@ -1,32 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.10.12, pytest-8.4.1, pluggy-1.6.0 -rootdir: /home/fyf/FlagGems -configfile: pytest.ini -plugins: hypothesis-6.136.4 -collected 1468 items / 1462 deselected / 6 selected - -test_reduction_ops.py dim == [] in sum_dim_comm -.dim == [] in sum_dim_comm -.dim == [] in sum_dim_comm -.dim == [] in sum_dim_comm -.dim == [] in sum_dim_comm -.dim == [] in sum_dim_comm -. - -=============================== warnings summary =============================== -../../.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: 11 warnings - /home/fyf/.virtualenvs/flaggem/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details. - warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " - -tests/test_reduction_ops.py::test_accuracy_sum_all_dims_keepdim_forward_only[dtype0-shape0] - /home/fyf/.virtualenvs/flaggem/lib/python3.10/site-packages/torch/library.py:365: UserWarning: Warning only once for all operators, other operators may also be overridden. - Overriding a previously registered kernel for the same operator and the same dispatch key - operator: aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) - registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6 - dispatch key: CUDA - previous kernel: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18106 - new kernel: registered at /dev/null:223 (Triggered internally at /pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:154.) - self.m.impl( - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=============== 6 passed, 1462 deselected, 12 warnings in 0.81s ================ From 06af3becfaec43249d26a2f93d80cc5741d67c6c Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Sat, 16 Aug 2025 02:15:41 +0800 Subject: [PATCH 21/22] merge --- Testing/Temporary/CTestCostData.txt | 1 - Testing/Temporary/LastTest.log | 3 - .../_kunlunxin/utils/pointwise_dynamic.py | 2 - src/flag_gems/utils/pointwise_dynamic.py | 33 +- src/flag_gems/utils/shape_utils.py | 4 +- src/flag_gems/utils/tensor_wrapper.py | 9 - tests/test_pointwise_dynamic.py | 944 ++++++++++++++++++ 7 files changed, 946 insertions(+), 50 deletions(-) delete mode 100644 Testing/Temporary/CTestCostData.txt delete mode 100644 Testing/Temporary/LastTest.log create mode 100644 tests/test_pointwise_dynamic.py diff --git a/Testing/Temporary/CTestCostData.txt b/Testing/Temporary/CTestCostData.txt deleted file mode 100644 index ed97d539c..000000000 --- a/Testing/Temporary/CTestCostData.txt +++ /dev/null @@ -1 +0,0 @@ ---- diff --git a/Testing/Temporary/LastTest.log b/Testing/Temporary/LastTest.log deleted file mode 100644 index c4e94e56c..000000000 --- a/Testing/Temporary/LastTest.log +++ /dev/null @@ -1,3 +0,0 @@ -Start testing: Aug 07 17:33 CST ----------------------------------------------------------- -End testing: Aug 07 17:33 CST diff --git a/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py b/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py index 6c168ef6c..4b4f3b29c 100644 --- a/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py +++ b/src/flag_gems/runtime/backend/_kunlunxin/utils/pointwise_dynamic.py @@ -197,8 +197,6 @@ def signature(self, outputs_in_arg: bool = False) -> str: for _ in range(self.num_outputs()): output_types.append("StridedBuffer") sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' - print("signature function call") - print(sig) return sig def _compute_input_id(self): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 72a6c4661..74f33bba5 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -950,36 +950,6 @@ def gen_kernel_launch_1d( for i in range(schema.num_output_tensors()): code.writeline(f"out{i}_strides = out{i}.stride()") - # ---- 新增的打印参数的 codegen ---- - code.writeline("print('Kernel parameters:')") - # 打印输入张量 - for i in range(schema.num_input_tensors()): - code.writeline(f"print(f' in{i}: {{in{i}}}')") - # 打印输出张量 - for i in range(schema.num_output_tensors()): - code.writeline(f"print(f' out{i}: {{out{i}}}')") - - # 打印输入张量的 strides - if ndim > 0: - for i in range(schema.num_input_tensors()): - code.writeline(f"print(f' in{i}_strides: {{in{i}_strides}}')") - - # 打印输出张量的 strides - if ndim > 0: - for i in range(schema.num_output_tensors()): - code.writeline(f"print(f' out{i}_strides: {{out{i}_strides}}')") - - # 打印其他参数 - if ndim > 0: - shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) - code.writeline("print(f' shape: {shape[0]}')") - code.writeline("print(f' num_tasks: {num_tasks}')") - code.writeline("print(f' tiles_per_cta: {tiles_per_cta}')") - code.writeline("print(f' tile_size: {tile_size}')") - code.writeline("print(f' one_tile_per_cta: {one_tile_per_cta}')") - - code.writeline("print(f' num_warps: {num_warps}')") - for i in range(schema.num_input_tensors()): code.writeline(f"in{i}_strides = in{i}.stride()") for i in range(schema.num_output_tensors()): @@ -1177,7 +1147,6 @@ def prepare_args(self, *args, **kwargs): task_shape = (tensors[0].numel(),) strides = (1,) ndim = 1 - # print(args) # input args = tuple( ( StridedBuffer(item, task_shape, strides) @@ -1243,7 +1212,7 @@ def prepare_args(self, *args, **kwargs): k: StridedBuffer( item, task_shape, - (item.shape, item.stride(), task_shape), + broadcasted_stride(item.shape, item.stride(), task_shape), ) for k, item in kwargs.items() } diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index 50285af18..58c6009b1 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -159,9 +159,7 @@ def ordered_stride(shape: Shape, order: Perm) -> Stride: def stride_order(strides): # we also handle negative strides - res = sorted(range(len(strides)), key=lambda i: abs(strides[i])) - print("stride_order", res) - return res + return sorted(range(len(strides)), key=lambda i: abs(strides[i])) def all_the_same_shape(tensors: Sequence[torch.Tensor]) -> bool: diff --git a/src/flag_gems/utils/tensor_wrapper.py b/src/flag_gems/utils/tensor_wrapper.py index 184639b83..0b871108f 100644 --- a/src/flag_gems/utils/tensor_wrapper.py +++ b/src/flag_gems/utils/tensor_wrapper.py @@ -121,12 +121,3 @@ def copy_(self, src): src_buffer = StridedBuffer(src) self.copy_(src_buffer) return self - - def __repr__(self): - return ( - f"StridedBuffer(shape={self.shape}, " - f"strides={self._strides}, " - f"dtype={self.dtype}, " - f"offset={self.offset}, " - f"device={self.device})" - ) diff --git a/tests/test_pointwise_dynamic.py b/tests/test_pointwise_dynamic.py new file mode 100644 index 000000000..a15a8a3e3 --- /dev/null +++ b/tests/test_pointwise_dynamic.py @@ -0,0 +1,944 @@ +import concurrent.futures +import multiprocessing + +import pytest +import torch +import triton + +import flag_gems +from flag_gems.utils import get_device_properties +from flag_gems.utils.pointwise_dynamic import ( + CodeGenConfig, + FunctionSchema, + pointwise_dynamic, +) +from flag_gems.utils.tensor_wrapper import StridedBuffer + +MAX_GRID_SIZES = (65535, 65535, 65535) +MAX_GRID_SIZE_X = MAX_GRID_SIZES[0] + +USE_BLOCK_POINTER = [True, False] +triton_version_less_than3 = int(triton.__version__[0]) < 3 + +if flag_gems.vendor_name == "kunlunxin": + pytestmark = pytest.mark.skip("Test Files for Operators Not Pending Testing") + + +def test_function_schema_with_non_tensor_input(): + schema = FunctionSchema( + is_tensor=[True, False, True], + dtypes=[None, float, None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + assert schema.num_input_tensors() == 2 + assert schema.num_output_tensors() == 1 + assert schema.num_inputs() == 3 + assert schema.num_non_tensor_args() == 1 + assert schema.input_index(0) == 0 # the first input is the first input tensor + assert schema.input_index(1) == 0 # the second input is the first non tensor input + assert schema.input_index(2) == 1 # the third input is the second input tensor + + +def test_function_schema_mismatch_input_num1(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[True, False, True], + dtypes=[None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_mismatch_input_num2(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[True, False, True], + num_inputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_mismatch_input_num3(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + num_inputs=2, + dtypes=[None, None, None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_missing_output_dtype_promotion_rules(): + with pytest.raises(ValueError): + schema = FunctionSchema( + num_inputs=2, + dtypes=[None, None, None], + ) + _ = schema + + +def test_function_schema_mismatch_output_num(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + num_inputs=1, + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_missing_input_info(): + with pytest.raises(ValueError): + schema = FunctionSchema( + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_tensor_inputs1(): + # no tensor input is okay with FunctionSchema + schema = FunctionSchema( + is_tensor=[False, False, False], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_tensor_inputs2(): + schema = FunctionSchema( + num_inputs=3, + is_tensor=[False, False, False], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_outputs1(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[False, False, False], + promotion_methods=[], + ) + _ = schema + + +def test_function_schema_no_outputs2(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[False, False, False], + num_outputs=0, + promotion_methods=[], + ) + _ = schema + + +def test_function_schema_illegal_dtypes(): + with pytest.raises(AssertionError): + schema = FunctionSchema(dtypes=[0, False, "a"]) + _ = schema + + +def test_function_schema_multiple_outputs(): + schema = FunctionSchema( + num_inputs=3, + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, "ALWAYS_BOOL")], + ) + _ = schema + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_without_non_tensor_args(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")], config=config + ) + @triton.jit + def add(x, y): + return x + y + + SIZE = 2 + for ndim in range(8): + shape = [SIZE] * ndim + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + out = add(x, y) + torch.testing.assert_close(out, x + y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_non_tensor_args(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 2 + for ndim in range(8): + shape = [SIZE] * ndim + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_multiple_outputs(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + num_outputs=2, + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def multiple_out(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 2 + for ndim in range(8): + shape = [SIZE] * ndim + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + alpha = 2.0 + out0, out1 = multiple_out(x, y, alpha) + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_broadcasting(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=True, # [misaligned address] + ) + + # NOTE: [misaligned address] + # triton 2.2 may cause Misaligned address when using >=3d tiles in some + # cases with some zero strides + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + y = torch.randn([1, SIZE, 1], device=flag_gems.device) + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_broadcasting2(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=True, # [misaligned address] + ) + + # NOTE: See note [misaligned address] + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_predefined_out(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + out = axpy(x, y, alpha, out0=o) + torch.testing.assert_close(out, alpha * x + y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_some_predefined_out1(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_some_predefined_out2(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + out0, out1 = axpyaxmy(x, y, alpha, out1=o) + assert out1 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_bool_input_and_output(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=1, + is_tensor=[True], + promotion_methods=[(0, "DEFAULT")], + config=config, + ) + @triton.jit + def invert(x): + return ~x + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) > 0 + notx = invert(x) + + torch.testing.assert_close(notx, ~x) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_manual_instantiation(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=1, + is_tensor=[True], + promotion_methods=[(0, "DEFAULT")], + config=config, + ) + @triton.jit + def invert(x): + return ~x + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) > 0 + o = torch.empty_like(x) + # manually instantiated overload does not handle output allocation + # since it is kind of low level + notx = invert.instantiate(3)(x, out0=o) + torch.testing.assert_close(notx, ~x) + + +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_nd_buffer(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = torch.randn([M, N, K], device=flag_gems.device)[::2, ::2, ::2] + y = torch.randn([N // 2, K // 2, M // 2], device=flag_gems.device).permute(2, 0, 1) + alpha = 2.0 + o = torch.empty([M // 2, N // 2, K // 2], device=flag_gems.device) + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +# Cambricon add. +@pytest.mark.skipif(flag_gems.vendor_name != "cambricon", reason="Only for cambricon") +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_nd_buffer_out_permute(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = torch.randn([M, N, K], device="cuda")[::2, ::2, ::2] + y = torch.randn([M // 2, N // 2, K // 2], device="cuda") + alpha = 2.0 + o = torch.empty([M // 2, K // 2, N // 2], device="cuda").permute(0, 2, 1) + o2 = torch.empty([K // 2, M // 2, N // 2], device="cuda").permute(1, 2, 0) + print(o.stride(), o2.stride()) + out0, out1 = axpyaxmy(x, y, alpha, out0=o, out1=o2) + assert out0 is o and out1 is o2 + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.skipif(flag_gems.vendor_name != "cambricon", reason="Only for cambricon") +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_nd_buffer_broadcast(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = torch.randn([M, N, 2], device="cuda")[::2, ::2, ::2] + y = torch.randn([1, K // 2, M // 2], device="cuda").permute(2, 0, 1) + alpha = 2.0 + o = torch.empty([M // 2, N // 2, K // 2], device="cuda") + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.skipif(flag_gems.vendor_name != "cambricon", reason="Only for cambricon") +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_nd_buffer_expand(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = ( + torch.randn([1, K // 2, N // 2], device="cuda") + .permute(0, 2, 1) + .expand([M // 2, N // 2, K // 2]) + ) + y = ( + torch.randn([1, K // 2, M // 2], device="cuda") + .permute(2, 0, 1) + .expand([M // 2, N // 2, K // 2]) + ) + alpha = 2.0 + o = torch.empty([M // 2, N // 2, K // 2], device="cuda") + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +# Cambricon add end. + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_different_stride_order(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = torch.randn([M, N, K], device=flag_gems.device) + y = torch.randn([N, K, M], device=flag_gems.device).permute(2, 0, 1) + alpha = 2.0 + o = torch.empty([M, N, K], device=flag_gems.device) + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_manual_instantiation_mixing_strided_buffer_and_tensor( + use_block_pointer, +): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + alpha = 2.0 + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) + out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) + + assert isinstance(out0, torch.Tensor) + assert isinstance(out1, StridedBuffer) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_manual_instantiation_does_not_support_broadcasting1( + use_block_pointer, +): + # manually instantiated overload does not support broadcasting of operands + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([1, SIZE], device=flag_gems.device) + alpha = 2.0 + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) + + with pytest.raises(Exception): + out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_manual_instantiation_does_not_support_broadcasting2( + use_block_pointer, +): + # manually instantiated overload does not support broadcasting of operands + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + alpha = 2.0 + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) + + with pytest.raises(Exception): + out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_manual_instantiation_does_not_allocate_output( + use_block_pointer, +): + # manually instantiated overload does not support broadcasting of operands + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + alpha = 2.0 + + with pytest.raises(Exception): + out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_gsl(use_block_pointer): + config = CodeGenConfig( + max_tile_size=512, + max_grid_size=(80, 1, 1), + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")], config=config + ) + @triton.jit + def add(x, y): + return x + y + + SIZE = 2 + for ndim in range(8): + shape = [SIZE] * ndim + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + out = add(x, y) + torch.testing.assert_close(out, x + y) + + +@pytest.mark.skipif( + get_device_properties(0).total_memory < (80 * 1024**3), + reason="This test requires a lot of memory.", +) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_int64_index(use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=(MAX_GRID_SIZE_X, 1, 1), + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")], config=config) + @triton.jit + def f(x): + return x * 2.0 + + x = torch.randn((2, 1024, 1024, 1024), dtype=torch.float16, device=flag_gems.device) + y1 = f(x) + y2 = x * 2.0 + torch.testing.assert_close(y1, y2) + + +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_0d_task(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")], config=config + ) + @triton.jit + def add(x, y): + return x + y + + shape = () + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + out = add(x, y) + torch.testing.assert_close(out, x + y) + + +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +@pytest.mark.skipif(flag_gems.device == "musa", reason="TOFIX") +def test_dynamic_function_zero_sized_task_unary(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=(65536, 65536, 65536), + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")], config=config) + @triton.jit + def f(x): + return x * 2.0 + + shape = (0, 10) + x = torch.randn(shape, device=flag_gems.device) + out = f(x) + torch.testing.assert_close(out, x * 2.0) + + +@pytest.mark.parametrize("use_1d_tile", [True, False]) +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +@pytest.mark.skipif(flag_gems.device == "musa", reason="TOFIX") +def test_dynamic_function_zero_sized_task_binary(use_1d_tile, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=(65536, 65536, 65536), + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=use_1d_tile, + ) + + @pointwise_dynamic( + num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")], config=config + ) + @triton.jit + def f(x, y): + return x * 2.0 + y + + shape = (0, 10) + x = torch.randn(shape, device=flag_gems.device) + y = torch.randn_like(x) + out = f(x, y) + torch.testing.assert_close(out, x * 2.0 + y) + + +def f_for_concurrency_test(x, alpha, use_block_pointer): + config = CodeGenConfig( + max_tile_size=1024, + max_grid_size=MAX_GRID_SIZES, + max_num_warps_per_cta=32, + prefer_block_pointer=use_block_pointer, + prefer_1d_tile=False, + ) + + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + config=config, + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + y = torch.zeros_like(x) + out = axpy(x, y, alpha) + return out + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_multithread(use_block_pointer): + shape = [128] + alpha = 2.0 + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + inputs = [torch.randn(shape, device=flag_gems.device) for _ in range(32)] + expected_outs = [item * alpha for item in inputs] + outs = [] + for item in inputs: + out_future = executor.submit( + f_for_concurrency_test, item, alpha, use_block_pointer + ) + outs.append(out_future) + outs = [item.result() for item in outs] + + for out, expected_out in zip(outs, expected_outs): + torch.testing.assert_close(out, expected_out) + + +@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) +def test_dynamic_function_with_multiprocess(use_block_pointer): + shape = [128] + alpha = 2.0 + ctx = multiprocessing.get_context("spawn") + with concurrent.futures.ProcessPoolExecutor( + max_workers=8, mp_context=ctx + ) as executor: + inputs = [torch.randn(shape, device=flag_gems.device) for _ in range(32)] + expected_outs = [item * alpha for item in inputs] + outs = [] + for item in inputs: + out_future = executor.submit( + f_for_concurrency_test, item, alpha, use_block_pointer + ) + outs.append(out_future) + outs = [item.result() for item in outs] + + for out, expected_out in zip(outs, expected_outs): + torch.testing.assert_close(out, expected_out) From 8ffd6a15759d969971dbb20fb690efa2815e6183 Mon Sep 17 00:00:00 2001 From: scatyf3 Date: Sat, 16 Aug 2025 02:28:01 +0800 Subject: [PATCH 22/22] fix build --- lib/CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 9d318bd9f..d2d79a25e 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,9 @@ +find_package(Python COMPONENTS Interpreter Development REQUIRED) +if(NOT Python_INCLUDE_DIRS OR NOT Python_LIBRARIES) + message(FATAL_ERROR "Python development files not found. Please ensure Python is installed and development headers are available.") +endif() +include_directories(${Python_INCLUDE_DIRS}) + add_library(operators SHARED zeros.cpp