Skip to content

Commit

Permalink
AG+GEMM overlap working
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 3, 2024
1 parent 9126c3a commit 616e301
Show file tree
Hide file tree
Showing 9 changed files with 873 additions and 638 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace transformer_engine {
*/
bool ubuf_built_with_mpi();

enum class CommOverlapType : int32_t { RS = 0, AG = 1 };
enum class CommOverlapType : int { RS = 0, AG = 1 };

enum class CommOverlapAlgo : int32_t {
enum class CommOverlapAlgo : int {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG_P2P = 2,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pybind11::enum_<transformer_engine::DType>(m, "DType") \
.value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \
.value("kInt64", transformer_engine::DType::kInt64) \
.value("kFloat32", transformer_engine::DType::kFloat32) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
Expand Down
762 changes: 486 additions & 276 deletions transformer_engine/jax/cpp_extensions/gemm.py

Large diffs are not rendered by default.

109 changes: 40 additions & 69 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,44 +171,6 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_
bool fuse_bias, bool grad, bool accumulate,
bool use_split_accumulator);

struct CustomCallBufferDescriptor {
const std::string name;
const size_t *shape;
const size_t ndim;
DType dtype;
CommOverlapType comm_type;
};

pybind11::bytes PackCustomCallBufferDescriptor(const std::string &name,
const std::vector<size_t> &shape, DType dtype,
CommOverlapType comm_type);

struct CustomCallOverlapDescriptor {
size_t m;
size_t k;
size_t n;
size_t workspace_size;
DType operand_dtype;
DType bias_dtype;
DType out_dtype;
bool lhs_trans;
bool rhs_trans;
bool fuse_gelu;
bool fuse_bias;
bool grad;
bool accumulate;
bool use_split_accumulator;
CommOverlapType comm_type;
const std::string name;
};

pybind11::bytes PackCustomCallOverlapDescriptor(size_t m, size_t k, size_t n, size_t workspace_size,
DType operand_dtype, DType bias_dtype,
DType out_dtype, bool lhs_trans, bool rhs_trans,
bool fuse_gelu, bool fuse_bias, bool grad,
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, const std::string &name);

// Transpose

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
Expand Down Expand Up @@ -372,54 +334,63 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);

// GEMM

XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasltHandleInitHandler);

void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out,
Result_Type out_amax_updated, Result_Type out_scale_updated,
Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type dummy_out,
Result_Type workspace, bool lhs_trans, bool rhs_trans, bool fuse_gelu,
bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator);
Error_Type GemmFFI(
cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out,
Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type dummy_in, Result_Type out_updated,
Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out,
Result_Type bias_grad, Result_Type dummy_out, Result_Type workspace, bool lhs_trans,
bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate,
bool use_split_accumulator);

XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);

// Comm+GEMM Overlap

void BootstrapCommGemmOverlap(const std::string &name, const std::string &method,
const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, int tp_size, int num_splits,
int num_max_streams, int comm_cga_size, int num_comm_sm,
int set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate,
bool pipeline_rs_overlap_first_gemm);
bool OverlapBufferIsFp8(const std::string &name);

void DestroyCommGemmOverlap(const std::string &name);
pybind11::object GetOverlapBuffer(const std::string &name, bool sharded);

void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv,
bool grad = false);
void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad);

bool OverlapBufferIsFp8(const std::string &name);
void BootstrapCommGemmOverlap(
const std::vector<size_t> &buffer_shape, DType buffer_dtype, const std::string &name,
const std::string &method, CommOverlapType comm_type, int64_t myrank, int64_t numranks,
int64_t tp_size, int64_t num_splits, int64_t num_max_streams, int64_t cga_size,
int64_t num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate,
bool pipeline_rs_overlap_first_gemm);

Error_Type BootstrapCommGemmOverlapFFI(
cudaStream_t, Buffer_Type sample_buffer, std::string_view name, std::string_view method,
int64_t comm_type_flag, int64_t myrank, int64_t numranks, int64_t tp_size, int64_t num_splits,
int64_t num_max_streams, int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm, bool aggregate, bool pipeline_rs_overlap_first_gemm);

XLA_FFI_DECLARE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler);

void DestroyCommGemmOverlap(const std::string &name);

pybind11::object GetOverlapBuffer(const std::string &name, CommOverlapType comm_type);
Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name);

void CopyIntoOverlapBuffer(cudaStream_t, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DestroyCommGemmOverlapHandler);

Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name,
int32_t comm_type_flag);
bool sharded);

XLA_FFI_DECLARE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler);

void CommGemmOverlap(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv,
Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale,
Result_Type out, Result_Type out_amax_new, Result_Type out_scale_new,
Result_Type pre_gelu_out, Result_Type bias_grad,
Result_Type extra_out, Result_Type workspace, bool lhs_trans,
bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad,
bool accumulate, bool use_split_accumulator, int32_t comm_type_flag,
std::string_view name);
Error_Type CommGemmOverlapFFI(
cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out,
Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated,
Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out,
Result_Type bias_grad, Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans,
bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate,
bool use_split_accumulator, int64_t comm_type_flag, std::string_view name);

XLA_FFI_DECLARE_HANDLER_SYMBOL(CommGemmOverlapHandler);

Expand Down
Loading

0 comments on commit 616e301

Please sign in to comment.