Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ch4: use am_tag_{send,recv} in RMA get/put #7202

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions src/mpi/datatype/get_elements_x.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ static MPI_Count MPIR_Type_get_basic_type_elements(MPI_Count * bytes_p,
break;
}

if (type1_sz + type2_sz == 0) {
/* this is likely a struct type with mixed basic elements. Let's just bail for now */
*bytes_p = 0;
return 0;
}

/* determine the number of elements in the region */
elements = 2 * (usable_bytes / (type1_sz + type2_sz));
if (usable_bytes % (type1_sz + type2_sz) >= type1_sz)
Expand Down
2 changes: 2 additions & 0 deletions src/mpi/datatype/typerep/src/typerep_flatten.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ int MPIR_Typerep_unflatten(MPIR_Datatype * datatype_ptr, void *flattened_type)
MPIR_ERR_CHECK(mpi_errno);
#endif

MPID_Type_commit_hook(datatype_ptr);

fn_exit:
return mpi_errno;

Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ typedef struct MPIDIG_put_req_t {
typedef struct MPIDIG_get_req_t {
MPIR_Request *greq_ptr;
void *flattened_dt;
int am_tag;
} MPIDIG_get_req_t;

typedef struct MPIDIG_cswap_req_t {
Expand Down
4 changes: 4 additions & 0 deletions src/mpid/ch4/netmod/ofi/ofi_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_issue_ack_recv(MPIR_Request * sreq, MPIR_
ackreq->remote_addr = MPIDI_OFI_av_to_phys(addr, nic, vci_remote);
ackreq->match_bits = match_bits;

#ifndef MPIDI_CH4_DIRECT_NETMOD
/* set is_local in case we go into active messages later */
MPIDI_REQUEST(sreq, is_local) = 0;
#endif
MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ackreq->ctx_idx].rx,
ackreq->ack_hdr, ackreq->ack_hdr_sz, NULL, ackreq->remote_addr,
ackreq->match_bits, 0ULL, (void *) &(ackreq->context)),
Expand Down
2 changes: 1 addition & 1 deletion src/mpid/ch4/netmod/ucx/ucx_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ extern ucp_generic_dt_ops_t MPIDI_UCX_datatype_ops;
#define MPIDI_UCX_TAG_BITS (64 - MPIDI_UCX_CONTEXT_ID_BITS - MPIDI_UCX_RANK_BITS - MPIDI_UCX_PROTOCOL_BITS)

/* protocol bits */
#define MPIDI_UCX_TAG_AM (1 << MPIDI_UCX_TAG_BITS)
#define MPIDI_UCX_TAG_AM (1ULL << MPIDI_UCX_TAG_BITS)

#define MPIDI_UCX_RANK_SHIFT (MPIDI_UCX_TAG_BITS + MPIDI_UCX_PROTOCOL_BITS)
#define MPIDI_UCX_CONTEXT_ID_SHIFT (MPIDI_UCX_TAG_BITS + MPIDI_UCX_PROTOCOL_BITS + MPIDI_UCX_RANK_BITS)
Expand Down
2 changes: 0 additions & 2 deletions src/mpid/ch4/src/ch4_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_isend(const void *buf,
mpi_errno = MPIDI_SHM_mpi_isend(buf, count, datatype, rank, tag, comm, attr, av, req);
else
mpi_errno = MPIDI_NM_mpi_isend(buf, count, datatype, rank, tag, comm, attr, av, req);
if (mpi_errno == MPI_SUCCESS)
MPIDI_REQUEST(*req, is_local) = r;
#endif
MPIR_ERR_CHECK(mpi_errno);

Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch4/src/ch4_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ typedef struct MPIDIG_put_msg_t {

typedef struct MPIDIG_put_dt_ack_msg_t {
int src_rank;
int am_tag;
MPIR_Request *target_preq_ptr;
MPIR_Request *origin_preq_ptr;
} MPIDIG_put_dt_ack_msg_t;
Expand All @@ -142,6 +143,7 @@ typedef struct MPIDIG_get_msg_t {
MPI_Aint target_datatype;
MPI_Aint target_true_lb;
int flattened_sz;
int am_tag;
} MPIDIG_get_msg_t;

typedef struct MPIDIG_get_ack_msg_t {
Expand Down
11 changes: 11 additions & 0 deletions src/mpid/ch4/src/mpidig.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ enum {

enum {
MPIDIG_TAG_RECV_COMPLETE = 0,
MPIDIG_TAG_GET_COMPLETE,
MPIDIG_TAG_PUT_COMPLETE,

MPIDIG_TAG_RECV_STATIC_MAX
};
Expand Down Expand Up @@ -135,6 +137,15 @@ typedef struct MPIDIG_global_t {
} MPIDIG_global_t;
extern MPIDIG_global_t MPIDIG_global;

MPL_STATIC_INLINE_PREFIX int MPIDIG_can_do_tag(bool is_local)
{
#ifdef MPIDI_CH4_DIRECT_NETMOD
return MPIDI_NM_am_can_do_tag();
#else
return is_local ? MPIDI_SHM_am_can_do_tag() : MPIDI_NM_am_can_do_tag();
#endif
}

MPL_STATIC_INLINE_PREFIX int MPIDIG_get_next_am_tag(MPIR_Comm * comm)
{
int tag = comm->next_am_tag++;
Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch4/src/mpidig_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ int MPIDIG_am_init(void)

MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_PUT_COMPLETE, &MPIDIG_tag_put_complete);

MPIDIG_am_comm_abort_init();

Expand Down
20 changes: 8 additions & 12 deletions src/mpid/ch4/src/mpidig_pt2pt_callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
static int handle_unexp_cmpl(MPIR_Request * rreq);
static int recv_target_cmpl_cb(MPIR_Request * rreq);

static int can_do_tag(MPIR_Request * rreq)
{
#ifdef MPIDI_CH4_DIRECT_NETMOD
return MPIDI_NM_am_can_do_tag();
#else
return MPIDI_REQUEST(rreq, is_local) ? MPIDI_SHM_am_can_do_tag() : MPIDI_NM_am_can_do_tag();
#endif
}

int MPIDIG_do_cts(MPIR_Request * rreq)
{
int mpi_errno = MPI_SUCCESS;
Expand All @@ -30,13 +21,18 @@ int MPIDIG_do_cts(MPIR_Request * rreq)
MPIDIG_send_cts_msg_t am_hdr;
am_hdr.sreq_ptr = (MPIDIG_REQUEST(rreq, req->rreq.peer_req_ptr));
am_hdr.rreq_ptr = rreq;
if (can_do_tag(rreq)) {
#ifndef MPIDI_CH4_DIRECT_NETMOD
int is_local = MPIDI_REQUEST(rreq, is_local);
#else
int is_local = 0;
#endif
if (MPIDIG_can_do_tag(is_local)) {
am_hdr.tag = MPIDIG_get_next_am_tag(rreq->comm);
CH4_CALL(am_tag_recv(source_rank, rreq->comm,
MPIDIG_TAG_RECV_COMPLETE, am_hdr.tag,
MPIDIG_REQUEST(rreq, buffer), MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype), remote_vci, local_vci, rreq),
MPIDI_REQUEST(rreq, is_local), mpi_errno);
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
} else {
am_hdr.tag = -1;
Expand All @@ -48,7 +44,7 @@ int MPIDIG_do_cts(MPIR_Request * rreq)

CH4_CALL(am_send_hdr_reply(rreq->comm, source_rank, MPIDIG_SEND_CTS,
&am_hdr, sizeof(am_hdr), local_vci, remote_vci),
MPIDI_REQUEST(rreq, is_local), mpi_errno);
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
Expand Down
17 changes: 14 additions & 3 deletions src/mpid/ch4/src/mpidig_rma.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co
* counter in request, thus it can be decreased at request completion. */
MPIDIG_win_cmpl_cnts_incr(win, target_rank, &sreq->dev.completion_notification);

bool is_local;
is_local = MPIDI_rank_is_local(target_rank, win->comm_ptr);
if (MPIDIG_can_do_tag(is_local)) {
am_hdr.am_tag = MPIDIG_get_next_am_tag(win->comm_ptr);
CH4_CALL(am_tag_recv(target_rank, win->comm_ptr, MPIDIG_TAG_GET_COMPLETE, am_hdr.am_tag,
origin_addr, origin_count, origin_datatype, vci_target, vci, sreq),
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
} else {
am_hdr.am_tag = -1;
}

int is_contig;
MPIR_Datatype_is_contig(target_datatype, &is_contig);
if (MPIR_DATATYPE_IS_PREDEFINED(target_datatype) || is_contig) {
Expand All @@ -228,8 +240,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co
MPIR_T_PVAR_TIMER_END(RMA, rma_amhdr_set);

CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr),
NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq),
MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno);
NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
goto fn_exit;
}
Expand All @@ -242,7 +253,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co

CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr),
flattened_dt, flattened_sz, MPI_BYTE, vci, vci_target, sreq),
MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno);
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
Expand Down
143 changes: 99 additions & 44 deletions src/mpid/ch4/src/mpidig_rma_callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -896,42 +896,45 @@ static int get_target_cmpl_cb(MPIR_Request * rreq)
get_ack.greq_ptr = MPIDIG_REQUEST(rreq, req->greq.greq_ptr);
win = rreq->u.rma.win;

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt) == NULL) {
if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt)) {
/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;
/* count is still target_data_sz now, use it for reply */
get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count);
MPIDIG_REQUEST(rreq, count) /= dt->size;
} else {
MPIDI_Datatype_check_size(MPIDIG_REQUEST(rreq, datatype),
MPIDIG_REQUEST(rreq, count), get_ack.target_data_sz);
}

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
if (MPIDIG_REQUEST(rreq, req->greq.am_tag) >= 0) {
int src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank);
CH4_CALL(am_tag_send(src_rank, win->comm_ptr, MPIDIG_GET_ACK,
MPIDIG_REQUEST(rreq, req->greq.am_tag),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq),
MPIDI_REQUEST(rreq, is_local), mpi_errno);
} else {
CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_GET_ACK, &get_ack, sizeof(get_ack),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci,
rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno);
MPID_Request_complete(rreq);
MPIR_ERR_CHECK(mpi_errno);
goto fn_exit;
}

