Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Avoid all-to-all connection. #10840

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class TCPSocket {
/**
* @brief Listen to incoming requests. Should be called after bind.
*/
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
[[nodiscard]] Result Listen(std::int32_t backlog = 512) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
Expand Down
1 change: 1 addition & 0 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "topo.h" // for BootstrapNext, BootstrapPrev
#include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span
Expand Down
79 changes: 33 additions & 46 deletions src/collective/broadcast.cc
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "broadcast.h"

#include <cmath> // for ceil, log2
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move

#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32

#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "topo.h" // for ParentRank

namespace xgboost::collective::cpu_impl {
namespace {
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}

CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto shifted_parent = shifted_rank - (1 << k);
return shifted_parent;
}

// Shift the root node to rank 0
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto shifted_rank = (rank + world - root) % world;
return shifted_rank;
}
// shift back to the original rank
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto orig = (rank + root) % world;
return orig;
}
} // namespace

Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
// Binomial tree broadcast
// * Wiki
Expand All @@ -56,28 +24,47 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
auto rank = comm.Rank();
auto world = comm.World();

// shift root to rank 0
auto shifted_rank = ShiftLeft(rank, world, root);
// Send data to the root to preserve the topology. Alternative is to shift the rank, but
// it requires a all-to-all connection.
//
// Most of the use of broadcasting in XGBoost are short messages, this should be
// fine. Otherwise, we can implement a linear pipeline broadcast.
if (root != 0) {
auto rc = Success() << [&] {
return (rank == 0) ? comm.Chan(root)->RecvAll(data) : Success();
} << [&] {
return (rank == root) ? comm.Chan(0)->SendAll(data) : Success();
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Broadcast failed to send data to root.", std::move(rc));
}
root = 0;
}

std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;

if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
<< [&] { return comm.Chan(parent)->Block(); };
if (rank != 0) { // not root
auto parent = ParentRank(rank, depth);
auto rc = Success() << [&] {
return comm.Chan(parent)->RecvAll(data);
} << [&] {
return comm.Chan(parent)->Block();
};
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
return Fail("Broadcast failed to send data to parent.", std::move(rc));
}
}

for (std::int32_t i = depth; i >= 0; --i) {
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
if (rank % (1 << (i + 1)) == 0 && rank + (1 << i) < world) {
auto peer = rank + (1 << i);
CHECK_NE(peer, root);
auto rc = comm.Chan(peer)->SendAll(data);
if (!rc.OK()) {
return rc;
return Fail("Failed to seed to " + std::to_string(peer), std::move(rc));
}
}
}
Expand Down
133 changes: 89 additions & 44 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#endif // !defined(XGBOOST_USE_NCCL)
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "topo.h" // for BootstrapNext
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/json.h" // for Json, Object
Expand Down Expand Up @@ -58,6 +59,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}

// Connect ring and tree neighbors
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,
Expand All @@ -80,10 +82,10 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return prev->NonBlocking(true);
};
if (!rc.OK()) {
return rc;
return Fail("Bootstrap failed to recv from ring prev.", std::move(rc));
}

// exchange host name and port
// Exchange host name and port
std::vector<std::int8_t> buffer(HOST_NAME_MAX * comm.World(), 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX);
Expand All @@ -107,7 +109,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st

rc = std::move(rc) << [&] {
return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch);
} << [&] { return block(); };
} << [&] {
return block();
};
if (!rc.OK()) {
return Fail("Failed to get host names from peers.", std::move(rc));
}
Expand All @@ -118,7 +122,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
auto s_ports = common::Span{reinterpret_cast<std::int8_t*>(peers_port.data()),
peers_port.size() * sizeof(ninfo.port)};
return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch);
} << [&] { return block(); };
} << [&] {
return block();
};
if (!rc.OK()) {
return Fail("Failed to get the port from peers.", std::move(rc));
}
Expand All @@ -138,55 +144,94 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st

std::vector<std::shared_ptr<TCPSocket>>& workers = *out_workers;
workers.resize(comm.World());

for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
auto const& peer = peers[r];
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc)
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
<< [&] { return worker->RecvTimeout(timeout); };
if (!rc.OK()) {
return rc;
}

