Skip to content

TL/UCP: add linear alltoall(v) #1116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
ucc_dt_size(args->dst.info.datatype);
rbuf = args->dst.info.buffer;
if (!UCC_IS_INPLACE(*args)) {
status = ctx->copy.post(PTR_OFFSET(args->dst.info.buffer, offset),
args->dst.info.mem_type,
args->src.info.buffer,
args->src.info.mem_type,
args->src.info.count *
ucc_dt_size(args->src.info.datatype),
task,
&task->allgather_kn.copy_task);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
task->allgather_kn.copy_task = NULL;
}
} else if (ct == UCC_COLL_TYPE_ALLGATHERV) {
ucc_kn_agv_pattern_init(size, rank, radix, args->dst.info_v.counts,
Expand Down
71 changes: 63 additions & 8 deletions src/components/tl/ucp/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define ALLTOALL_MAX_PATTERN_SIZE (sizeof(UCC_TL_UCP_ALLTOALL_DEFAULT_ALG_SELECT_STR_PATTERN) + 32)
#define ALLTOALL_DEFAULT_ALG_SWITCH 129
/* TODO: add as parameters */
#define MSG_MEDIUM 66000
#define NP_THRESH 32

ucc_status_t ucc_tl_ucp_alltoall_pairwise_start(ucc_coll_task_t *task);
void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *task);
Expand Down Expand Up @@ -43,13 +46,49 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
.name = "onesided",
.desc = "naive, linear one-sided implementation"},
[UCC_TL_UCP_ALLTOALL_ALG_LINEAR] =
{.id = UCC_TL_UCP_ALLTOALL_ALG_LINEAR,
.name = "linear",
.desc = "linear two-sided implementation"},
[UCC_TL_UCP_ALLTOALL_ALG_LAST] = {.id = 0, .name = NULL, .desc = NULL}};

ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task)
static ucc_rank_t get_num_posts(const ucc_tl_ucp_team_t *team,
const ucc_coll_args_t *args)
{
ucc_status_t status;
unsigned long posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts;
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size;

data_size = (size_t)args->src.info.count *
ucc_dt_size(args->src.info.datatype);
if (posts == UCC_ULUNITS_AUTO) {
if ((data_size > MSG_MEDIUM) && (tsize > NP_THRESH)) {
/* use pairwise algorithm */
posts = 1;
} else {
/* use linear algorithm */
posts = 0;
}
}

posts = (posts > tsize || posts == 0) ? tsize: posts;
return posts;
}

ucc_status_t ucc_tl_ucp_alltoall_pairwise_init_num_posts(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h,
ucc_rank_t num_posts)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_status_t status;

ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
task = ucc_tl_ucp_init_task(coll_args, team);
task->alltoall_pairwise.num_posts = num_posts;
*task_h = &task->super;

ALLTOALL_TASK_CHECK(TASK_ARGS(task), TASK_TEAM(task));
status = ucc_tl_ucp_alltoall_pairwise_init_common(task);
out:
return status;
Expand All @@ -60,12 +99,28 @@ ucc_status_t ucc_tl_ucp_alltoall_pairwise_init(ucc_base_coll_args_t *coll_args,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_status_t status;
ucc_rank_t num_posts;

num_posts = get_num_posts(tl_team, &coll_args->args);
return ucc_tl_ucp_alltoall_pairwise_init_num_posts(coll_args, team, task_h,
num_posts);
}

ucc_status_t ucc_tl_ucp_alltoall_linear_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
return ucc_tl_ucp_alltoall_pairwise_init_num_posts(coll_args, team, task_h, 0);
}

ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task)
{
ucc_status_t status;

ALLTOALL_TASK_CHECK(TASK_ARGS(task), TASK_TEAM(task));
task->alltoall_pairwise.num_posts = get_num_posts(TASK_TEAM(task),
&TASK_ARGS(task));

ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
status = ucc_tl_ucp_alltoall_pairwise_init_common(task);
out:
return status;
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum {
UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE,
UCC_TL_UCP_ALLTOALL_ALG_BRUCK,
UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
UCC_TL_UCP_ALLTOALL_ALG_LINEAR,
UCC_TL_UCP_ALLTOALL_ALG_LAST
};

Expand All @@ -37,6 +38,9 @@ ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoall_linear_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
41 changes: 7 additions & 34 deletions src/components/tl/ucp/alltoall/alltoall_pairwise.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -11,10 +11,6 @@
#include "utils/ucc_math.h"
#include "tl_ucp_sendrecv.h"

/* TODO: add as parameters */
#define MSG_MEDIUM 66000
#define NP_THRESH 32

static inline ucc_rank_t get_recv_peer(ucc_rank_t rank, ucc_rank_t size,
ucc_rank_t step)
{
Expand All @@ -27,44 +23,21 @@ static inline ucc_rank_t get_send_peer(ucc_rank_t rank, ucc_rank_t size,
return (rank - step + size) % size;
}

static ucc_rank_t get_num_posts(const ucc_tl_ucp_team_t *team,
const ucc_coll_args_t *args)
{
unsigned long posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts;
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size;

data_size = (size_t)args->src.info.count *
ucc_dt_size(args->src.info.datatype);
if (posts == UCC_ULUNITS_AUTO) {
if ((data_size > MSG_MEDIUM) && (tsize > NP_THRESH)) {
/* use pairwise algorithm */
posts = 1;
} else {
/* use linear algorithm */
posts = 0;
}
}

posts = (posts > tsize || posts == 0) ? tsize: posts;
return posts;
}

void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ptrdiff_t sbuf = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
ptrdiff_t rbuf = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
int polls = 0;
ucc_rank_t peer, nreqs;
ucc_rank_t nreqs = task->alltoall_pairwise.num_posts;
ucc_rank_t peer;
size_t data_size;

nreqs = get_num_posts(team, &TASK_ARGS(task));
data_size = (size_t)(TASK_ARGS(task).src.info.count / gsize) *
ucc_dt_size(TASK_ARGS(task).src.info.datatype);
while ((task->tagged.send_posted < gsize ||
Expand All @@ -75,7 +48,7 @@ void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task)
((task->tagged.recv_posted - task->tagged.recv_completed) <
nreqs)) {
peer = get_recv_peer(grank, gsize, task->tagged.recv_posted);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb((void *)(rbuf + peer * data_size),
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, peer * data_size),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ucc_tl_ucp_recv_nb returns UCC_INPROGRESS it jumps to out, but you could still post more tasks to match nreq tasks in flight

data_size, rmem, peer, team, task),
task, out);
polls = 0;
Expand All @@ -84,7 +57,7 @@ void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task)
((task->tagged.send_posted - task->tagged.send_completed) <
nreqs)) {
peer = get_send_peer(grank, gsize, task->tagged.send_posted);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(sbuf + peer * data_size),
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(sbuf, peer * data_size),
data_size, smem, peer, team, task),
task, out);
polls = 0;
Expand Down
66 changes: 57 additions & 9 deletions src/components/tl/ucp/alltoallv/alltoallv.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -8,6 +8,9 @@
#include "tl_ucp.h"
#include "alltoallv.h"