/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;
/* count is still target_data_sz now, use it for reply */
get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count);
MPIDIG_REQUEST(rreq, count) /= dt->size;

CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_GET_ACK, &get_ack, sizeof(get_ack),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count), dt->handle, local_vci,
remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno);
MPID_Request_complete(rreq);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
Expand Down Expand Up @@ -968,12 +971,43 @@ static int put_dt_target_cmpl_cb(MPIR_Request * rreq)

MPIR_FUNC_ENTER;

/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
/* Note: handle is filled in by MPIR_Handle_obj_alloc() */
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->preq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;

ack_msg.src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank);
ack_msg.origin_preq_ptr = MPIDIG_REQUEST(rreq, req->preq.preq_ptr);
ack_msg.target_preq_ptr = rreq;

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
MPIR_Comm *comm = rreq->u.rma.win->comm_ptr;

bool is_local;
#ifndef MPIDI_CH4_DIRECT_NETMOD
is_local = MPIDI_REQUEST(rreq, is_local);
#else
is_local = 0;
#endif
if (MPIDIG_can_do_tag(is_local)) {
ack_msg.am_tag = MPIDIG_get_next_am_tag(comm);
CH4_CALL(am_tag_recv(ack_msg.src_rank, comm, MPIDIG_TAG_PUT_COMPLETE, ack_msg.am_tag,
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype),
local_vci, remote_vci, rreq), is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
} else {
ack_msg.am_tag = -1;
}

