Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gda/backend_gda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void GDABackend::setup_host_ctx() {

void GDABackend::setup_default_ctx() {
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
default_context_proxy_ = GDADefaultContextProxyT(this, tinfo);
default_context_proxy_ = GDADefaultContextProxyT(this, tinfo, gda_provider);
}

void GDABackend::setup_ctxs() {
Expand All @@ -166,7 +166,7 @@ void GDABackend::setup_ctxs() {
CHECK_HIP(hipMalloc(&ctx_array, sizeof(GDAContext) * envvar::max_num_contexts));
// 0th context is default context
for (size_t i = 0; i < envvar::max_num_contexts; i++) {
new (&ctx_array[i]) GDAContext(this, i + 1);
new (&ctx_array[i]) GDAContext(this, i + 1, gda_provider);
ctx_free_list.get()->push_back(ctx_array + i);
}
}
Expand Down
88 changes: 88 additions & 0 deletions src/gda/bnxt/queue_pair_bnxt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ __device__ void QueuePair::bnxt_quiet() {
}
}

__device__ void QueuePair::bnxt_quiet_single() {
poll_cq_until(sq.depth);
}

__device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
uint64_t active_lane_mask;
uint8_t active_lane_count;
Expand Down Expand Up @@ -301,6 +305,90 @@ __device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t *
}
}

__device__ void QueuePair::bnxt_post_wqe_rma_single(int pe, int32_t length, uintptr_t *laddr,
uintptr_t *raddr, uint8_t opcode) {
uint64_t active_lane_mask;
uint8_t active_lane_count;
uint8_t active_lane_id;
struct bnxt_re_bsqe hdr;
struct bnxt_re_rdma rdma;
struct bnxt_re_sge sge;
struct bnxt_re_bsqe *hdr_ptr;
struct bnxt_re_rdma *rdma_ptr;
struct bnxt_re_sge *sge_ptr;
uint32_t wqe_size;
uint32_t wqe_type;
uint32_t hdr_flags;
uint32_t inline_msg;

aquire_lock(&sq.lock);

inline_msg = length <= inline_threshold &&
opcode == gda_op_rdma_write;

poll_cq_until(GDA_BNXT_WQE_SLOT_COUNT);

hdr_ptr = (struct bnxt_re_bsqe*) bnxt_re_get_hwqe(&sq, 0);
rdma_ptr = (struct bnxt_re_rdma*) bnxt_re_get_hwqe(&sq, 1);
sge_ptr = (struct bnxt_re_sge*) bnxt_re_get_hwqe(&sq, 2);

/* Populate Header Segment */
wqe_type = BNXT_RE_HDR_WT_MASK & opcode;
wqe_size = BNXT_RE_HDR_WS_MASK & GDA_BNXT_WQE_SLOT_COUNT;
hdr_flags = ((uint32_t) BNXT_RE_HDR_FLAGS_MASK)
& ((uint32_t) BNXT_RE_WR_FLAGS_SIGNALED);

if (inline_msg) {
hdr_flags |= ((uint32_t) BNXT_RE_WR_FLAGS_INLINE);
}

hdr.rsv_ws_fl_wt = (wqe_size << BNXT_RE_HDR_WS_SHIFT)
| (hdr_flags << BNXT_RE_HDR_FLAGS_SHIFT)
| wqe_type;
hdr.key_immd = 0;
hdr.lhdr.qkey_len = length;

/* Populate RDMA Segment */
rdma.rva = (uint64_t) raddr;
rdma.rkey = rkey;

if (!inline_msg) {
/* Populate SG Segment */
sge.pa = (uint64_t) laddr;
sge.lkey = lkey;
sge.length = length;
}

/* Write WQE to SQ */
memcpy(hdr_ptr, &hdr, sizeof(struct bnxt_re_bsqe));
memcpy(rdma_ptr, &rdma, sizeof(struct bnxt_re_rdma));

if (inline_msg) {
memcpy(sge_ptr, laddr, length);
} else {
memcpy(sge_ptr, &sge, sizeof(struct bnxt_re_sge));
}

/* Populate MSN Table */
bnxt_re_fill_psns_for_msntbl(&sq, length);

/* Update SQ Pointer */
bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT);

