diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 4ea4d8573..f36226fd7 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -31,8 +31,30 @@ array ensure_row_contiguous(const array& arr) { } } +// TODO: Change to a vectorized sum +template +void simple_sum( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc += *in; + acc++; + in++; + } +} +template void simple_sum(void*, void*, int*, MPI_Datatype*); +template void simple_sum(void*, void*, int*, MPI_Datatype*); + struct MPIWrapper { MPIWrapper() { + initialized_ = false; + libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; @@ -47,6 +69,9 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Comm_free, comm_free); LOAD_SYMBOL(MPI_Allreduce, all_reduce); LOAD_SYMBOL(MPI_Allgather, all_gather); + LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); + LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); + LOAD_SYMBOL(MPI_Op_create, mpi_op_create); // Objects LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); @@ -76,7 +101,24 @@ struct MPIWrapper { if (!is_available()) { return false; } - return init(nullptr, nullptr) == MPI_SUCCESS; + bool success = init(nullptr, nullptr) == MPI_SUCCESS; + + // Initialize custom types and ops + if (success && !initialized_) { + // Custom float16 dtypes + mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_); + mpi_type_commit(&mpi_float16_); + mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); + mpi_type_commit(&mpi_bfloat16_); + + // Custom sum ops + mpi_op_create(&simple_sum, 1, &op_sum_f16_); + mpi_op_create(&simple_sum, 1, &op_sum_bf16_); + + initialized_ = true; + } + + return success; } void finalize_safe() { @@ -114,13 +156,21 @@ struct MPIWrapper { case complex64: return mpi_complex_; case float16: + return mpi_float16_; case bfloat16: - throw std::runtime_error("MPI doesn't support 16-bit floats"); + return mpi_bfloat16_; } } - MPI_Op op_sum() { - return op_sum_; + MPI_Op op_sum(const array& arr) { + switch (arr.dtype()) { + case float16: + return op_sum_f16_; + case bfloat16: + return op_sum_bf16_; + default: + return op_sum_; + } } void* libmpi_handle_; @@ -147,6 +197,8 @@ struct MPIWrapper { // Ops MPI_Op op_sum_; + MPI_Op op_sum_f16_; + MPI_Op op_sum_bf16_; // Datatypes MPI_Datatype mpi_bool_; @@ -160,6 +212,16 @@ struct MPIWrapper { MPI_Datatype mpi_uint64_; MPI_Datatype mpi_float_; MPI_Datatype mpi_complex_; + MPI_Datatype mpi_float16_; + MPI_Datatype mpi_bfloat16_; + + private: + bool initialized_; + + // Private API + int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*); + int (*mpi_type_commit)(MPI_Datatype*); + int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*); }; MPIWrapper& mpi() { @@ -268,7 +330,7 @@ void all_sum(Group group, const array& input_, array& output) { output.data(), input.size(), mpi().datatype(input), - mpi().op_sum(), + mpi().op_sum(input), to_comm(group)); }