Skip to content

Commit

Permalink
ch4: use am_tag_{send,recv} in MPIDIG get
Browse files Browse the repository at this point in the history
When target reply data to origin get, use am_tag_send if available.
  • Loading branch information
hzhou committed Nov 7, 2024
1 parent 5598be6 commit d562aae
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ typedef struct MPIDIG_put_req_t {
typedef struct MPIDIG_get_req_t {
MPIR_Request *greq_ptr;
void *flattened_dt;
int am_tag;
} MPIDIG_get_req_t;

typedef struct MPIDIG_cswap_req_t {
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/ch4_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ typedef struct MPIDIG_get_msg_t {
MPI_Aint target_datatype;
MPI_Aint target_true_lb;
int flattened_sz;
int am_tag;
} MPIDIG_get_msg_t;

typedef struct MPIDIG_get_ack_msg_t {
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ enum {

enum {
MPIDIG_TAG_RECV_COMPLETE = 0,
MPIDIG_TAG_GET_COMPLETE,

MPIDIG_TAG_RECV_STATIC_MAX
};
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ int MPIDIG_am_init(void)

MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete);

MPIDIG_am_comm_abort_init();

Expand Down
17 changes: 14 additions & 3 deletions src/mpid/ch4/src/mpidig_rma.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co
* counter in request, thus it can be decreased at request completion. */
MPIDIG_win_cmpl_cnts_incr(win, target_rank, &sreq->dev.completion_notification);

bool is_local;
is_local = MPIDI_rank_is_local(target_rank, win->comm_ptr);
if (MPIDIG_can_do_tag(is_local)) {
am_hdr.am_tag = MPIDIG_get_next_am_tag(win->comm_ptr);
CH4_CALL(am_tag_recv(target_rank, win->comm_ptr, MPIDIG_TAG_GET_COMPLETE, am_hdr.am_tag,
origin_addr, origin_count, origin_datatype, vci_target, vci, sreq),
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
} else {
am_hdr.am_tag = -1;
}

int is_contig;
MPIR_Datatype_is_contig(target_datatype, &is_contig);
if (MPIR_DATATYPE_IS_PREDEFINED(target_datatype) || is_contig) {
Expand All @@ -228,8 +240,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co
MPIR_T_PVAR_TIMER_END(RMA, rma_amhdr_set);

CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr),
NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq),
MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno);
NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
goto fn_exit;
}
Expand All @@ -242,7 +253,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co

CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr),
flattened_dt, flattened_sz, MPI_BYTE, vci, vci_target, sreq),
MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno);
is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
Expand Down
64 changes: 38 additions & 26 deletions src/mpid/ch4/src/mpidig_rma_callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -896,42 +896,44 @@ static int get_target_cmpl_cb(MPIR_Request * rreq)
get_ack.greq_ptr = MPIDIG_REQUEST(rreq, req->greq.greq_ptr);
win = rreq->u.rma.win;

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt) == NULL) {
if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt)) {
/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;
/* count is still target_data_sz now, use it for reply */
get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count);
MPIDIG_REQUEST(rreq, count) /= dt->size;
} else {
MPIDI_Datatype_check_size(MPIDIG_REQUEST(rreq, datatype),
MPIDIG_REQUEST(rreq, count), get_ack.target_data_sz);
}

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
if (MPIDIG_REQUEST(rreq, req->greq.am_tag) >= 0) {
int src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank);
CH4_CALL(am_tag_send(src_rank, win->comm_ptr, MPIDIG_GET_ACK,
MPIDIG_REQUEST(rreq, req->greq.am_tag),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq),
MPIDI_REQUEST(rreq, is_local), mpi_errno);
} else {
CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_GET_ACK, &get_ack, sizeof(get_ack),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci,
rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno);
MPID_Request_complete(rreq);
MPIR_ERR_CHECK(mpi_errno);
goto fn_exit;
}

/* FIXME: MPIR_Typerep_unflatten should allocate the new object */
MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem);
if (!dt) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s",
"MPIR_Datatype_mem");
}
MPIR_Object_set_ref(dt, 1);
MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt));
MPIDIG_REQUEST(rreq, datatype) = dt->handle;
/* count is still target_data_sz now, use it for reply */
get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count);
MPIDIG_REQUEST(rreq, count) /= dt->size;

CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_GET_ACK, &get_ack, sizeof(get_ack),
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count), dt->handle, local_vci,
remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno);
MPID_Request_complete(rreq);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
Expand Down Expand Up @@ -2104,6 +2106,7 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
MPIDIG_REQUEST(rreq, req->greq.flattened_dt) = NULL;
MPIDIG_REQUEST(rreq, req->greq.greq_ptr) = msg_hdr->greq_ptr;
MPIDIG_REQUEST(rreq, u.target.origin_rank) = msg_hdr->src_rank;
MPIDIG_REQUEST(rreq, req->greq.am_tag) = msg_hdr->am_tag;

if (msg_hdr->flattened_sz) {
void *flattened_dt = MPL_malloc(msg_hdr->flattened_sz, MPL_MEM_BUFFER);
Expand Down Expand Up @@ -2164,3 +2167,12 @@ int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
MPIR_FUNC_EXIT;
return mpi_errno;
}

int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;

mpi_errno = get_ack_target_cmpl_cb(req);

return mpi_errno;
}
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig_rma_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
uint32_t attr, MPIR_Request ** req);
int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
uint32_t attr, MPIR_Request ** req);
int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status);

#endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */

0 comments on commit d562aae

Please sign in to comment.