Skip to content

Commit

Permalink
added XLA custom ops and C++ infrastructure for comm+GEMM overlap in …
Browse files Browse the repository at this point in the history
…TE/JAX

Signed-off-by: Alp Dener <[email protected]>

comm+GEMM overlap API for TE/JAX compiles, untested, but did not break collective GEMM op

Signed-off-by: Alp Dener <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
  • Loading branch information
denera committed Nov 21, 2024
1 parent dc384b1 commit 11ad5ec
Show file tree
Hide file tree
Showing 21 changed files with 1,721 additions and 215 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/dlpack"]
path = 3rdparty/dlpack
url = [email protected]:dmlc/dlpack.git
1 change: 1 addition & 0 deletions 3rdparty/dlpack
Submodule dlpack added at bbd2f4
20 changes: 20 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""JAX related extensions."""
import os
from pathlib import Path
from typing import Optional

import setuptools
from glob import glob
Expand Down Expand Up @@ -36,6 +37,7 @@ def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
third_party_packages,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
Expand All @@ -55,12 +57,28 @@ def setup_jax_extension(
common_header_files / "common" / "include",
csrc_header_files,
xla_home,
third_party_packages / "dlpack" / "include",
]

# Compile flags
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]

# Userbuffers MPI dependence
libraries = []
library_dirs = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
mpi_home = os.getenv("MPI_HOME")
assert mpi_home is not None, "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(mpi_home)
libraries.append("mpi")
library_dirs.append(mpi_home / "lib")

include_dirs.append(mpi_home / "include")

cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

Expand All @@ -79,5 +97,7 @@ def _add_cflags(self, flags: List[str]) -> None:
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
library_dirs=[str(path) for path in library_dirs],
libraries=libraries,
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine",
current_file_path / "3rdparty",
)
)
if "paddle" in frameworks:
Expand Down
72 changes: 66 additions & 6 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool atomic_gemm)
bool atomic_gemm, bool overlap_first_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, false, atomic_gemm) {
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
_overlap_first_gemm = overlap_first_gemm;
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
"or 2 (multi-atomic).");
Expand All @@ -146,6 +147,36 @@ CommOverlapBase::~CommOverlapBase() {
cudaStreamDestroy(_stream_comm);
}

TensorWrapper CommOverlapBase::get_ubuf_output(CommOverlapType comm_type) {
char *output_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (comm_type == CommOverlapType::RS)
output_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
size_t output_c_dim0 =
(comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
size_t output_c_dim1 = _ubuf.size(1);
return TensorWrapper(reinterpret_cast<void *>(output_ptr), {output_c_dim0, output_c_dim1},
_ubuf.dtype());
}

void CommOverlapBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input,
CommOverlapType comm_type) {
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (comm_type == CommOverlapType::AG) {
if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() ||
input.element_size() != (int64_t)_ubuf.element_size()) {
NVTE_ERROR("Input and buffer sizes do not match!");
}
ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
} else {
if (input.numel() != (int64_t)_ubuf.numel() ||
input.element_size() != (int64_t)_ubuf.element_size()) {
NVTE_ERROR("Input and buffer sizes do not match!");
}
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.dptr(), input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, stream));
}

/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
Expand Down Expand Up @@ -201,8 +232,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
TensorWrapper &rs_output, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
Expand Down Expand Up @@ -301,8 +331,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
TensorWrapper &rs_output, cudaStream_t stream_main) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
Expand Down Expand Up @@ -334,7 +363,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap

assert(pre_gelu_out.numel() == 0);

if (gemm_overlap) {
if (_overlap_first_gemm) {
auto input_a_chunk =
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
Expand Down Expand Up @@ -541,6 +570,37 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaStreamDestroy(_stream_send);
}

TensorWrapper CommOverlapP2PBase::get_ubuf_output(CommOverlapType comm_type) {
char *output_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (comm_type == CommOverlapType::RS)
output_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
size_t output_c_dim0 =
(comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
size_t output_c_dim1 = _ubuf.size(1);
return TensorWrapper(reinterpret_cast<void *>(output_ptr), {output_c_dim0, output_c_dim1},
_ubuf.dtype());
}

void CommOverlapP2PBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input,
CommOverlapType comm_type) {
if (comm_type == CommOverlapType::RS) {
// Copy input to the target ubuf chunk by rank offset
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {
NVTE_ERROR("Input and buffer sizes do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.dptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
stream));
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("Input and buffer sizes do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.dptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
stream));
}
}

/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
Expand Down
Loading

0 comments on commit 11ad5ec

Please sign in to comment.