Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bfb5ff3
[kunlunxin] skip unsuported operator and open supported operator (#802)
nianqi-tian Jul 18, 2025
51ae527
[KUNLUNXIN]tmp disable some operators and open operator (#805)
nianqi-tian Jul 21, 2025
c72320f
add max.dim_max
scatyf3 Jul 30, 2025
df98af1
use CUstream instead c10::cuda::CUDAStream
scatyf3 Jul 31, 2025
f2269a2
Merge branch 'FlagOpen:master' into misc
scatyf3 Jul 31, 2025
53923fd
use CUstream instead c10::cuda::CUDAStream
scatyf3 Jul 31, 2025
f7b0993
重构
scatyf3 Jul 31, 2025
c675f78
Merge branch 'FlagOpen:master' into misc
scatyf3 Jul 31, 2025
110f47d
add max_dim_max and sum, move permute_reduction_axes_right to utils, …
scatyf3 Jul 31, 2025
d674544
add and fix
scatyf3 Jul 31, 2025
9e50fd6
fix error
scatyf3 Jul 31, 2025
3a1e357
Merge branch 'FlagOpen:master' into misc
scatyf3 Aug 5, 2025
3d81e81
Merge branch 'FlagOpen:master' into misc
scatyf3 Aug 5, 2025
1a74427
tmp update
scatyf3 Aug 6, 2025
1117b82
tmp
scatyf3 Aug 7, 2025
b782083
Merge branch 'FlagOpen:master' into misc
scatyf3 Aug 7, 2025
3685efb
sum cpp warper
scatyf3 Aug 8, 2025
24c4e40
remove useless comment
scatyf3 Aug 8, 2025
e8e6015
tmp update for merge
scatyf3 Aug 8, 2025
b4e46fa
merge
scatyf3 Aug 8, 2025
9022878
pointwise dynamic add finish use fast path version
scatyf3 Aug 8, 2025
aedc21c
tmp
scatyf3 Aug 12, 2025
720192b
tmp change
scatyf3 Aug 14, 2025
e01d53c
use fast path work
scatyf3 Aug 15, 2025
dec3f3d
pointwise dynamic fastpath
scatyf3 Aug 15, 2025
425363c
merge
scatyf3 Aug 15, 2025
0b477d5
tmp
scatyf3 Aug 15, 2025
06af3be
merge
scatyf3 Aug 15, 2025
8ffd6a1
fix build
scatyf3 Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions ctests/test_triton_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
#include "flag_gems/operators.h"
#include "torch/torch.h"

TEST(pointwise_op_test, add) {
TEST(pointwise_op_simple_test, add) {
const torch::Device device(torch::kCUDA, 0);
torch::Tensor a = torch::randn({10, 10}, device);
torch::Tensor b = torch::randn({10, 10}, device);
torch::Tensor a = torch::randn({128}, device);
torch::Tensor b = torch::randn({128}, device);

torch::Tensor out_torch = a + b;
torch::Tensor out_triton = flag_gems::add_tensor(a, b);

EXPECT_TRUE(torch::allclose(out_torch, out_triton));
}
55 changes: 49 additions & 6 deletions include/flag_gems/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,59 @@
#include "torch/torch.h"

namespace flag_gems::utils {

using Shape = c10::IntArrayRef;
std::filesystem::path get_path_of_this_library();
std::filesystem::path get_triton_src_path();
std::filesystem::path get_flag_gems_src_path();
int64_t next_power_of_2(int64_t n);
bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2);
std::tuple<at::Tensor, int64_t, int64_t> permute_reduction_axes_right(const at::Tensor &tensor,
std::tuple<at::Tensor, int64_t, int64_t> permute_reduction_axes_right(const at::Tensor& tensor,
int reduction_axis);
std::tuple<at::Tensor, int64_t, int64_t> permute_reduction_axes_right(
const at::Tensor &tensor, at::OptionalIntArrayRef reduction_axes_opt);
std::tuple<int64_t, int64_t, int64_t> parse_reduction_axes(const at::Tensor &tensor, int reduction_axis);
const at::Tensor& tensor, at::OptionalIntArrayRef reduction_axes_opt);
std::tuple<int64_t, int64_t, int64_t> parse_reduction_axes(const at::Tensor& tensor, int reduction_axis);
int cdiv(int a, int b);
} // namespace flag_gems::utils
bool broadcastable_to(at::IntArrayRef s1, at::IntArrayRef s2);
}; // namespace flag_gems::utils

namespace flag_gems::pointwise_dynamic {
void checkIfScalar(const torch::Tensor& tensor1,
const torch::Tensor& tensor2,
std::array<bool, 2>& is_tensor);
bool use_fast_path(const std::vector<at::Tensor>& tensors);

class ParamStack {
private:
std::vector<void*> kernel_params;
std::string signature;
std::vector<void*> tensor_ptr;
std::vector<int64_t> strides;
std::vector<int64_t> task_shape;
std::vector<int64_t> task_partition;
std::string constexp;
void* global_scratch;

private:
void push_strides();
void push_task_shape();
void push_task_partition();
void add_global_scratch();

public:
ParamStack(int max_args = 32) {
kernel_params.reserve(max_args);
tensor_ptr.reserve(max_args);
void* global_scratch = nullptr;
}
void save_tensor(at::Tensor& tensor);
void save_tensor(const at::Tensor& tensor);
void save_stride(int64_t stride);
void save_task_shape(int64_t shape);
void save_task_partition(int64_t partition);
void save_constexpr(int64_t value);
void save_constexpr(bool value);
void** get_params();
std::string get_signature();

void build();
};
}; // namespace flag_gems::pointwise_dynamic
13 changes: 10 additions & 3 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
find_package(Python COMPONENTS Interpreter Development REQUIRED)
if(NOT Python_INCLUDE_DIRS OR NOT Python_LIBRARIES)
message(FATAL_ERROR "Python development files not found. Please ensure Python is installed and development headers are available.")
endif()
include_directories(${Python_INCLUDE_DIRS})

add_library(operators
SHARED
zeros.cpp
utils.cpp
add.cpp
sum.cpp
max.cpp
mm.cpp
Expand All @@ -18,9 +23,11 @@ add_library(operators
bmm.cpp
embedding.cpp
argmax.cpp
fill.cpp
softmax.cpp
exponential_.cpp
fill.cpp)
pointwise_dynamic.cpp
exponential_.cpp
)
target_include_directories(operators
PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
Expand Down
39 changes: 0 additions & 39 deletions lib/add.cpp

This file was deleted.

87 changes: 87 additions & 0 deletions lib/pointwise_dynamic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "flag_gems/operators.h"
#include "flag_gems/utils.h"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iostream>
#include "c10/cuda/CUDAStream.h"
#include "c10/util/Logging.h"
#include "pybind11/embed.h"
#include "triton_jit/pointwise_generator.h"
#include "triton_jit/triton_jit_function.h"

namespace flag_gems {
using namespace triton_jit;

namespace py = pybind11;
at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) {
pointwise_dynamic::ParamStack stk = pointwise_dynamic::ParamStack();
int64_t task_shape, ndim;
int64_t num_ctas;
int64_t tiles_per_cta;
int64_t tile_sizes;
int64_t num_tiles;
at::Tensor out = at::empty_like(a_);
std::vector<at::Tensor> tensors = {a_, b_, out};
const int num_warps = 4;
const int num_stages = 1;
if (pointwise_dynamic::use_fast_path(tensors)) {
task_shape = a_.numel();
int64_t stride = 1;
ndim = 1;
stk.save_stride(stride);
stk.save_stride(stride);
stk.save_stride(stride);
stk.save_task_shape(task_shape);
stk.save_task_shape(task_shape);
tile_sizes = num_warps * 32;
num_tiles = utils::cdiv(task_shape, tile_sizes);
num_ctas = std::min(static_cast<int64_t>(65536), num_tiles);
tiles_per_cta = utils::cdiv(num_tiles, num_ctas);
stk.save_task_partition(tiles_per_cta);
} else {
std::runtime_error("NotImplementError");
}
stk.save_constexpr(tile_sizes);
int64_t one_tile_per_cta = (tiles_per_cta == 1);
stk.save_constexpr(one_tile_per_cta);

std::array<bool, 2> is_scalar;
pointwise_dynamic::checkIfScalar(a_, b_, is_scalar);
std::optional<TritonJITFunction> f;
auto ans_tuple = gen_add(ndim);
std::string kernel_name = std::get<0>(ans_tuple);
std::string file_path = std::get<1>(ans_tuple);
if (!is_scalar[0] && !is_scalar[1]) {
f = TritonJITFunction::getInstance(file_path, kernel_name);
} else if (!is_scalar[0] && is_scalar[1]) {
std::runtime_error("NotImplementError");
f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"),
"add_func_tensor_scalar");
} else if (is_scalar[0] && !is_scalar[1]) {
std::runtime_error("NotImplementError");
f = TritonJITFunction::getInstance(std::string(utils::get_flag_gems_src_path() / "ops" / "add.py"),
"add_func_scalar_tensor");
} else {
return a_ + b_;
}
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
c10::DeviceGuard guard(out.device());
CUstream raw_stream = static_cast<CUstream>(stream.stream());

stk.save_tensor(a_);
stk.save_tensor(b_);
stk.save_tensor(out);
stk.build();
f->launch_with_raw_args(raw_stream,
num_ctas,
1,
1,
num_warps,
num_stages,
stk.get_signature(),
stk.get_params());
return out;
}

}; // namespace flag_gems
Loading
Loading