CH4_CALL(am_send_hdr_reply
(rreq->u.rma.win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_PUT_DT_ACK, &ack_msg, sizeof(ack_msg), local_vci, remote_vci),
Expand Down Expand Up @@ -1591,13 +1625,25 @@ int MPIDIG_put_dt_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_s
/* origin datatype to be released in MPIDIG_put_data_origin_cb */
MPIDIG_REQUEST(rreq, datatype) = MPIDIG_REQUEST(origin_req, datatype);

CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(origin_req, u.origin.target_rank),
MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg),
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);
int target_rank = MPIDIG_REQUEST(origin_req, u.origin.target_rank);
if (msg_hdr->am_tag >= 0) {
CH4_CALL(am_tag_send(target_rank, win->comm_ptr, MPIDIG_PUT_DAT_REQ,
msg_hdr->am_tag,
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);

} else {
CH4_CALL(am_isend_reply(win->comm_ptr, target_rank,
MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg),
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);
}
MPIR_ERR_CHECK(mpi_errno);

if (attr & MPIDIG_AM_ATTR__IS_ASYNC) {
Expand Down Expand Up @@ -1715,19 +1761,9 @@ int MPIDIG_put_data_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,

rreq = (MPIR_Request *) msg_hdr->preq_ptr;

/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
/* Note: handle is filled in by MPIR_Handle_obj_alloc() */
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->preq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;

MPIDIG_REQUEST(rreq, req->target_cmpl_cb) = put_target_cmpl_cb;
MPIDIG_recv_type_init(MPIDIG_REQUEST(rreq, req->preq.origin_data_sz), rreq);
mpi_errno = MPIDIG_recv_type_init(MPIDIG_REQUEST(rreq, req->preq.origin_data_sz), rreq);
MPIR_ERR_CHECK(mpi_errno);

if (attr & MPIDIG_AM_ATTR__IS_ASYNC) {
*req = rreq;
Expand Down Expand Up @@ -2104,6 +2140,7 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
MPIDIG_REQUEST(rreq, req->greq.flattened_dt) = NULL;
MPIDIG_REQUEST(rreq, req->greq.greq_ptr) = msg_hdr->greq_ptr;
MPIDIG_REQUEST(rreq, u.target.origin_rank) = msg_hdr->src_rank;
MPIDIG_REQUEST(rreq, req->greq.am_tag) = msg_hdr->am_tag;

if (msg_hdr->flattened_sz) {
void *flattened_dt = MPL_malloc(msg_hdr->flattened_sz, MPL_MEM_BUFFER);
Expand Down Expand Up @@ -2164,3 +2201,21 @@ int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
MPIR_FUNC_EXIT;
return mpi_errno;
}

int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;

mpi_errno = get_ack_target_cmpl_cb(req);

return mpi_errno;
}

int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;

mpi_errno = put_target_cmpl_cb(req);

return mpi_errno;
}
Loading