/* Ring Doorbell
* Doorbell ring must be serialized as we cannot have all threads write to the same address */
active_lane_mask = get_active_lane_mask();
active_lane_count = get_active_lane_count(active_lane_mask);
active_lane_id = get_active_lane_num(active_lane_mask);

for (int i = 0; i < active_lane_count; i++) {
if (i == active_lane_id) {
bnxt_ring_doorbell(sq.tail);
release_lock(&sq.lock);
}
}
}

__device__ uint64_t QueuePair::bnxt_post_wqe_amo(int pe, int32_t length, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
uint64_t active_lane_mask;
Expand Down
7 changes: 6 additions & 1 deletion src/gda/context_gda_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

namespace rocshmem {

__host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id)
__host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id, int gda_provider)
: Context(b, false) {
GDABackend *backend{static_cast<GDABackend *>(b)};
base_heap = backend->heap.get_heap_bases().data();
Expand All @@ -56,6 +56,7 @@ __host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id)
ipcImpl_.pes_with_ipc_avail = backend->ipcImpl.pes_with_ipc_avail;

ctx_id_ = ctx_id;
gda_provider_ = gda_provider;
}

__host__ GDAContext::~GDAContext() {
Expand Down Expand Up @@ -147,6 +148,10 @@ __device__ void GDAContext::pe_quiet(size_t pe) {
qps[pe].quiet();
}

__device__ void GDAContext::pe_quiet_single(size_t pe) {
qps[pe].quiet_single();
}

__device__ void *GDAContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
int local_pe{-1};
Expand Down
13 changes: 12 additions & 1 deletion src/gda/context_gda_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class QueuePair;

class GDAContext : public Context {
public:
__host__ GDAContext(Backend *b, unsigned int ctx_id);
__host__ GDAContext(Backend *b, unsigned int ctx_id, int gda_provider);

__host__ ~GDAContext();

Expand Down Expand Up @@ -63,6 +63,7 @@ class GDAContext : public Context {
__device__ void quiet_wave();

__device__ void pe_quiet(size_t pe);
__device__ void pe_quiet_single(size_t pe);

__device__ void *shmem_ptr(const void *dest, int pe);

Expand Down Expand Up @@ -257,6 +258,10 @@ class GDAContext : public Context {
__device__ void alltoall_linear(rocshmem_team_t team, T *dest,
const T *source, int nelems);

template <typename T>
__device__ void alltoall_linear_thread_puts(rocshmem_team_t team, T *dest,
const T *source, int nelems);

__device__ void internal_sync(int pe, int PE_start, int stride, int PE_size,
int64_t *pSync);

Expand All @@ -272,6 +277,10 @@ class GDAContext : public Context {
__device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);

__device__ void internal_direct_barrier_wg_thread_puts(int pe, int PE_start,
int stride, int n_pes,
int64_t *pSync);

__device__ void internal_atomic_barrier(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);

Expand All @@ -298,6 +307,8 @@ class GDAContext : public Context {
*/
unsigned int ctx_id_{};

int gda_provider_{0};

public:
QueuePair *qps{nullptr};

Expand Down
48 changes: 48 additions & 0 deletions src/gda/context_gda_device_coll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,54 @@ __device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start,
}
}

__device__ void GDAContext::internal_direct_barrier_wg_thread_puts(int pe, int PE_start,
int stride, int n_pes,
int64_t *pSync) {
int64_t flag_val{1};

if (pe == PE_start) {
int tid = get_flat_block_id();

// Go through all PE offsets (except current offset = 0)
// and wait until they all reach
for (int j = tid + 1; j < n_pes; j+= WF_SIZE) {
wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val);
pSync[j] = ROCSHMEM_SYNC_VALUE;
}

__syncthreads();

// Announce to other PEs that all have reached
for (int i = tid + 1, j = PE_start + stride + tid;
i < n_pes;
i+= WF_SIZE, j += (WF_SIZE * stride)) {
uint64_t L_offset = reinterpret_cast<char*>(&pSync[0]) - base_heap[my_pe];
qps[j].put_nbi_single(base_heap[j] + L_offset, &flag_val, sizeof(long), j);
}

for (int i = tid + 1, j = PE_start + stride + tid;
i < n_pes;
i+= WF_SIZE, j += (WF_SIZE * stride)) {
pe_quiet_single(j);
}

__syncthreads();

if (is_thread_zero_in_block()) {
pSync[0] = ROCSHMEM_SYNC_VALUE;
}
} else {
if (is_thread_zero_in_block()) {
// Mark current PE offset as reached
size_t pe_offset = (pe - PE_start) / stride;
putmem(&pSync[pe_offset], &flag_val, sizeof(long), PE_start);
wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val);
pSync[0] = ROCSHMEM_SYNC_VALUE;
__threadfence_system();
}
}
}

__device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start,
int stride, int n_pes,
int64_t *pSync) {
Expand Down
39 changes: 38 additions & 1 deletion src/gda/context_gda_tmpl_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "gda_team.hpp"
#include "queue_pair.hpp"
#include "rocshmem_calc.hpp"
#include "backend_gda.hpp"

#include <hip/hip_runtime.h>

Expand Down Expand Up @@ -604,7 +605,11 @@ __device__ void GDAContext::internal_broadcast(T *dst, const T *src, int nelems,
template <typename T>
__device__ void GDAContext::alltoall(rocshmem_team_t team, T *dst,
const T *src, int nelems) {
alltoall_linear(team, dst, src, nelems);
if (gda_provider_ == GDAProvider::BNXT) {
alltoall_linear_thread_puts(team, dst, src, nelems);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not applicable for other NICs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Should work on IONIC (untested, so I've left it out)
  • There is a pretty major PR in the works for CX7, adding the logic to support multiple threads writing into their own QP would create a major conflict for them. It would be best to add this support once the other PR is merged

} else {
alltoall_linear(team, dst, src, nelems);
}
}

template <typename T>
Expand Down Expand Up @@ -637,6 +642,38 @@ __device__ void GDAContext::alltoall_linear(rocshmem_team_t team, T *dst,
internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync);
}

template <typename T>
__device__ void GDAContext::alltoall_linear_thread_puts(rocshmem_team_t team, T *dst,
const T *src, int nelems) {
GDATeam *team_obj = reinterpret_cast<GDATeam *>(team);

int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_size = team_obj->num_pes;
int stride = team_obj->tinfo_wrt_world->stride;
long *pSync = team_obj->alltoall_pSync;
int my_pe_in_team = team_obj->my_pe;

int tid = get_flat_block_id();
int wf_id = get_flat_block_id() / WF_SIZE;
int wf_count = (int) ceil((double)get_flat_block_size() / (double)WF_SIZE);
bool wf_leader = 0 == get_active_lane_num();

// Have each PE put their designated data to the other PEs
for (int j = tid; j < pe_size; j+= WF_SIZE) {
int dest_pe = team_obj->get_pe_in_world(j);
uint64_t L_offset = reinterpret_cast<char*>(&dst[my_pe_in_team * nelems]) - base_heap[my_pe];
qps[dest_pe].put_nbi_single(base_heap[dest_pe] + L_offset, &src[j * nelems], nelems * sizeof(T), dest_pe);
}

for (int j = tid; j < pe_size; j+= WF_SIZE) {
int dest_pe = team_obj->get_pe_in_world(j);
pe_quiet_single(dest_pe);
}

// wait until everyone has obtained their designated data
internal_direct_barrier_wg_thread_puts(my_pe, pe_start, stride, pe_size, pSync);
}

template <typename T>
__device__ void GDAContext::fcollect(rocshmem_team_t team, T *dst,
const T *src, int nelems) {
Expand Down
3 changes: 2 additions & 1 deletion src/gda/gda_context_proxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ class GDADefaultContextProxy {
* Placement new the memory which is allocated by proxy_
*/
explicit GDADefaultContextProxy(GDABackend* backend, TeamInfo *tinfo,
int gda_provider,
size_t num_elems = 1)
: constructed_{true}, proxy_{num_elems} {
auto ctx{proxy_.get()};
new (ctx) GDAContext(reinterpret_cast<Backend*>(backend), 0);
new (ctx) GDAContext(reinterpret_cast<Backend*>(backend), 0, gda_provider);
ctx->tinfo = tinfo;
rocshmem_ctx_t local{ctx, tinfo};
set_internal_ctx(&local);
Expand Down
Loading