Skip to content

Commit

Permalink
[JAX] Fix incorrect sharding when only enable FSDP and Mem Misaligned…
Browse files Browse the repository at this point in the history
… in LN_BWD. (#379)

* [JAX] Fix incorrect sharding when only enable FSDP.

Signed-off-by: Ming Huang <[email protected]>

* [JAX] Add WAR to memory misaligned issues of LN BWD.

Signed-off-by: Ming Huang <[email protected]>

* [JAX] Reuse sm_arch for avoiding duplicate code.

Signed-off-by: Ming Huang <[email protected]>

* [JAX] Support multiple sizes allocation in WorkspaceManager.

Signed-off-by: Ming Huang <[email protected]>

* [JAX] Use template and ariadic arguments to improve multple sizes allocator.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
  • Loading branch information
mingxu1067 authored Aug 30, 2023
1 parent b8ba734 commit 3a63b13
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 35 deletions.
1 change: 1 addition & 0 deletions transformer_engine/common/util/cuda_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_

#include <cuda_runtime_api.h>
#include <string>

namespace transformer_engine {

Expand Down
20 changes: 8 additions & 12 deletions transformer_engine/jax/csrc/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);

size_t workspace_size = kCublasLtForwardWorkspaceSize;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);

nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
Expand Down Expand Up @@ -327,7 +327,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) +
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());

void *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);

auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
Expand Down Expand Up @@ -412,13 +412,9 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
dummy_dgamma_part_tensor.shape().data[1] *
typeToSize(dummy_dgamma_part_tensor.dtype());
size_t total_workspace_size =
(workspace_size + barrier_size + dgamma_part_size + dbeta_part_size);

void *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
void *barrier = static_cast<char *>(workspace) + workspace_size;
void *dgamma_part = static_cast<char *>(barrier) + barrier_size;
void *dbeta_part = static_cast<char *>(dgamma_part) + dgamma_part_size;
auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace(
workspace_size, dgamma_part_size, dbeta_part_size, barrier_size);

auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
Expand Down Expand Up @@ -811,7 +807,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
output_s->data.dptr = softmax_aux;

auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

Expand Down Expand Up @@ -894,7 +890,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
query_workspace_tensor.data(), stream);

size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

Expand Down Expand Up @@ -978,7 +974,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

Expand Down Expand Up @@ -1074,7 +1070,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa

size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);

auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
Expand Down
16 changes: 2 additions & 14 deletions transformer_engine/jax/csrc/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda_runtime_api.h>
#include <cassert>

#include "common/util/cuda_runtime.h"
#include "utils.h"

namespace transformer_engine {
Expand All @@ -17,20 +18,7 @@ int GetCudaRuntimeVersion() {
return ver;
}

int GetDeviceComputeCapability(int gpu_id) {
int max_num_gpu = 0;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&max_num_gpu));
assert(gpu_id < max_num_gpu);

int major = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, gpu_id));

int minor = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, gpu_id));

int gpu_arch = major * 10 + minor;
return gpu_arch;
}
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }

__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
Expand Down
31 changes: 26 additions & 5 deletions transformer_engine/jax/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <pybind11/pybind11.h>

#include <cstdint>
#include <numeric>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand All @@ -26,25 +27,44 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);

class cublasLtMetaManager {
class WorkspaceManager {
public:
static cublasLtMetaManager &Instance() {
static thread_local cublasLtMetaManager instance;
static WorkspaceManager &Instance() {
static thread_local WorkspaceManager instance;
return instance;
}

cublasLtMetaManager() {}
~cublasLtMetaManager() { Clear_(); }
WorkspaceManager() {}
~WorkspaceManager() { Clear_(); }

void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}

template <typename... Args>
inline auto GetWorkspace(Args... args) {
auto asks = std::array<size_t, sizeof...(Args)>{args...};
std::array<size_t, sizeof...(Args) + 1> offsets = {0};
std::array<void *, sizeof...(Args)> workspaces = {nullptr};
std::transform_inclusive_scan(
asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
[=](auto x) { return PadSize_(x); }, 0);
auto *workspace = GetWorkspace(offsets.back());
std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
[workspace](auto x) { return static_cast<char *>(workspace) + x; });
return workspaces;
}

private:
void *workspace_ = nullptr;
size_t size_ = 0;

size_t PadSize_(size_t size) {
constexpr size_t alignment = 128;
return ((size + alignment - 1) / alignment) * alignment;
}

void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
Expand All @@ -54,6 +74,7 @@ class cublasLtMetaManager {
}

void Allocate_(size_t new_size) {
new_size = PadSize_(new_size);
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
Expand Down
11 changes: 7 additions & 4 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def infer_major_sharding_type() -> MajorShardingType:
"""
gsr = global_shard_resource()

resources = [gsr.dp_resource, gsr.tp_resource]
resources = [gsr.dp_resource, gsr.tp_resource, gsr.fsdp_resource]
for idx, rs in enumerate(resources):
try:
size, _ = _get_mesh_info(rs)
Expand All @@ -149,12 +149,15 @@ def infer_major_sharding_type() -> MajorShardingType:

dp_resource = resources[0]
tp_resource = resources[1]
fsdp_resource = resources[2]

if dp_resource is not None and \
tp_resource is not None :
def dp_enabled():
return (fsdp_resource is not None) or (dp_resource is not None)

if dp_enabled() and tp_resource is not None:
return MajorShardingType.DPTP

if dp_resource is not None:
if dp_enabled():
return MajorShardingType.DP

if tp_resource is not None:
Expand Down

0 comments on commit 3a63b13

Please sign in to comment.