/* TODO: add as parameters */
#define NP_THRESH 32

ucc_base_coll_alg_info_t
ucc_tl_ucp_alltoallv_algs[UCC_TL_UCP_ALLTOALLV_ALG_LAST + 1] = {
[UCC_TL_UCP_ALLTOALLV_ALG_PAIRWISE] =
Expand All @@ -23,14 +26,45 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLTOALLV_ALG_ONESIDED,
.name = "onesided",
.desc = "O(N) onesided alltoallv"},
[UCC_TL_UCP_ALLTOALLV_ALG_LINEAR] =
{.id = UCC_TL_UCP_ALLTOALLV_ALG_LINEAR,
.name = "linear",
.desc = "O(N) linear alltoallv"},
[UCC_TL_UCP_ALLTOALLV_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

ucc_status_t ucc_tl_ucp_alltoallv_init(ucc_tl_ucp_task_t *task)
static ucc_rank_t get_num_posts(const ucc_tl_ucp_team_t *team)
{
ucc_status_t status;
unsigned long posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoallv_pairwise_num_posts;
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);

ALLTOALLV_TASK_CHECK(TASK_ARGS(task), TASK_TEAM(task));
if (posts == UCC_ULUNITS_AUTO) {
if (UCC_TL_TEAM_SIZE(team) <= NP_THRESH) {
/* use linear algorithm */
posts = 0;
} else {
/* use pairwise algorithm */
posts = 1;
}
}

posts = (posts > tsize || posts == 0) ? tsize: posts;
return posts;
}

ucc_status_t ucc_tl_ucp_alltoallv_pairwise_init_num_posts(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h,
ucc_rank_t num_posts)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_status_t status;

ALLTOALLV_TASK_CHECK(coll_args->args, tl_team);
task = ucc_tl_ucp_init_task(coll_args, team);
task->alltoallv_pairwise.num_posts = num_posts;
*task_h = &task->super;
status = ucc_tl_ucp_alltoallv_pairwise_init_common(task);
out:
return status;
Expand All @@ -41,12 +75,26 @@ ucc_status_t ucc_tl_ucp_alltoallv_pairwise_init(ucc_base_coll_args_t *coll_args,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task;
ucc_status_t status;
ucc_rank_t num_posts;

ALLTOALLV_TASK_CHECK(coll_args->args, tl_team);
task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
num_posts = get_num_posts(tl_team);
return ucc_tl_ucp_alltoallv_pairwise_init_num_posts(coll_args, team, task_h,
num_posts);
}

ucc_status_t ucc_tl_ucp_alltoallv_linear_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
return ucc_tl_ucp_alltoallv_pairwise_init_num_posts(coll_args, team, task_h, 0);
}

ucc_status_t ucc_tl_ucp_alltoallv_init(ucc_tl_ucp_task_t *task)
{
ucc_status_t status;

ALLTOALLV_TASK_CHECK(TASK_ARGS(task), TASK_TEAM(task));
task->alltoallv_pairwise.num_posts = get_num_posts(TASK_TEAM(task));
status = ucc_tl_ucp_alltoallv_pairwise_init_common(task);
out:
return status;
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/alltoallv/alltoallv.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum {
UCC_TL_UCP_ALLTOALLV_ALG_PAIRWISE,
UCC_TL_UCP_ALLTOALLV_ALG_HYBRID,
UCC_TL_UCP_ALLTOALLV_ALG_ONESIDED,
UCC_TL_UCP_ALLTOALLV_ALG_LINEAR,
UCC_TL_UCP_ALLTOALLV_ALG_LAST
};

Expand All @@ -29,6 +30,10 @@ ucc_status_t ucc_tl_ucp_alltoallv_pairwise_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoallv_linear_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_alltoallv_hybrid_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);
Expand Down
Loading
Loading