auto rank = comm.Rank();
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
workers[BootstrapNext(comm.Rank(), comm.World())] = next;
if (BootstrapNext(comm.Rank(), comm.World()) == BootstrapPrev(comm.Rank(), comm.World())) {
if (comm.Rank() == 0) {
if (comm.World() == 2) {
workers[BootstrapNext(comm.Rank(), comm.World())] = prev;
} else {
CHECK_EQ(comm.World(), 1);
}
}
workers[r] = std::move(worker);
} else {
workers[BootstrapPrev(comm.Rank(), comm.World())] = prev;
}

for (std::int32_t r = 0; r < comm.Rank(); ++r) {
auto peer = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
/**
* Construct tree.
*/
// All workers connect to rank 0 so that we can always use rank 0 as broadcast root.
if (comm.Rank() == 0) {
for (std::int32_t i = 0; i < comm.World() - 3; ++i) {
auto worker = std::make_shared<TCPSocket>();
SockAddress addr;
return listener->Accept(peer.get(), &addr);
} << [&] {
return peer->RecvTimeout(timeout);
};
if (!rc.OK()) {
return rc;
rc = listener->Accept(worker.get(), &addr);
if (!rc.OK()) {
return Fail("Failed to accept for rank 0.", std::move(rc));
}
std::int32_t r{-1};
std::size_t n_bytes{0};
rc = worker->RecvAll(&r, sizeof(r), &n_bytes);
if (!rc.OK()) {
return Fail("Failed to recv rank.", std::move(rc));
}
if (n_bytes != sizeof(r)) {
return Fail("Failed to recv rank due to size.", std::move(rc));
}
workers[r] = worker;
}
std::int32_t rank{-1};
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
} else {
if (!workers[0]) {
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
return Connect(peers[0].host, peers[0].port, retry, timeout, worker.get());
} << [&] {
auto rank = comm.Rank();
std::size_t n_bytes = 0;
auto rc = worker->SendAll(&rank, sizeof(rank), &n_bytes);
if (n_bytes != sizeof(rank)) {
return Fail("Failed to send rank due to size.", std::move(rc));
}
return rc;
};
if (!rc.OK()) {
return Fail("Failed to connect to root.", std::move(rc));
}
workers[0] = worker;
}
}
// Binomial tree connect
std::int32_t const kDepth = std::ceil(std::log2(static_cast<double>(comm.World()))) - 1;
if (comm.Rank() != 0) {
auto prank = ParentRank(comm.Rank(), kDepth);
if (!workers[prank]) { // Skip if it's part of the ring.
auto parent = std::make_shared<TCPSocket>();
SockAddress addr;
rc = listener->Accept(parent.get(), &addr);
if (!rc.OK()) {
return Fail("Failed to recv connection from tree parent.", std::move(rc));
}
workers[prank] = parent;
}
workers[rank] = std::move(peer);
}

for (std::int32_t r = 0; r < comm.World(); ++r) {
if (r == comm.Rank()) {
continue;
for (std::int32_t i = kDepth; i >= 0; --i) {
if (comm.Rank() % (1 << (i + 1)) == 0 && comm.Rank() + (1 << i) < comm.World()) {
auto peer = comm.Rank() + (1 << i);
if (workers[peer]) { // skip if it's part of the ring.
continue;
}
auto worker = std::make_shared<TCPSocket>();
rc = std::move(rc) << [&] {
return Connect(peers[peer].host, peers[peer].port, retry, timeout, worker.get());
} << [&] {
return worker->RecvTimeout(timeout);
};
if (!rc.OK()) {
return Fail("Failed to connect to tree neighbor", std::move(rc));
}
workers[peer] = worker;
}
CHECK(workers[r]);
}

return Success();
Expand Down
17 changes: 3 additions & 14 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,10 @@

namespace xgboost::collective {

inline constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min
inline constexpr std::int32_t DefaultRetry() { return 3; }
constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min
constexpr std::int32_t DefaultRetry() { return 3; }

// indexing into the ring
inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) {
auto nrank = (r + world + 1) % world;
return nrank;
}

inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) {
auto nrank = (r + world - 1) % world;
return nrank;
}

inline StringView DefaultNcclName() { return "libnccl.so.2"; }
constexpr StringView DefaultNcclName() { return "libnccl.so.2"; }

class Channel;
class Coll;
Expand Down
26 changes: 26 additions & 0 deletions src/collective/topo.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "topo.h"

#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
namespace xgboost::collective {
std::int32_t ParentRank(std::int32_t rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}

CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto parent = rank - (1 << k);
return parent;
}
} // namespace xgboost::collective
Loading
Loading