diff --git a/.circleci/config.yml b/.circleci/config.yml index 559b0f4c67..492d8044c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -160,6 +160,7 @@ jobs: LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py + /bin/bash python/tests/run_ring_test.sh - run: name: Build example extension command: | diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 8fd081b844..8e16bd40dd 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -5,3 +5,4 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 34d9583fa8..cb6edd1e80 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -1,8 +1,11 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" +#include "mlx/distributed/ring/ring.h" #include "mlx/scheduler.h" namespace mlx::core::distributed { @@ -65,7 +68,7 @@ class EmptyGroup : public GroupImpl { } // namespace detail bool is_available() { - return mpi::is_available(); + return mpi::is_available() || ring::is_available(); } int Group::rank() const { @@ -80,20 +83,50 @@ Group Group::split(int color, int key /* = -1 */) const { return Group(group_->split(color, key)); } -Group init(bool strict /* = false */) { - auto init_group = [strict]() { - auto default_group = mpi::init(strict); - if (default_group == nullptr) { - default_group = std::make_shared(); +Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { + static std::unordered_map> + backends; + + // Already initialized so return the group. + if (auto g = backends.find(bk); g != backends.end()) { + return Group(g->second); + } + + // Create the requested communication group + std::shared_ptr group; + std::string bk_ = bk; + if (bk == "mpi") { + group = mpi::init(strict); + } else if (bk == "ring") { + group = ring::init(strict); + } else if (bk == "any") { + group = ring::init(false); + bk_ = "ring"; + if (group == nullptr) { + group = mpi::init(false); + bk_ = "mpi"; } - return default_group; - }; - static std::shared_ptr default_group = init_group(); + if (group == nullptr && strict) { + throw std::runtime_error("[distributed] Couldn't initialize any backend"); + } + } else { + std::ostringstream msg; + msg << "[distributed] The only valid values for backend are 'any', 'mpi' " + << "and 'ring' but '" << bk << "' was provided."; + throw std::invalid_argument(msg.str()); + } + + if (group == nullptr) { + group = std::make_shared(); + } else { + backends.insert({"any", group}); + } + backends.insert({std::move(bk_), group}); // Ensure the communication stream is alive before // the graph is evaluated detail::communication_stream(); - return Group(default_group); + return Group(group); } } // namespace mlx::core::distributed diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index c06d10756c..1f1713866e 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -53,6 +53,6 @@ struct Group { * distributed subsystem. Otherwise simply return a singleton group which will * render communication operations as no-op. */ -Group init(bool strict = false); +Group init(bool strict = false, const std::string& bk = "any"); } // namespace mlx::core::distributed diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index fdcbf777dd..bf5fc13b5d 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -11,6 +11,8 @@ namespace mlx::core::distributed::detail { */ class GroupImpl { public: + virtual ~GroupImpl() {} + virtual int rank() = 0; virtual int size() = 0; virtual std::shared_ptr split(int color, int key = -1) = 0; diff --git a/mlx/distributed/ring/CMakeLists.txt b/mlx/distributed/ring/CMakeLists.txt new file mode 100644 index 0000000000..e94ca11426 --- /dev/null +++ b/mlx/distributed/ring/CMakeLists.txt @@ -0,0 +1,5 @@ +if(MLX_BUILD_CPU) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp) +endif() diff --git a/mlx/distributed/ring/no_ring.cpp b/mlx/distributed/ring/no_ring.cpp new file mode 100644 index 0000000000..0c31d28392 --- /dev/null +++ b/mlx/distributed/ring/no_ring.cpp @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/ring/ring.h" + +namespace mlx::core::distributed::ring { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available() { + return false; +} + +std::shared_ptr init(bool strict /* = false */) { + if (strict) { + throw std::runtime_error("Cannot initialize ring distributed backend."); + } + return nullptr; +} + +} // namespace mlx::core::distributed::ring diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp new file mode 100644 index 0000000000..c30de91a2b --- /dev/null +++ b/mlx/distributed/ring/ring.cpp @@ -0,0 +1,827 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/threadpool.h" + +#define SWITCH_TYPE(x, ...) \ + switch ((x).dtype()) { \ + case bool_: { \ + using T = bool; \ + __VA_ARGS__; \ + } break; \ + case int8: { \ + using T = int8_t; \ + __VA_ARGS__; \ + } break; \ + case int16: { \ + using T = int16_t; \ + __VA_ARGS__; \ + } break; \ + case int32: { \ + using T = int32_t; \ + __VA_ARGS__; \ + } break; \ + case int64: { \ + using T = int64_t; \ + __VA_ARGS__; \ + } break; \ + case uint8: { \ + using T = uint8_t; \ + __VA_ARGS__; \ + } break; \ + case uint16: { \ + using T = uint16_t; \ + __VA_ARGS__; \ + } break; \ + case uint32: { \ + using T = uint32_t; \ + __VA_ARGS__; \ + } break; \ + case uint64: { \ + using T = uint64_t; \ + __VA_ARGS__; \ + } break; \ + case bfloat16: { \ + using T = bfloat16_t; \ + __VA_ARGS__; \ + } break; \ + case float16: { \ + using T = float16_t; \ + __VA_ARGS__; \ + } break; \ + case float32: { \ + using T = float; \ + __VA_ARGS__; \ + } break; \ + case complex64: { \ + using T = complex64_t; \ + __VA_ARGS__; \ + } break; \ + } + +namespace mlx::core::distributed::ring { + +constexpr const size_t PACKET_SIZE = 262144; +constexpr const int CONN_ATTEMPTS = 5; +constexpr const int CONN_WAIT = 1000; + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; +using json = nlohmann::json; + +namespace { + +class Barrier { + public: + explicit Barrier(int n_threads) + : n_threads_(n_threads), count_(0), flag_(false) {} + + void arrive_and_wait() { + std::unique_lock lock(mtx_); + + // Keep the flag that marks the current use of the barrier. The next use is + // going to have this flag flipped. + bool initial_flag = flag_; + + // Increment the count + count_++; + + // We are the last thread to arrive so reset the count, change the flag and + // notify everybody. + if (count_ == n_threads_) { + count_ = 0; + flag_ = !flag_; + cv_.notify_all(); + } + + // Wait for the rest to arrive + else { + cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; }); + } + } + + private: + std::mutex mtx_; + std::condition_variable cv_; + int n_threads_; + + int count_; + bool flag_; // we need this for sequential use of the barrier +}; + +template +void log(std::ostream& os, T first) { + os << first << std::endl; +} + +template +void log(std::ostream& os, T first, Args... args) { + log(os << first << " ", args...); +} + +template +void log_info(bool verbose, Args... args) { + if (!verbose) { + return; + } + + log(std::cerr, "[ring]", args...); +} + +template +decltype(T() * U()) ceildiv(T a, U b) { + return (a + b - 1) / b; +} + +struct address_t { + sockaddr_storage addr; + socklen_t len; + + const sockaddr* get() const { + return (struct sockaddr*)&addr; + } +}; + +/** + * Parse a sockaddr from an ip and port provided as strings. + */ +address_t parse_address(const std::string& ip, const std::string& port) { + struct addrinfo hints, *res; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); + if (status != 0) { + std::ostringstream msg; + msg << "Can't parse address " << ip << ":" << port; + throw std::runtime_error(msg.str()); + } + + address_t result; + memcpy(&result.addr, res->ai_addr, res->ai_addrlen); + result.len = res->ai_addrlen; + freeaddrinfo(res); + + return result; +} + +/** + * Parse a sockaddr provided as an : string. + */ +address_t parse_address(const std::string& ip_port) { + auto colon = ip_port.find(":"); + if (colon == std::string::npos) { + std::ostringstream msg; + msg << "Can't parse address " << ip_port; + throw std::runtime_error(msg.str()); + } + std::string ip(ip_port.begin(), ip_port.begin() + colon); + std::string port(ip_port.begin() + colon + 1, ip_port.end()); + + return parse_address(ip, port); +} + +/** + * Load all addresses from the json hostfile. The hostfile is a list of + * addresses in order of rank. For each rank there can be many addresses so + * that we can have multiple connections between peers. + * + * For example: + * [ + * ["ip1:5000", "ip1:5001"], + * ["ip2:5000", "ip2:5001"], + * ["ip3:5000", "ip3:5001"], + * ] + */ +std::vector> load_nodes(const char* hostfile) { + std::vector> nodes; + std::ifstream f(hostfile); + + json hosts = json::parse(f); + for (auto& h : hosts) { + std::vector host; + for (auto& ips : h) { + host.push_back(std::move(parse_address(ips.get()))); + } + nodes.push_back(std::move(host)); + } + + return nodes; +} + +/** + * Create a socket and accept one connection for each of the provided + * addresses. + */ +std::vector accept_connections(const std::vector& addresses) { + std::vector sockets; + int success; + + for (auto& address : addresses) { + // Create the socket to wait for connections from the peers + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::ostringstream msg; + msg << "[ring] Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Make sure we can launch immediately after shutdown by setting the + // reuseaddr option so that we don't get address already in use errors + int enable = 1; + success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "[ring] Couldn't enable reuseport (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Bind the socket to the address and port + success = bind(sock, address.get(), address.len); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "[ring] Couldn't bind socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Wait for connections + success = listen(sock, 0); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "[ring] Couldn't listen (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + int peer_socket = accept(sock, nullptr, nullptr); + if (peer_socket < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "[ring] Accept failed (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Close the listening socket + shutdown(sock, 2); + close(sock); + + sockets.push_back(peer_socket); + } + + return sockets; +} + +/** + * The counterpoint of `accept_connections`. Basically connect to each of the + * provided addresses. + */ +std::vector make_connections( + const std::vector& addresses, + bool verbose) { + std::vector sockets; + int success; + + for (auto& address : addresses) { + int sock; + + // Attempt to connect to the peer CONN_ATTEMPTS times with exponential + // backoff. TODO: Do we need that? + for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) { + // Create the socket + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::ostringstream msg; + msg << "[ring] Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + if (attempt > 0) { + int wait = (1 << (attempt - 1)) * CONN_WAIT; + log_info( + verbose, + "Attempt", + attempt, + "wait", + wait, + "ms (error:", + errno, + ")"); + std::this_thread::sleep_for(std::chrono::milliseconds(wait)); + } + + success = connect(sock, address.get(), address.len); + if (success == 0) { + break; + } + } + if (success < 0) { + std::ostringstream msg; + msg << "[ring] Couldn't connect (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + sockets.push_back(sock); + } + + return sockets; +} + +array ensure_row_contiguous(const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } +} + +template +void sum_inplace(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output += *input; + input++; + output++; + } +} + +template +void _send(int sock, T* data, size_t start, size_t stop) { + if (stop <= start) { + return; + } + data += start; + size_t len = (stop - start) * sizeof(T); + const char* buffer = (const char*)data; + while (len > 0) { + ssize_t r = send(sock, buffer, len, 0); + if (r <= 0) { + std::ostringstream msg; + msg << "Send of " << len << " bytes failed (errno: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + buffer += r; + len -= r; + } +} + +template +void _recv(int sock, T* data, size_t start, size_t stop) { + if (stop <= start) { + return; + } + data += start; + size_t len = (stop - start) * sizeof(T); + char* buffer = (char*)data; + while (len > 0) { + ssize_t r = recv(sock, buffer, len, 0); + if (r <= 0) { + std::ostringstream msg; + msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + buffer += r; + len -= r; + } +} + +template +void _recv_sum(int sock, T* data, size_t start, size_t stop) { + if (stop <= start) { + return; + } + data += start; + char buffer[PACKET_SIZE]; + size_t len = (stop - start) * sizeof(T); + while (len > 0) { + ssize_t r = 0; + do { + ssize_t partial_r = + recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0); + if (partial_r <= 0) { + std::ostringstream msg; + msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + r += partial_r; + } while (r % sizeof(T)); + sum_inplace((const T*)buffer, data, r / sizeof(T)); + data += r / sizeof(T); + len -= r; + } +} + +template +void ring_send( + Barrier& barrier, + int socket, + int rank, + int size, + T* data, + size_t data_size, + int direction = -1) { + // We split the data into `size_` segments of size `segment_size` + size_t segment_size = ceildiv(data_size, size); + + // Initial segment + int segment = rank; + + // 1st send + for (int i = 0; i < size - 1; i++) { + size_t start = segment * segment_size; + size_t stop = std::min((segment + 1) * segment_size, data_size); + _send(socket, data, start, stop); + barrier.arrive_and_wait(); + segment = (segment + size + direction) % size; + } + + // 2nd send + for (int i = 0; i < size - 1; i++) { + size_t start = segment * segment_size; + size_t stop = std::min((segment + 1) * segment_size, data_size); + _send(socket, data, start, stop); + barrier.arrive_and_wait(); + segment = (segment + size + direction) % size; + } +} + +template +void ring_recv_sum( + Barrier& barrier, + int socket, + int rank, + int size, + T* data, + size_t data_size, + int direction = -1) { + // We split the data into `size_` segments of size `segment_size` + size_t segment_size = ceildiv(data_size, size); + + // Initial segment + int segment = (rank + size + direction) % size; + + // Recv sum + for (int i = 0; i < size - 1; i++) { + size_t start = segment * segment_size; + size_t stop = std::min((segment + 1) * segment_size, data_size); + _recv_sum(socket, data, start, stop); + barrier.arrive_and_wait(); + segment = (segment + size + direction) % size; + } + + // Recv + for (int i = 0; i < size - 1; i++) { + size_t start = segment * segment_size; + size_t stop = std::min((segment + 1) * segment_size, data_size); + _recv(socket, data, start, stop); + barrier.arrive_and_wait(); + segment = (segment + size + direction) % size; + } +} + +} // namespace + +class RingGroup : public GroupImpl { + public: + RingGroup(int rank, std::vector> nodes, bool verbose) + : rank_(rank), verbose_(verbose), pool_(0) { + if (rank_ > 0 && rank_ >= nodes.size()) { + throw std::runtime_error( + "[ring] Rank cannot be larger than the size of the group"); + } + + size_ = nodes.size(); + int connect_to = (rank_ + 1) % size_; + + // We define the connection order by having the rank_ == size_ - 1 connect + // first and accept after. + if (rank_ < connect_to) { + log_info(verbose_, "Rank", rank_, "accepting"); + recv_sockets_ = std::move(accept_connections(nodes[rank_])); + log_info(verbose_, "Rank", rank_, "connecting to", connect_to); + send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); + } else { + log_info(verbose_, "Rank", rank_, "connecting to", connect_to); + send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); + log_info(verbose_, "Rank", rank_, "accepting"); + recv_sockets_ = std::move(accept_connections(nodes[rank_])); + } + + // Failure if we couldn't make send or recv sockets + if (send_sockets_.empty()) { + std::ostringstream msg; + msg << "[ring] Rank " << rank_ << " has no send sockets."; + throw std::invalid_argument(msg.str()); + } + if (recv_sockets_.empty()) { + std::ostringstream msg; + msg << "[ring] Rank " << rank_ << " has no recv sockets."; + throw std::invalid_argument(msg.str()); + } + + // The following could be relaxed since we can define non-homogeneous rings + // but it makes things a bit simpler for now. + if (send_sockets_.size() != recv_sockets_.size()) { + std::ostringstream msg; + msg << "[ring] It is required to have as many connections to the left as " + << "to the right but rank " << rank_ << " has " + << send_sockets_.size() << " connections to the right and " + << recv_sockets_.size() << " to the left."; + throw std::invalid_argument(msg.str()); + } + + // Start the necessary threads for completely parallel operation on all + // channels. One thread to send, one to receive per socket. + pool_.resize(send_sockets_.size() * 2 * 2); + } + + ~RingGroup() { + for (auto s : send_sockets_) { + shutdown(s, 2); + close(s); + } + for (auto s : recv_sockets_) { + shutdown(s, 2); + close(s); + } + } + + int rank() override { + return rank_; + } + + int size() override { + return size_; + } + + void all_sum(const array& input_, array& output) override { + SWITCH_TYPE(output, all_sum(input_, output)); + } + + std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[ring] Group split not supported."); + } + void all_gather(const array& input, array& output) override { + throw std::runtime_error("[ring] All gather not supported."); + } + void send(const array& input, int dst) override { + throw std::runtime_error("[ring] Send not supported."); + } + void recv(array& out, int src) override { + throw std::runtime_error("[ring] Recv not supported."); + } + + private: + template + void all_sum(const array& input_, array& output) { + // Make sure that the input is row contiguous + array input = ensure_row_contiguous(input_); + + // If the input data cannot be split into size_ segments then copy it and + // all reduce a local buffer prefilled with 0s. + if (input.size() < size_) { + // TODO: Maybe allocate dynamically so we don't have the constraint below? + if (input.itemsize() * size_ > 1024) { + std::ostringstream msg; + msg << "Can't perform the ring all reduce of " << output.size() + << " elements with a ring of size " << size_; + throw std::runtime_error(msg.str()); + } + + std::future sent, recvd; + auto barrier = std::make_unique(2); + char buffer[1024]; + std::memset(buffer, 0, size_ * input.itemsize()); + std::memcpy(buffer, input.data(), input.nbytes()); + sent = pool_.enqueue( + ring_send, + std::reference_wrapper(*barrier), + send_sockets_[0], + rank_, + size_, + (T*)buffer, + size_, + -1); + recvd = pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barrier), + recv_sockets_[0], + rank_, + size_, + (T*)buffer, + size_, + -1); + sent.wait(); + recvd.wait(); + std::memcpy(output.data(), buffer, output.nbytes()); + return; + } + + // If not inplace all reduce then copy the input to the output first + if (input.data() != output.data()) { + std::memcpy(output.data(), input.data(), input.nbytes()); + } + + // All reduce in place. We have `send_channels_.size()` bidirectional + // channels so let's split the message up and perform as many parallel + // ring-reductions as possible. + std::vector> reductions; + std::vector> barriers; + size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE); + + // Large all reduce territory so let's use all we got + if (packets >= 2 * send_sockets_.size()) { + size_t segment = ceildiv(output.size(), 2 * send_sockets_.size()); + for (int i = 0; i < send_sockets_.size(); i++) { + // 1st ring reduce + barriers.emplace_back(std::make_unique(2)); + reductions.push_back(pool_.enqueue( + ring_send, + std::reference_wrapper(*barriers.back()), + send_sockets_[i], + rank_, + size_, + output.data() + 2 * i * segment, + std::min(output.size() - 2 * i * segment, segment), + -1)); + reductions.push_back(pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barriers.back()), + recv_sockets_[i], + rank_, + size_, + output.data() + 2 * i * segment, + std::min(output.size() - 2 * i * segment, segment), + -1)); + + // 2nd ring reduce + barriers.emplace_back(std::make_unique(2)); + reductions.push_back(pool_.enqueue( + ring_send, + std::reference_wrapper(*barriers.back()), + recv_sockets_[i], + rank_, + size_, + output.data() + (2 * i + 1) * segment, + std::min(output.size() - (2 * i + 1) * segment, segment), + 1)); + reductions.push_back(pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barriers.back()), + send_sockets_[i], + rank_, + size_, + output.data() + (2 * i + 1) * segment, + std::min(output.size() - (2 * i + 1) * segment, segment), + 1)); + } + } + + // At least 2 reductions so we can be from small to medium + else if (packets > 1) { + size_t segment = ceildiv(output.size(), packets); + for (int i = 0; i < send_sockets_.size(); i++) { + barriers.emplace_back(std::make_unique(2)); + reductions.push_back(pool_.enqueue( + ring_send, + std::reference_wrapper(*barriers.back()), + send_sockets_[i], + rank_, + size_, + output.data() + i * segment, + std::min(output.size() - i * segment, segment), + -1)); + reductions.push_back(pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barriers.back()), + recv_sockets_[i], + rank_, + size_, + output.data() + i * segment, + std::min(output.size() - i * segment, segment), + -1)); + } + for (int i = 0; i < packets - send_sockets_.size(); i++) { + barriers.emplace_back(std::make_unique(2)); + reductions.push_back(pool_.enqueue( + ring_send, + std::reference_wrapper(*barriers.back()), + recv_sockets_[i], + rank_, + size_, + output.data() + (send_sockets_.size() + i) * segment, + std::min( + output.size() - (send_sockets_.size() + i) * segment, segment), + 1)); + reductions.push_back(pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barriers.back()), + send_sockets_[i], + rank_, + size_, + output.data() + (send_sockets_.size() + i) * segment, + std::min( + output.size() - (send_sockets_.size() + i) * segment, segment), + 1)); + } + } + + // Small reduction which won't really benefit much from parallelization. + // TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a + // fairly large array. + else { + barriers.emplace_back(std::make_unique(2)); + reductions.push_back(pool_.enqueue( + ring_send, + std::reference_wrapper(*barriers.back()), + send_sockets_[0], + rank_, + size_, + output.data(), + output.size(), + -1)); + reductions.push_back(pool_.enqueue( + ring_recv_sum, + std::reference_wrapper(*barriers.back()), + recv_sockets_[0], + rank_, + size_, + output.data(), + output.size(), + -1)); + } + + // Wait for the reductions to finish. + for (auto& f : reductions) { + f.wait(); + } + } + + int rank_; + int size_; + + bool verbose_; + + ThreadPool pool_; + + std::vector send_sockets_; + std::vector recv_sockets_; +}; + +bool is_available() { + return true; +} + +std::shared_ptr init(bool strict /* = false */) { + const char* hostfile = std::getenv("MLX_HOSTFILE"); + const char* rank_str = std::getenv("MLX_RANK"); + const char* ring_verbose = std::getenv("MLX_RING_VERBOSE"); + + if (!hostfile || !rank_str) { + if (strict) { + std::ostringstream msg; + msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) " + << "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\"" + << ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\"" + << ((hostfile) ? hostfile : "") << "\""; + throw std::runtime_error(msg.str()); + } + return nullptr; + } + + auto nodes = load_nodes(hostfile); + int rank = std::atoi(rank_str); + + return std::make_shared(rank, nodes, ring_verbose != nullptr); +} + +} // namespace mlx::core::distributed::ring diff --git a/mlx/distributed/ring/ring.h b/mlx/distributed/ring/ring.h new file mode 100644 index 0000000000..e0b3fd093c --- /dev/null +++ b/mlx/distributed/ring/ring.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::ring { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::ring diff --git a/mlx/io/load.h b/mlx/io/load.h index 4160d66939..138098e826 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -13,7 +13,7 @@ #include #endif -#include "mlx/io/threadpool.h" +#include "mlx/threadpool.h" // Strictly we need to operate on files in binary mode (to avoid \r getting // automatically inserted), but every modern system except for Windows no diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9dda4f228c..311ad830fe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -652,7 +652,7 @@ void normalize_dynamic_slice_inputs( const array& a, const array& start, std::vector& axes, - const std::string prefix) { + std::string_view prefix) { if (start.size() > a.ndim()) { std::ostringstream msg; msg << prefix << " Invalid number of starting positions for " @@ -690,7 +690,9 @@ void normalize_dynamic_slice_inputs( } std::set dims(axes.begin(), axes.end()); if (dims.size() != axes.size()) { - throw std::invalid_argument(prefix + " Repeat axes not allowed."); + std::ostringstream msg; + msg << prefix << " Repeat axes not allowed."; + throw std::invalid_argument(msg.str()); } } @@ -927,7 +929,7 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) { std::vector meshgrid( const std::vector& arrays, bool sparse /* = false */, - std::string indexing /* = "xy" */, + const std::string& indexing /* = "xy" */, StreamOrDevice s /* = {} */) { if (indexing != "xy" && indexing != "ij") { throw std::invalid_argument( @@ -1186,7 +1188,7 @@ array pad( const Shape& low_pad_size, const Shape& high_pad_size, const array& pad_value /*= array(0)*/, - const std::string mode /*= "constant"*/, + const std::string& mode /*= "constant"*/, StreamOrDevice s /* = {}*/) { if (axes.size() != low_pad_size.size() || axes.size() != high_pad_size.size()) { @@ -1238,7 +1240,7 @@ array pad( const array& a, const std::vector>& pad_width, const array& pad_value /*= array(0)*/, - const std::string mode /*= "constant"*/, + const std::string& mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { std::vector axes(a.ndim(), 0); std::iota(axes.begin(), axes.end(), 0); @@ -1258,7 +1260,7 @@ array pad( const array& a, const std::pair& pad_width, const array& pad_value /*= array(0)*/, - const std::string mode /*= "constant"*/, + const std::string& mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { return pad( a, @@ -1272,7 +1274,7 @@ array pad( const array& a, int pad_width, const array& pad_value /*= array(0)*/, - const std::string mode /*= "constant"*/, + const std::string& mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { return pad( a, diff --git a/mlx/ops.h b/mlx/ops.h index 26bff5ec8b..141cfde709 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -222,7 +222,7 @@ split(const array& a, const Shape& indices, StreamOrDevice s = {}); std::vector meshgrid( const std::vector& arrays, bool sparse = false, - std::string indexing = "xy", + const std::string& indexing = "xy", StreamOrDevice s = {}); /** @@ -274,7 +274,7 @@ array pad( const Shape& low_pad_size, const Shape& high_pad_size, const array& pad_value = array(0), - const std::string mode = "constant", + const std::string& mode = "constant", StreamOrDevice s = {}); /** Pad an array with a constant value along all axes */ @@ -282,19 +282,19 @@ array pad( const array& a, const std::vector>& pad_width, const array& pad_value = array(0), - const std::string mode = "constant", + const std::string& mode = "constant", StreamOrDevice s = {}); array pad( const array& a, const std::pair& pad_width, const array& pad_value = array(0), - const std::string mode = "constant", + const std::string& mode = "constant", StreamOrDevice s = {}); array pad( const array& a, int pad_width, const array& pad_value = array(0), - const std::string mode = "constant", + const std::string& mode = "constant", StreamOrDevice s = {}); /** Permutes the dimensions in reverse order. */ diff --git a/mlx/io/threadpool.h b/mlx/threadpool.h similarity index 81% rename from mlx/io/threadpool.h rename to mlx/threadpool.h index 0ed9859483..b0e56d0f2c 100644 --- a/mlx/io/threadpool.h +++ b/mlx/threadpool.h @@ -38,9 +38,13 @@ class ThreadPool { template auto enqueue(F&& f, Args&&... args) -> std::future>; + void resize(size_t); ~ThreadPool(); private: + void stop_and_wait(); + void start_threads(size_t); + std::vector workers; std::queue> tasks; std::mutex queue_mutex; @@ -49,24 +53,7 @@ class ThreadPool { }; inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); + start_threads(threads); } template @@ -92,12 +79,55 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) return res; } +inline void ThreadPool::resize(size_t threads) { + if (workers.size() == threads) { + return; + } + + if (workers.size() > threads) { + stop_and_wait(); + } + start_threads(threads - workers.size()); +} + inline ThreadPool::~ThreadPool() { + stop_and_wait(); +} + +inline void ThreadPool::stop_and_wait() { + // Stop the current threads and wait until they finish { std::unique_lock lock(queue_mutex); stop = true; } condition.notify_all(); - for (std::thread& worker : workers) + for (std::thread& worker : workers) { worker.join(); + } + + // Reset the member variables so that the threadpool is reusable + stop = false; + workers.clear(); +} + +inline void ThreadPool::start_threads(size_t threads) { + for (size_t i = 0; i < threads; ++i) { + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); + } } diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index f0459b8d3a..f3df1904d2 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) { "init", &mx::distributed::init, "strict"_a = false, - nb::sig("def init(strict: bool = False) -> Group"), + "backend"_a = "any", + nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"), R"pbdoc( Initialize the communication backend and create the global communication group. + Example: + + import mlx.core as mx + + group = mx.distributed.init(backend="ring") + + Args: strict (bool, optional): If set to False it returns a singleton group in case ``mx.distributed.is_available()`` returns False otherwise it throws a runtime error. Default: ``False`` + backend (str, optional): Select a specific distributed backend to + initialize. If set to ``any`` then try all available backends and + return the first one that succeeds. Subsequent calls will return + the first backend that was initialized. Default: ``any`` Returns: Group: The group representing all the launched processes. diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 2af7fcf9a9..79c58d8b6f 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -34,6 +34,8 @@ def test_all_reduce(self): mx.int32, mx.uint32, mx.float32, + mx.float16, + mx.bfloat16, mx.complex64, ] for dt in dtypes: diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py new file mode 100644 index 0000000000..215ecb44a4 --- /dev/null +++ b/python/tests/ring_test_distributed.py @@ -0,0 +1,61 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestRingDistributed(mlx_tests.MLXTestCase): + @classmethod + def setUpClass(cls): + world = mx.distributed.init(strict=True, backend="ring") + + def test_groups(self): + world = mx.distributed.init() + self.assertEqual(world.size(), 8) + self.assertTrue(0 <= world.rank() < 8) + + world2 = mx.distributed.init() + self.assertEqual(world.size(), world2.size()) + self.assertEqual(world.rank(), world2.rank()) + + with self.assertRaises(RuntimeError): + sub = world.split(world.rank() % 2) + + def test_all_reduce(self): + world = mx.distributed.init() + dtypes = [ + (mx.int8, 0), + (mx.uint8, 0), + (mx.int16, 0), + (mx.uint16, 0), + (mx.int32, 0), + (mx.uint32, 0), + (mx.float32, 1e-6), + (mx.float16, 5e-3), + (mx.bfloat16, 1e-1), + (mx.complex64, 1e-6), + ] + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + for dt, rtol in dtypes: + for sh in sizes: + x = ( + mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 + ).astype(dt) + y = mx.distributed.all_sum(x[world.rank()]) + z = sum( + x[i] for i in range(world.size()) + ) # to ensure that we don't sum to int32 + maxrelerror = ((y - z).abs() / z.abs()).max() + self.assertLessEqual(maxrelerror, rtol) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/run_ring_test.sh b/python/tests/run_ring_test.sh new file mode 100644 index 0000000000..3106e49173 --- /dev/null +++ b/python/tests/run_ring_test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +tmpfile=$(mktemp) +cat <$tmpfile +[ + ["127.0.0.1:5000"], + ["127.0.0.1:5001"], + ["127.0.0.1:5002"], + ["127.0.0.1:5003"], + ["127.0.0.1:5004"], + ["127.0.0.1:5005"], + ["127.0.0.1:5006"], + ["127.0.0.1:5007"] +] +HOSTFILE + +ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py" + +for i in {0..7}; do + if (($i == 7)); then + sleep 1 + fi + DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test & +done +wait