Skip to content

Commit

Permalink
Ring distributed backend (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Jan 28, 2025
1 parent 2235dee commit ccb61d7
Show file tree
Hide file tree
Showing 17 changed files with 1,078 additions and 44 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
/bin/bash python/tests/run_ring_test.sh
- run:
name: Build example extension
command: |
Expand Down
1 change: 1 addition & 0 deletions mlx/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
53 changes: 43 additions & 10 deletions mlx/distributed/distributed.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright © 2024 Apple Inc.

#include <unordered_map>

#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/ring/ring.h"
#include "mlx/scheduler.h"

namespace mlx::core::distributed {
Expand Down Expand Up @@ -65,7 +68,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail

bool is_available() {
return mpi::is_available();
return mpi::is_available() || ring::is_available();
}

int Group::rank() const {
Expand All @@ -80,20 +83,50 @@ Group Group::split(int color, int key /* = -1 */) const {
return Group(group_->split(color, key));
}

Group init(bool strict /* = false */) {
auto init_group = [strict]() {
auto default_group = mpi::init(strict);
if (default_group == nullptr) {
default_group = std::make_shared<detail::EmptyGroup>();
Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>
backends;

// Already initialized so return the group.
if (auto g = backends.find(bk); g != backends.end()) {
return Group(g->second);
}

// Create the requested communication group
std::shared_ptr<detail::GroupImpl> group;
std::string bk_ = bk;
if (bk == "mpi") {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";
if (group == nullptr) {
group = mpi::init(false);
bk_ = "mpi";
}
return default_group;
};
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

if (group == nullptr) {
group = std::make_shared<detail::EmptyGroup>();
} else {
backends.insert({"any", group});
}
backends.insert({std::move(bk_), group});

// Ensure the communication stream is alive before
// the graph is evaluated
detail::communication_stream();
return Group(default_group);
return Group(group);
}

} // namespace mlx::core::distributed
2 changes: 1 addition & 1 deletion mlx/distributed/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ struct Group {
* distributed subsystem. Otherwise simply return a singleton group which will
* render communication operations as no-op.
*/
Group init(bool strict = false);
Group init(bool strict = false, const std::string& bk = "any");

} // namespace mlx::core::distributed
2 changes: 2 additions & 0 deletions mlx/distributed/distributed_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace mlx::core::distributed::detail {
*/
class GroupImpl {
public:
virtual ~GroupImpl() {}

virtual int rank() = 0;
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
Expand Down
5 changes: 5 additions & 0 deletions mlx/distributed/ring/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)
endif()
20 changes: 20 additions & 0 deletions mlx/distributed/ring/no_ring.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.

#include "mlx/distributed/ring/ring.h"

namespace mlx::core::distributed::ring {

using GroupImpl = mlx::core::distributed::detail::GroupImpl;

bool is_available() {
return false;
}

std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize ring distributed backend.");
}
return nullptr;
}

} // namespace mlx::core::distributed::ring
Loading

0 comments on commit ccb61d7

Please sign in to comment.