diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 6bb6bd0ac6..f1ed2ffe3e 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -119,8 +119,8 @@ def add_parser_api_server(): ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) # multi-node serving args - ArgumentHelper.node_rank(parser) - ArgumentHelper.num_nodes(parser) + node_rank_act = ArgumentHelper.node_rank(pt_group) + num_nodes_act = ArgumentHelper.num_nodes(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') @@ -135,6 +135,8 @@ def add_parser_api_server(): tb_group._group_actions.append(max_prefill_token_num_act) tb_group._group_actions.append(quant_policy) tb_group._group_actions.append(model_format) + tb_group._group_actions.append(num_nodes_act) + tb_group._group_actions.append(node_rank_act) tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(disable_metrics) tb_group._group_actions.append(dp) @@ -143,6 +145,7 @@ def add_parser_api_server(): ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.communicator(tb_group) + ArgumentHelper.dist_init_addr(tb_group) # vlm args vision_group = parser.add_argument_group('Vision model arguments') @@ -239,6 +242,9 @@ def api_server(args): tp=args.tp, dp=args.dp, cp=args.cp, + nnodes=args.nnodes, + node_rank=args.node_rank, + dist_init_addr=args.dist_init_addr, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index a4f368be5f..c9060d1fa5 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -219,6 +219,12 @@ def num_nodes(parser): return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums') + @staticmethod + def dist_init_addr(parser): + """Add argument dist_init_addr to parser.""" + + return parser.add_argument('--dist-init-addr', type=str, default=None) + @staticmethod def session_id(parser): """Add argument session_id to parser.""" diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 69ac157652..f38d15c26f 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -246,6 +246,10 @@ class TurbomindEngineConfig: mlp_tp_size: int = None mlp_dp_size: int = None outer_dp_size: int = None + nnodes: int = 1 + node_rank: int = 0 + dist_init_addr: Optional[str] = None + devices: List[int] = None session_len: Optional[int] = None max_batch_size: int = None cache_max_entry_count: float = 0.8 diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 2be007aef7..60ba7d5140 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -399,7 +399,9 @@ def _build_stat_loggers(self): if getattr(self.backend_config, 'enable_metrics', False): from lmdeploy.metrics.loggers import LoggingStatLogger, PrometheusStatLogger - dp_rank = self.backend_config.dp_rank if self.backend_config.dp > 1 else 0 + + # currently, metrics in TM engine doesn't support dp + dp_rank = self.backend_config.dp_rank if self.backend == 'pytorch' else 0 logger.info(f'enable metrics, with dp: {self.backend_config.dp} dp_rank: {dp_rank}') self.stat_loggers = [ diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5dc5a2bfde..577d465784 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1228,6 +1228,9 @@ async def startup_event(): if VariableInterface.proxy_url is None: return + elif getattr(async_engine.engine, 'is_dummy', False): + logger.info('Dummy node started') + return try: import requests engine_config = VariableInterface.async_engine.backend_config diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index 5695cc2325..126bf5e800 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -63,7 +63,7 @@ class ModelConfig: attn_sink: bool = False qk_norm: bool = False size_per_head: int = 128 - group_size: int = 64 + group_size: int = 32 data_type: str = None weight_type: str = None expert_weight_type: str = None diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index d796848259..cf70f76a70 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -162,9 +162,10 @@ def _tofile(tensor, path): torch_tensor = torch_tensor.bfloat16() else: torch_tensor = torch_tensor.half() - for tm_tensor in tm_params[name]: - tm_tensor.copy_from(torch_tensor) - tm_params.pop(name) + if name in tm_params: + for tm_tensor in tm_params[name]: + tm_tensor.copy_from(torch_tensor) + tm_params.pop(name) else: tprint('skip export', name, param.shape) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 8ce9373258..5c3fb11ed3 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -5,6 +5,7 @@ import copy import json import math +import os import os.path as osp import sys from collections.abc import Sequence @@ -86,10 +87,11 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): + cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num if not complete_parallel_config(cfg): total = cfg.dp * cfg.tp if not cfg.device_num: - count = torch.cuda.device_count() + count = torch.cuda.device_count() * cfg.nnodes if total < count: count = total cfg.device_num = count @@ -106,7 +108,10 @@ def update_parallel_config(cfg: TurbomindEngineConfig): cfg.mlp_tp_size = mlp_tp_size * inner_tp_size assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num - cfg.devices = cfg.devices or list(range(cfg.device_num)) + # update devices + cfg.devices = cfg.devices or list(range(cfg.device_num // cfg.nnodes)) + cfg.devices = cfg.devices[:cfg.device_num // cfg.nnodes] + assert len(cfg.devices) == cfg.device_num // cfg.nnodes class TurboMind: @@ -142,14 +147,27 @@ def __init__(self, f' greater than 0, but got {_engine_config.max_batch_size}' update_parallel_config(_engine_config) - - self.gpu_count = _engine_config.device_num + if _engine_config.nnodes > 1: + logger.info(f'dist_init_addr={_engine_config.dist_init_addr}') + assert _engine_config.dist_init_addr is not None + hostname, port = _engine_config.dist_init_addr.split(':') + os.environ['LMDEPLOY_DIST_INIT_ADDR'] = hostname + os.environ['LMDEPLOY_DIST_INIT_PORT'] = port + # this will block the process and ignore signals until all ranks done + from torch.distributed import TCPStore + self.store = TCPStore(host_name=hostname, + port=int(port), + world_size=_engine_config.nnodes, + is_master=_engine_config.node_rank == 0) + + self.gpu_count = len(_engine_config.devices) self.devices = _engine_config.devices self._engine_created = False if not osp.exists(model_path): model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision) self.model_comm = self._from_hf(model_path=model_path, engine_config=_engine_config) + self.is_dummy = self.model_comm.is_dummy_node() self.tokenizer = Tokenizer(model_path) if not _engine_config.empty_init: self._load_weights() @@ -192,10 +210,8 @@ def _create_engine(self): def _create_weight(self, model_comm): """Allocate weight buffer, load params if from_workspace.""" - # TODO: support mpi - self.node_id = 0 - self.node_num = 1 - torch.cuda.synchronize() + engine_cfg = self.config_dict['engine_config'] + self.node_id = engine_cfg['node_rank'] # create weight def _create_weight_func(device_id): @@ -382,6 +398,8 @@ def close(self): if self.model_comm is not None: self.model_comm = None self._engine_created = False + if hasattr(self, 'store'): + del self.store def create_instance(self, cuda_stream_id=0): """Create a turbomind instance. @@ -527,11 +545,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea self.tm_model = tm_model self.cuda_stream_id = cuda_stream_id - self.node_id = tm_model.node_id - self.gpu_count = tm_model.gpu_count - - self.session_len = tm_model.session_len - # create model instances lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False) self._model_inst = None if lazy_init else self._create_model_instance(0) diff --git a/src/turbomind/comm/CMakeLists.txt b/src/turbomind/comm/CMakeLists.txt index 0a3b2b4ea3..d63374b945 100644 --- a/src/turbomind/comm/CMakeLists.txt +++ b/src/turbomind/comm/CMakeLists.txt @@ -22,6 +22,9 @@ if (BUILD_MULTI_GPU) target_link_libraries(device_comm INTERFACE nccl_comm) endif () + add_subdirectory(gloo) + target_link_libraries(host_comm INTERFACE gloo_comm) + if (BUILD_TEST) add_executable(test_comm test_comm.cu) target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils) diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt new file mode 100644 index 0000000000..395b7ee55a --- /dev/null +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +cmake_minimum_required(VERSION 3.8) + +include(FetchContent) +FetchContent_Declare( + gloo + GIT_REPOSITORY https://github.com/pytorch/gloo.git + GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4 +) + +# some settings of gloo, +set(GLOO_INSTALL OFF CACHE BOOL "" FORCE) +set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) +set(USE_NCCL OFF) +set(BUILD_TEST OFF) +set(USE_IBVERBS OFF) +FetchContent_MakeAvailable(gloo) + +# gloo build doesn't add include directories as a target property... +target_include_directories(gloo PUBLIC + $ + $ # config.h generated at cmake config time +) + +target_compile_options(gloo PRIVATE + $<$:/W0> + $<$,$>:-w> +) + +add_library(gloo_comm STATIC + gloo_comm.cc + hybrid_comm.cc + tcp_store.cc +) +set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(gloo_comm PUBLIC gloo host_comm logger xgrammar) + +add_executable(test_ipc_comm test_ipc_comm.cc) +target_link_libraries(test_ipc_comm PRIVATE gloo_comm Threads::Threads) diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc new file mode 100644 index 0000000000..10bb4a7974 --- /dev/null +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -0,0 +1,381 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if GLOO_HAVE_TRANSPORT_IBVERBS +#include "gloo/transport/ibverbs/device.h" +#endif + +#include "src/turbomind/comm/gloo/tcp_store.h" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +const char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; +const char STORE_INFO_DELIM = ','; + +std::shared_ptr<::gloo::transport::Device> createGlooDevice() +{ +#if GLOO_HAVE_TRANSPORT_IBVERBS + if (auto transport = std::getenv("GLOO_DEVICE_TRANSPORT"); + transport != nullptr && strcmp(transport, "ibverbs") == 0) { + ::gloo::transport::ibverbs::attr ib_attr{}; + ib_attr.name = ""; + ib_attr.port = 1; + ib_attr.index = 3; // use IBV_GID_TYPE_ROCE_V2 and ipv4 + return ::gloo::transport::ibverbs::CreateDevice(ib_attr); + } +#endif + ::gloo::transport::tcp::attr attr; + if (auto ifname = std::getenv(GLOO_SOCKET_IFNAME_ENV); ifname) { + attr.iface = ifname; + } + else { + attr.hostname = ::gloo::getHostname(); + } + return ::gloo::transport::tcp::CreateDevice(attr); +} + +class Store: public ::gloo::rendezvous::PrefixStore { +public: + explicit Store(const std::string& host, int port, const std::string& prefix): + host_(host), port_(port), ::gloo::rendezvous::PrefixStore(prefix, nullptr) + { + store_ = std::make_shared(host_, port_); + }; + + ~Store() = default; + + std::shared_ptr New(const std::string& prefix) + { + std::string new_prefix = prefix + "/" + prefix_; + return std::make_shared(host_, port_, new_prefix); + } + +public: + std::string host_; + int port_; + + using ::gloo::rendezvous::PrefixStore::store_; + using ::gloo::rendezvous::PrefixStore::prefix_; +}; + +class GlobalStoreFactory { +public: + static GlobalStoreFactory& Instance() + { + static GlobalStoreFactory instance; + return instance; + } + + std::string New() + { + std::lock_guard lock(mutex_); + TM_CHECK(std::getenv("LMDEPLOY_DIST_INIT_ADDR") != nullptr) << "LMDEPLOY_DIST_INIT_ADDR not set"; + TM_CHECK(std::getenv("LMDEPLOY_DIST_INIT_PORT") != nullptr) << "LMDEPLOY_DIST_INIT_PORT not set"; + + std::string host = std::getenv("LMDEPLOY_DIST_INIT_ADDR"); + int port = std::stoi(std::getenv("LMDEPLOY_DIST_INIT_PORT")); + + std::stringstream ss; + ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++; + return ss.str(); + } + + std::shared_ptr Load(const std::string& info) + { + std::stringstream ss(info); + std::vector keys; + std::string local; + while (getline(ss, local, STORE_INFO_DELIM)) { + keys.push_back(std::move(local)); + } + TM_CHECK(keys.size() == 3); + + std::string host = keys[0]; + int port = stoi(keys[1]); + std::string prefix = keys[2]; + + return std::make_shared(host, port, prefix); + } + +private: + GlobalStoreFactory() {} + + std::mutex mutex_; + int prefix_{0}; +}; + +typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); + +struct GlooCommImpl: public HostCommImpl { + + struct SplitInfo { + int color; + int rank; + + bool operator<(const SplitInfo& other) const + { + return (color < other.color) || (color == other.color && rank < other.rank); + } + + bool operator==(const SplitInfo& other) const + { + return (color == other.color) && (rank == other.rank); + } + }; + + GlooCommImpl(std::shared_ptr store, int n_ranks, int rank): + store_{std::move(store)}, rank_{rank}, n_ranks_{n_ranks} + { + device_ = createGlooDevice(); + context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_); + context_->setTimeout(kTimeOut); + context_->connectFullMesh(store_, device_); + } + + ~GlooCommImpl() {} + + int rank() const override + { + return rank_; + } + + int n_ranks() const override + { + return n_ranks_; + } + + bool is_same_process() const override + { + return false; + } + + std::shared_ptr Split(int color, int key) override + { + auto vec = comm::AllGather(this, SplitInfo{color, rank_}); + auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) { // + return x.color == color; + }); + vec.erase(last, vec.end()); + std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) { // + return a < b; + }); + + auto new_prefix = std::to_string(color) + ":" + std::to_string(n_split_++); + auto new_store = store_->New(new_prefix); + int new_n_ranks = vec.size(); + int new_rank = std::find(vec.begin(), vec.end(), SplitInfo{color, rank_}) - vec.begin(); + return std::make_shared(new_store, new_n_ranks, new_rank); + } + + void Sync(bool blocking) override + { + ::gloo::BarrierOptions opts(context_); + ::gloo::barrier(opts); + } + + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override + { + // trivially copyable if no ser/des function + if (!ser || !des) { + return Broadcast(data, count, dtype, root); + } + + // broadcast buffer size + size_t size; + if (root == rank()) { + ser(data, 0, count, size, nullptr); + } + Broadcast(&size, 1, data_type_v, root); + + // serialize data on root rank + std::vector bytes; + bytes.reserve(size); + if (root == rank()) { + ser(data, 0, count, size, bytes.data()); + } + + // broadcast serialized data + Broadcast(bytes.data(), size, data_type_v, root); + + // deserialize data on all ranks + if (root != rank()) { + des(data, 0, count, bytes.data(), size); + } + } + + void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override + { + // trivially copyable if no ser/des function + if (!ser || !des) { + return AllGather(data, count, dtype); + } + + // get buffer size on each rank and find max size + size_t size; + ser(data, count * rank(), count, size, nullptr); + std::vector sizes(n_ranks()); + sizes[rank()] = size; + AllGather(sizes.data(), 1, data_type_v); + auto max_size = *std::max_element(sizes.begin(), sizes.end()); + + // serialize data on each rank + std::vector bytes(max_size * n_ranks()); + ser(data, count * rank(), count, size, bytes.data() + rank() * max_size); + + // gather serialized data + AllGather(bytes.data(), max_size, data_type_v); + + // deserialize data on each rank + for (int i = 0; i < n_ranks(); ++i) { + if (i != rank()) { + des(data, i * count, count, bytes.data() + i * max_size, sizes[i]); + } + } + } + + void Broadcast(void* data, int count, DataType dtype, int root) + { + ::gloo::BroadcastOptions opts(context_); + opts.setRoot(root); + opts.setOutput((char*)data, count * byte_size(dtype)); + ::gloo::broadcast(opts); + } + + void AllGather(void* data, int count, DataType dtype) + { + ::gloo::AllgatherOptions opts(context_); + opts.setOutput((char*)data, count * byte_size(dtype) * n_ranks_); + ::gloo::allgather(opts); + } + + static ReduceFunc getReduceFunc(DataType dtype, RedOp red_op) + { + + auto dispatch_op = [&](auto t) -> ReduceFunc { + using T = decltype(t); + switch (red_op) { + case RedOp::kSum: + return ::gloo::sum; + case RedOp::kMax: + return ::gloo::max; + case RedOp::kMin: + return ::gloo::min; + default: + return {}; + } + }; + + auto dispatch = [&]() -> ReduceFunc { + switch (dtype) { + case kInt32: + return dispatch_op(int32_t{}); + case kInt64: + return dispatch_op(int64_t{}); + case kUint32: + return dispatch_op(uint32_t{}); + case kUint64: + return dispatch_op(uint64_t{}); + default: + return {}; + } + }; + + if (auto fn = dispatch()) { + return fn; + } + else { + throw std::runtime_error("not implemented"); + return {}; + } + } + + void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override + { + ::gloo::AllreduceOptions opts(context_); + opts.setReduceFunction(getReduceFunc(dtype, red_op)); + switch (dtype) { + case kInt32: + opts.setOutput((int32_t*)data, count); + break; + case kInt64: + opts.setOutput((int64_t*)data, count); + break; + case kUint32: + opts.setOutput((uint32_t*)data, count); + break; + case kUint64: + opts.setOutput((uint64_t*)data, count); + break; + default: + throw std::runtime_error("not implemented"); + } + ::gloo::allreduce(opts); + } + + // there might be very long intervals between receiving requests. + static constexpr std::chrono::milliseconds kTimeOut = std::chrono::milliseconds(1000LL * 3600 * 24 * 365); + + int n_split_{}; + std::shared_ptr<::gloo::transport::Device> device_; + std::shared_ptr<::gloo::rendezvous::Context> context_; + std::shared_ptr store_; + int rank_; + int n_ranks_; +}; + +class GlooGroupId: public HostGroupId { + + void Initialize() override + { + info_ = GlobalStoreFactory::Instance().New(); + TM_LOG_INFO("[TM][COMM] GlooGroupId=%s", info_.c_str()); + } + + void Export(std::ostream& os) override + { + os << info_; + } + + void Import(std::istream& is) override + { + std::stringstream ss; + ss << is.rdbuf(); + info_ = ss.str(); + } + + HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override + { + TM_CHECK(info_ != ""); + auto impl = std::make_shared(GlobalStoreFactory::Instance().Load(info_), n_ranks, rank); + return std::static_pointer_cast(impl); + } + +private: + std::string info_; // ip,port,prefix + std::shared_ptr<::gloo::rendezvous::Store> store_; +}; + +std::unique_ptr CreateGlooGroupId() +{ + return std::make_unique(); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/hybrid_comm.cc b/src/turbomind/comm/gloo/hybrid_comm.cc new file mode 100644 index 0000000000..18d0e4e5bf --- /dev/null +++ b/src/turbomind/comm/gloo/hybrid_comm.cc @@ -0,0 +1,220 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/check.h" + +namespace turbomind::comm { + +extern std::unique_ptr CreateThreadGroupId(); +extern std::unique_ptr CreateGlooGroupId(); + +struct HybridCommImpl: public HostCommImpl { + + HybridCommImpl(int n_ranks, int rank, int node_rank, HostGroupId* gloo_group_id, HostGroupId* thread_group_id): + n_ranks_{n_ranks}, // + rank_{rank}, + node_rank_(node_rank) + { + gloo_comm_ = gloo_group_id->CreateCommunicator(n_ranks, rank); + rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank); + same_process_ = rank_to_nodes_.front() == rank_to_nodes_.back(); + if (same_process_) { + intra_comm_ = thread_group_id->CreateCommunicator(n_ranks, rank); + } + else { + init_inter_comm(); + intra_comm_ = thread_group_id->CreateCommunicator(intra_n_ranks_, rank_to_intra_[rank_]); + } + } + + HybridCommImpl(std::shared_ptr gloo_comm, std::shared_ptr intra_comm, int node_rank): + gloo_comm_{std::move(gloo_comm)}, + intra_comm_{std::move(intra_comm)}, + rank_{gloo_comm_->rank()}, + n_ranks_{gloo_comm_->n_ranks()}, + node_rank_(node_rank) + { + rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank); + same_process_ = rank_to_nodes_.front() == rank_to_nodes_.back(); + if (same_process_) {} + else { + init_inter_comm(); + } + } + + void init_inter_comm() + { + int intra_n_ranks = 0; + int intra_rank = -1; + for (int r = 0; r < n_ranks_; ++r) { + if (rank_to_nodes_[r] == node_rank_) { + if (r == rank_) { + intra_rank = intra_n_ranks; + } + intra_n_ranks++; + } + } + + intra_n_ranks_ = intra_n_ranks; + gloo_comm_->AllReduce(&intra_n_ranks_, 1, DataType::kInt, RedOp::kMin); + TM_CHECK_EQ(intra_n_ranks_, intra_n_ranks) << "The number of ranks in each node should be same."; + TM_CHECK_GT(intra_rank, -1) << "Invalid intra_rank."; + rank_to_intra_ = ::turbomind::comm::AllGather(gloo_comm_, intra_rank); + + inter_comm_ = gloo_comm_->Split(rank_to_intra_[rank_], 0); + rank_to_inter_ = ::turbomind::comm::AllGather(gloo_comm_, inter_comm_->rank()); + } + + std::shared_ptr Split(int color, int key) override + { + if (!is_same_process()) { + auto new_gloo_comm = gloo_comm_->Split(color, key); + auto new_intra_comm = intra_comm_->Split(color, key); + return std::make_shared(new_gloo_comm, new_intra_comm, node_rank_); + } + else { + return intra_comm_->Split(color, key); + } + } + + int rank() const override + { + return rank_; + } + + int n_ranks() const override + { + return n_ranks_; + } + + bool is_same_process() const override + { + return same_process_; + } + + void Sync(bool blocking) override + { + if (!is_same_process() && rank_to_intra_[rank_] == 0) { + inter_comm_->Sync(blocking); + } + intra_comm_->Sync(blocking); + } + + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override + { + if (!ser || !des) { + return Broadcast(data, count, dtype, root, copy); + } + + if (rank_to_intra_[root] == rank_to_intra_[rank_]) { // same ith rank in node + inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy, ser, des); + } + intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy); + } + + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) + { + if (is_same_process()) { + return intra_comm_->Broadcast(data, count, dtype, root, copy); + } + + if (rank_to_intra_[root] == rank_to_intra_[rank_]) { // same ith rank in node + inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy); + } + intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy); + } + + void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override + { + if (!ser || !des) { + return AllGather(data, count, dtype, copy); + } + + return gloo_comm_->AllGather(data, count, dtype, copy, ser, des); + } + + void AllGather(void* data, int count, DataType dtype, copy_fn copy) + { + if (is_same_process()) { + return intra_comm_->AllGather(data, count, dtype, copy); + } + + // TODO: support allgatherv in gloo comm (each node may has different rank size) + return gloo_comm_->AllGather(data, count, dtype, copy); + } + + void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override + { + if (is_same_process()) { + return intra_comm_->AllReduce(data, count, dtype, red_op); + } + + intra_comm_->AllReduce(data, count, dtype, red_op); + if (rank_to_intra_[rank_] == 0) { + inter_comm_->AllReduce(data, count, dtype, red_op); + } + intra_comm_->Broadcast(data, byte_size(dtype) * count, data_type_v, 0, detail::copy_fn); + } + + HostComm gloo_comm_{}; // primitive comm, used for initializing inter_comm and intra_comm + HostComm inter_comm_{}; // inter-node comm + HostComm intra_comm_{}; // intra-node comm + + int rank_; // group rank + int n_ranks_; // group size + int node_rank_; // node rank + int intra_n_ranks_; + + std::vector rank_to_nodes_{}; // map group rank to node rank (not global) + std::vector rank_to_intra_{}; // map group rank to intra-node rank + std::vector rank_to_inter_{}; // map group rank to inter-node rank + + bool same_process_; +}; + +class HybridGroupId: public HostGroupId { +public: + HybridGroupId() + { + thread_group_id_ = CreateThreadGroupId(); + gloo_group_id_ = CreateGlooGroupId(); + } + + void Initialize() override + { + thread_group_id_->Initialize(); + gloo_group_id_->Initialize(); + } + + void Export(std::ostream& os) override + { + thread_group_id_->Export(os); + gloo_group_id_->Export(os); + } + + void Import(std::istream& is) override + { + thread_group_id_->Import(is); + gloo_group_id_->Import(is); + } + + HostComm CreateCommunicator(int n_ranks, int rank, int node_rank) + { + auto impl = std::make_shared(n_ranks, // + rank, + node_rank, + gloo_group_id_.get(), + thread_group_id_.get()); + return std::static_pointer_cast(impl); + } + + std::unique_ptr thread_group_id_; + std::unique_ptr gloo_group_id_; +}; + +std::unique_ptr CreateHybridGroupId() +{ + return std::make_unique(); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/tcp_store.cc b/src/turbomind/comm/gloo/tcp_store.cc new file mode 100644 index 0000000000..8de75f508b --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.cc @@ -0,0 +1,221 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include +#include + +#include +#include + +#include "src/turbomind/comm/gloo/tcp_store.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +namespace { + +// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.8.0-rc4/torch/csrc/distributed/c10d/TCPStoreBackend.hpp + +static const uint32_t validationMagicNumber = 0x3C85F7CE; + +enum class CheckResponseType : uint8_t +{ + READY, + NOT_READY +}; + +enum class QueryType : uint8_t +{ + VALIDATE, + SET, + COMPARE_SET, + GET, + ADD, + CHECK, + WAIT, + GETNUMKEYS, + DELETE_KEY, + APPEND, + MULTI_GET, + MULTI_SET, + CANCEL_WAIT, + PING, + QUEUE_PUSH, + QUEUE_POP, + QUEUE_LEN, +}; + +} // namespace + +struct Buffer { + std::vector buffer; + + template>> + void append(T val) + { + char* ptr = (char*)&val; + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + void append(const std::vector& vec) + { + append((uint64_t)vec.size()); + buffer.insert(buffer.end(), vec.begin(), vec.end()); + } + + void append(const std::string& str) + { + append((uint64_t)str.size()); + buffer.insert(buffer.end(), str.begin(), str.end()); + } + + const char* data() const + { + return buffer.data(); + } + + size_t count() const + { + return buffer.size(); + } +}; + +void validate(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::VALIDATE); + buffer.append(validationMagicNumber); + socket->write(buffer.data(), buffer.count()); +} + +void ping(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::PING); + uint32_t nonce = getpid(); + uint32_t returnedNonce = -1; + buffer.append(nonce); + socket->write(buffer.data(), buffer.count()); + int r = socket->read(&returnedNonce, sizeof(returnedNonce)); + if (nonce != returnedNonce) { + std::stringstream ss; + ss << "Ping failed, nonce=" << nonce << ", returnedNonce=" << returnedNonce << ", socket read=" << r; + throw std::runtime_error(ss.str()); + } +} + +TCPStore::TCPStore(const std::string& host, int port) +{ + auto retry = 0; + do { + try { + ::addrinfo hints{}, *res{}; + hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + int status = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); + + std::shared_ptr holder(res, [](addrinfo* p) { + if (p != nullptr) { + freeaddrinfo(p); + } + }); + + if (status != 0) { + throw std::runtime_error("getaddrinfo failed: " + std::string(gai_strerror(status))); + } + + for (::addrinfo* addr = res; addr != nullptr; addr = addr->ai_next) { + int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (fd == -1) { + continue; + } + auto socket = std::make_shared<::gloo::transport::tcp::Socket>(fd); + socket->connect(addr->ai_addr, addr->ai_addrlen); + socket->noDelay(true); + socket->recvTimeout(std::chrono::milliseconds(5000)); + socket->sendTimeout(std::chrono::milliseconds(5000)); + validate(socket); // validate the connection + ping(socket); // check send/recv + socket_ = std::move(socket); + break; + } + + if (socket_ == nullptr) { + throw std::runtime_error("unable to connect to " + host + ":" + std::to_string(port)); + } + } + catch (const std::exception& e) { + TM_LOG_WARNING("[TM][COMM] Failed to connect to store after %d retries: %s", retry, e.what()); + std::this_thread::sleep_for(std::chrono::seconds(1)); + retry += 1; + } + } while (socket_ == nullptr); +} + +void TCPStore::set(const std::string& key, const std::vector& data) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::SET); + buffer.append(key); + buffer.append(data); + socket_->write(buffer.data(), buffer.count()); +} + +std::vector TCPStore::get(const std::string& key) +{ + wait({key}); + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::GET); + buffer.append(key); + socket_->write(buffer.data(), buffer.count()); + + uint64_t vec_size; + socket_->read(&vec_size, sizeof(vec_size)); + std::vector value(vec_size); + socket_->read(value.data(), value.size()); + return value; +} + +bool TCPStore::check(const std::vector& keys) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::CHECK); + buffer.append((uint64_t)keys.size()); + for (const auto& key : keys) { + buffer.append(key); + } + socket_->write(buffer.data(), buffer.count()); + + CheckResponseType response; + socket_->read(&response, sizeof(response)); + return response == CheckResponseType::READY; +} + +void TCPStore::wait(const std::vector& keys, const std::chrono::milliseconds& timeout) +{ + const auto start = std::chrono::steady_clock::now(); + while (!check(keys)) { + const auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + if (elapsed > timeout) { + std::stringstream ss; + ss << "Wait timeout for key(s): ["; + for (const auto& key : keys) { + ss << key << " "; + } + ss << "]"; + TM_LOG_ERROR("[TM][COMM] %s, elapsed %lld s", ss.str().c_str(), elapsed.count()); + throw std::runtime_error("Wait timeout for key(s): " + ss.str()); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } +} + +TCPStore::~TCPStore() = default; + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/tcp_store.h b/src/turbomind/comm/gloo/tcp_store.h new file mode 100644 index 0000000000..35dd1c05bf --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.h @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +namespace turbomind::comm { + +class TCPStore: public gloo::rendezvous::Store { +public: + explicit TCPStore(const std::string& host, int port); + + ~TCPStore(); + + void set(const std::string& key, const std::vector& data) override; + + std::vector get(const std::string& key) override; + + bool check(const std::vector& keys); + + void wait(const std::vector& keys) override + { + wait(keys, std::chrono::seconds(30)); + } + + void wait(const std::vector& keys, const std::chrono::milliseconds& timeout) override; + +private: + std::shared_ptr<::gloo::transport::tcp::Socket> socket_; + std::mutex mutex_; +}; + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/test_ipc_comm.cc b/src/turbomind/comm/gloo/test_ipc_comm.cc new file mode 100644 index 0000000000..4ab8490491 --- /dev/null +++ b/src/turbomind/comm/gloo/test_ipc_comm.cc @@ -0,0 +1,397 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "src/turbomind/comm/host_comm.h" + +using namespace turbomind::comm; + +#define TEST_TRIVIALLY_COPYABLE 1 + +// #define SKIP_SERIALIZE 0 // useless now + +// const std::string backend = ""; +const std::string backend = "hybrid"; +// const std::string backend = "gloo"; + +struct Store { + std::string hostname_; + std::string port_; + int nnodes_; + int node_rank_; + std::string py_script_; + std::string py_file_path_ = "/tmp/start_tcp_store.py"; + + std::thread thread_; + + Store(const std::string& hostname, const std::string& port, int nnodes, int node_rank): + hostname_(hostname), port_(port), nnodes_(nnodes), node_rank_(node_rank) + { + + int pid = getpid(); + + // clang-format off + py_script_ = +"import psutil\n" +"import os\n" +"import time\n" +"from torch.distributed import TCPStore\n" +"store = TCPStore(host_name='" + hostname_ + "',\n" +" port=" + port_ + ",\n" +" world_size=" + std::to_string(nnodes_) + ",\n" +" is_master=" + (node_rank_ == 0 ? "True" : "False") + ")\n" +"while True:\n" +" time.sleep(1)\n" +" if not psutil.pid_exists(" + std::to_string(pid) + "):\n" +" break\n" +" if not os.path.exists('/tmp/start_tcp_store.py'):\n" +" break\n"; + + // clang-format on + std::ofstream py_file(py_file_path_); + py_file << py_script_; + py_file.close(); + + std::string env_addr = "LMDEPLOY_DIST_INIT_ADDR=" + hostname_; + std::string env_port = "LMDEPLOY_DIST_INIT_PORT=" + port_; + setenv("LMDEPLOY_DIST_INIT_ADDR", hostname_.c_str(), 1); + setenv("LMDEPLOY_DIST_INIT_PORT", port_.c_str(), 1); + + start(); + // wait a moment for the store to start. + std::this_thread::sleep_for(std::chrono::seconds(3)); + } + + ~Store() + { + stop(); + } + + void start() + { + const std::string cmd = ("python " + py_file_path_); + thread_ = std::thread([](const std::string& cmd) { int result = system(cmd.c_str()); }, cmd); + } + + void stop() + { + int r = system("rm /tmp/start_tcp_store.py"); + thread_.join(); + } +}; + +struct TestGlooComm { + std::string hostname_; + std::string port_; + int nnodes_; + int node_rank_; + int n_ranks_per_node_; + + std::vector h_comm_; + + TestGlooComm(const std::string& host, const std::string& port, int nnodes, int node_rank, int n_ranks_per_node): + hostname_(host), port_(port), nnodes_(nnodes), node_rank_(node_rank), n_ranks_per_node_(n_ranks_per_node) + { + h_comm_.resize(n_ranks_per_node_); + } + + void init() + { + std::unique_ptr group_id = CreateHostGroupId(backend); + std::string group_id_data; + if (1) { // master + group_id->Initialize(); + std::stringstream ss; + group_id->Export(ss); + group_id_data = ss.str(); + } + + auto init = [&](int rank) { + // initialize host communicators + std::stringstream ss(group_id_data); + std::unique_ptr host_id = CreateHostGroupId(backend); + host_id->Import(ss); + h_comm_[rank % n_ranks_per_node_] = + host_id->CreateCommunicator(n_ranks_per_node_ * nnodes_, rank, node_rank_); + }; + + std::vector threads; + for (int i = 0; i < n_ranks_per_node_; ++i) { + threads.emplace_back(init, n_ranks_per_node_ * node_rank_ + i); + } + for (auto& t : threads) { + t.join(); + } + } + + void test_broadcast() + { + const int count = 10; + + auto fun = [&](HostComm& comm, int rank) { + for (int r = 0; r < comm->n_ranks(); ++r) { + +#if TEST_TRIVIALLY_COPYABLE + std::vector data(count); +#else + std::shared_ptr> data_ptr = std::make_shared>(count); + int* data = data_ptr->data(); +#endif + + for (int i = 0; i < count; ++i) { + data[i] = i + rank * count; // i + rank * count + } + +#if TEST_TRIVIALLY_COPYABLE + Broadcast(comm, data.data(), count, r); +#else + Broadcast(comm, data_ptr, r); + data = data_ptr->data(); +#endif + // check result + for (int i = 0; i < count; ++i) { + int expected = i + r * count; + if (data[i] != expected) { + printf("Rank %d: Broadcast failed at root %d, index %d, got %d, expected %d\n", + rank, + r, + i, + data[i], + expected); + } + } + } + }; + + std::vector threads; + for (size_t i = 0; i < n_ranks_per_node_; ++i) { + threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i); + } + for (auto& t : threads) { + t.join(); + } + } + + void test_allgather() + { + const int count = 40; + + auto fun = [&](HostComm& comm, int rank) { + +#if TEST_TRIVIALLY_COPYABLE + std::vector data(count * comm->n_ranks()); + for (int i = 0; i < count; ++i) { + data[i + count * comm->rank()] = i + rank * count; // i + rank * count + } +#else + std::vector>> data_ptrs(comm->n_ranks()); + data_ptrs[comm->rank()] = std::make_shared>(count); + int* data = data_ptrs[comm->rank()]->data(); + for (int i = 0; i < count; ++i) { + data[i] = i + rank * count; // i + rank * count + } +#endif + +#if TEST_TRIVIALLY_COPYABLE + AllGather(comm, data.data(), count); + for (int r = 0; r < comm->n_ranks(); ++r) { + for (int j = 0; j < count; ++j) { + int expected = j + r * count; + if (data[j + r * count] != expected) { + printf("Rank %d: AllGather failed, index %d, got %d, expected %d\n", + rank, + j + r * count, + data[j + r * count], + expected); + } + } + } +#else + AllGather(comm, data_ptrs.data(), 1); + for (int r = 0; r < comm->n_ranks(); ++r) { + data = data_ptrs[r]->data(); + for (int j = 0; j < count; ++j) { + int expected = j + r * count; + if (data[j] != expected) { + printf("Rank %d: AllGather failed, index %d, got %d, expected %d\n", + rank, + j + r * count, + data[j], + expected); + } + } + } +#endif + }; + + std::vector threads; + for (size_t i = 0; i < n_ranks_per_node_; ++i) { + threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i); + } + for (auto& t : threads) { + t.join(); + } + } + + void test_allreduce() + { + const int count = 10; + + auto fun = [&](HostComm& comm, int rank) { + std::vector data(count); + for (int i = 0; i < count; ++i) { + data[i] = i + rank * count; // i + rank * count + } + + AllReduce(comm, data.data(), count, RedOp::kSum); + for (int j = 0; j < count; ++j) { + int expected{}; + for (int r = 0; r < comm->n_ranks(); ++r) { + expected += j + r * count; + } + if (data[j] != expected) { + printf("Rank %d: AllReduce failed, index %d, got %d, expected %d\n", rank, j, data[j], expected); + } + } + }; + + std::vector threads; + for (size_t i = 0; i < n_ranks_per_node_; ++i) { + threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i); + } + for (auto& t : threads) { + t.join(); + } + } + + void test_perf() + { + const long kMinDurationNs = 2e9; // 2 second + const long kWarmupIter = 5; // warmup iter + const float kItersMultiplier = 1.2; + + std::vector count = {1024, 262144, 524288, 1048576, 2097152, 4194304, 67108864}; + // 1M, 2M, 4M, 8M, 16M, 256M + + if (node_rank_ == 0) { + printf("%10s %10s %10s %10s %11s %18s %10s\n", + "size(MB)", + "elements", + "avg(us)", + "p50(us)", + "p99(us)", + "bandwidth(GB/s)", + "iterations"); + } + + auto fun = [&](HostComm& comm, int rank, int n) { + +#if TEST_TRIVIALLY_COPYABLE + std::vector data(n); +#else + std::shared_ptr> sptr; + if (rank == 0) { + sptr = std::make_shared>(n); + } +#endif + + std::vector times; + + auto job = [&](int n_iters) { + times.clear(); + int64_t total = 0; + int64_t ns = 0; + comm->Sync(); + for (int i = 0; i < n_iters; ++i) { + auto start = std::chrono::high_resolution_clock::now(); +#if TEST_TRIVIALLY_COPYABLE + Broadcast(comm, data.data(), n, 0); +#else + Broadcast(comm, sptr, 0); +#endif + auto now = std::chrono::high_resolution_clock::now(); + int64_t ns = std::chrono::duration_cast(now - start).count(); + total += ns; + times.push_back(ns); + } + Broadcast(comm, total, 0); + return total; + }; + + auto warmup_dur = job(kWarmupIter) / kWarmupIter; + auto iter = (int)std::max(kMinDurationNs / warmup_dur * 0.5f, 100.f); + + while (1) { + auto dur = job(iter); + std::sort(times.begin(), times.end()); + + if (rank == 0) { + size_t bytes = n * sizeof(int); + int p50 = std::min(times.size() / 2, times.size() - 1); + int p99 = std::min((int)(times.size() * 0.99), (int)times.size() - 1); + printf("%10.5f %10d %10lld %10lld %10lld %18.3f %10lld\n", + bytes / 1024.f / 1024.f, + n, + static_cast(dur / 1e3f / iter), + static_cast(times[p50] / 1e3f), + static_cast(times[p99] / 1e3f), + (bytes * iter) / (dur / 1e9f) / (1024 * 1024 * 1024), + static_cast(iter)); + } + + if (dur >= kMinDurationNs) { + break; + } + iter = std::max(iter * kItersMultiplier, iter + 1.f); + } + }; + + for (auto n : count) { + std::vector threads; + for (size_t i = 0; i < n_ranks_per_node_; ++i) { + threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i, n); + } + for (auto& t : threads) { + t.join(); + } + } + } +}; + +// ./test_gloo_comm +int main(int argc, char* argv[]) +{ + if (argc != 5) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + int nnodes = std::atoi(argv[1]); + int node_rank = std::atoi(argv[2]); + int n_ranks_per_node = std::atoi(argv[3]); + + const std::string init_addr = argv[4]; + auto pos = init_addr.find(":"); + const std::string host = init_addr.substr(0, pos); + const std::string port = init_addr.substr(pos + 1); + + Store store(host, port, nnodes, node_rank); + + { + TestGlooComm test(host, port, nnodes, node_rank, n_ranks_per_node); + test.init(); + + test.test_broadcast(); + test.test_allgather(); + test.test_allreduce(); + + // test.test_perf(); + } + + return 0; +} diff --git a/src/turbomind/comm/host_comm.cc b/src/turbomind/comm/host_comm.cc index 0d3cf367e2..bf756a660f 100644 --- a/src/turbomind/comm/host_comm.cc +++ b/src/turbomind/comm/host_comm.cc @@ -8,8 +8,21 @@ HostCommImpl::~HostCommImpl() = default; std::unique_ptr CreateThreadGroupId(); +std::unique_ptr CreateGlooGroupId(); + +std::unique_ptr CreateHybridGroupId(); + std::unique_ptr CreateHostGroupId(const std::string& backend) { +#ifdef BUILD_MULTI_GPU + if (backend == "hybrid") { + return CreateHybridGroupId(); + } + if (backend == "gloo") { + return CreateGlooGroupId(); + } +#endif + return CreateThreadGroupId(); } diff --git a/src/turbomind/comm/host_comm.h b/src/turbomind/comm/host_comm.h index e9e25d2b8f..f1da005583 100644 --- a/src/turbomind/comm/host_comm.h +++ b/src/turbomind/comm/host_comm.h @@ -3,12 +3,16 @@ #pragma once #include +#include #include #include +#include #include #include #include "src/turbomind/core/data_type.h" +#include "src/turbomind/core/serdes.h" +#include "src/turbomind/utils/logger.h" namespace turbomind::comm { @@ -23,6 +27,10 @@ typedef void (*copy_fn)(void* src, int n, void* dst, int offset); typedef void (*reduce_fn)(void* src, int n, void* dst, int offset); +typedef void (*ser_fn)(void* data, int offset, int n, size_t& size, void* out); + +typedef void (*des_fn)(void* data, int offset, int n, void* in, size_t size); + class HostCommImpl { public: virtual ~HostCommImpl(); @@ -37,9 +45,20 @@ class HostCommImpl { virtual void Sync(bool blocking = false) = 0; - virtual void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) = 0; - - virtual void AllGather(void* data, int count, DataType dtype, copy_fn copy) = 0; + virtual void Broadcast(void* data, // + int count, + DataType dtype, + int root, + copy_fn copy, + ser_fn ser = nullptr, + des_fn des = nullptr) = 0; + + virtual void AllGather(void* data, // + int count, + DataType dtype, + copy_fn copy, + ser_fn ser = nullptr, + des_fn des = nullptr) = 0; virtual void AllReduce(void* data, int count, DataType dtype, RedOp red_op) = 0; }; @@ -65,13 +84,40 @@ class HostComm { }; namespace detail { - template void copy_fn(void* src, int n, void* dst, int offset) { std::copy_n((T*)src + offset, n, (T*)dst + offset); } +template +void ser_fn(void* data, int offset, int n, size_t& size, void* out) +{ + if (out == nullptr) { + size = 0; + core::BinarySizeArchive sa; + for (int i = 0; i < n; ++i) { + sa&((T*)data)[offset + i]; + } + size = sa.size(); + } + else { + core::BinaryOutputArchive oa(core::ArrayWrapper((std::byte*)out, size)); + for (int i = 0; i < n; ++i) { + oa&((T*)data)[offset + i]; + } + } +} + +template +void des_fn(void* data, int offset, int n, void* in, size_t size) +{ + core::BinaryInputArchive ia(core::ArrayWrapper((std::byte*)in, size)); + for (int i = 0; i < n; ++i) { + ia&((T*)data)[offset + i]; + } +} + } // namespace detail ////////////////////////////////////////////////////////////////////////////////// @@ -88,7 +134,7 @@ void Broadcast(HostCommImpl* comm, T* data, int n, int root) comm->Broadcast(data, n, kNull, root, detail::copy_fn); } else { - throw std::runtime_error("not implemented"); + comm->Broadcast(data, n, kNull, root, detail::copy_fn, detail::ser_fn, detail::des_fn); } } } @@ -105,8 +151,7 @@ void AllGather(HostCommImpl* comm, T* data, int n) comm->AllGather(data, n, kNull, detail::copy_fn); } else { - /// serialize data - throw std::runtime_error("not implemented"); + comm->AllGather(data, n, kNull, detail::copy_fn, detail::ser_fn, detail::des_fn); } } } @@ -150,7 +195,7 @@ class HostGroupId { virtual void Export(std::ostream& os) = 0; virtual void Import(std::istream& is) = 0; - virtual HostComm CreateCommunicator(int n_ranks, int rank) = 0; + virtual HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) = 0; }; std::unique_ptr CreateHostGroupId(const std::string& backend); diff --git a/src/turbomind/comm/thread_comm.cc b/src/turbomind/comm/thread_comm.cc index 509681271e..ba3675f9f5 100644 --- a/src/turbomind/comm/thread_comm.cc +++ b/src/turbomind/comm/thread_comm.cc @@ -11,6 +11,7 @@ #include "src/turbomind/comm/host_comm.h" #include "src/turbomind/core/check.h" #include "src/turbomind/core/data_type.h" +#include "src/turbomind/core/serdes.h" namespace turbomind::comm { struct ThreadCommImpl: public HostCommImpl { @@ -131,7 +132,7 @@ struct ThreadCommImpl: public HostCommImpl { } } - void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) override + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override { TM_CHECK(copy); if (n_ranks_ == 1) { @@ -164,7 +165,7 @@ struct ThreadCommImpl: public HostCommImpl { } } - void AllGather(void* data, int count, DataType dtype, copy_fn copy) override + void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override { TM_CHECK(copy); if (n_ranks_ == 1) { @@ -315,7 +316,7 @@ class ThreadGroupId: public HostGroupId { TM_CHECK((bool)internal_); } - HostComm CreateCommunicator(int n_ranks, int rank) override + HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override { auto init_shared_state = [&] { // internal_->state = std::make_shared(n_ranks); @@ -348,4 +349,16 @@ std::unique_ptr CreateThreadGroupId() return std::make_unique(); } +template +void save(Archive& ar, const std::shared_ptr& p) +{ + TM_CHECK(false) << "should never be called"; +} + +template +void load(Archive& ar, std::shared_ptr& p) +{ + TM_CHECK(false) << "should never be called"; +} + } // namespace turbomind::comm diff --git a/src/turbomind/core/buffer.h b/src/turbomind/core/buffer.h index cdb5c52f8d..438f2e97bb 100644 --- a/src/turbomind/core/buffer.h +++ b/src/turbomind/core/buffer.h @@ -10,6 +10,7 @@ #include "src/turbomind/core/common.h" #include "src/turbomind/core/context.h" #include "src/turbomind/core/data_type.h" +#include "src/turbomind/core/serdes.h" namespace turbomind::core { @@ -340,4 +341,27 @@ void Clear(Ref b_, const Stream& stream); void Clear(Ref b_); +// clang-format off +template +void save(Archive& ar, const Buffer& buffer) +{ + TM_CHECK(buffer.device().type == kCPU); + ar & buffer.size(); + ar & buffer.dtype(); + ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size()); +} + +template +void load(Archive& ar, Buffer& buffer) +{ + decltype(buffer.size()) size; + decltype(buffer.dtype()) dtype; + + ar & size; + ar & dtype; + buffer = Buffer(size, dtype, kCPU); + ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size()); +} +// clang-format on + } // namespace turbomind::core diff --git a/src/turbomind/core/layout.h b/src/turbomind/core/layout.h index a29c25a0dd..4dfb484cd7 100644 --- a/src/turbomind/core/layout.h +++ b/src/turbomind/core/layout.h @@ -173,4 +173,23 @@ inline std::string to_string(const Layout& x) return ss.str(); } +// clang-format off +template +void save(Archive& ar, const Layout& layout) +{ + ar & layout.shape(); + ar & layout.stride(); +} + +template +void load(Archive& ar, Layout& layout) +{ + vector shape; + vector stride; + ar & shape; + ar & stride; + layout = Layout(std::move(shape), std::move(stride)); +} +// clang-format on + } // namespace turbomind::core diff --git a/src/turbomind/core/serdes.h b/src/turbomind/core/serdes.h new file mode 100644 index 0000000000..7a6dc7d549 --- /dev/null +++ b/src/turbomind/core/serdes.h @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace turbomind::core { + +template typename F, class SFINAE, class... Args> +struct is_detected: std::false_type { +}; + +template typename F, class... Args> +struct is_detected>, Args...>: std::true_type { +}; + +template +using save_t = decltype(save(std::declval(), std::declval())); + +template +inline constexpr bool has_save_v = is_detected::value; + +template +using load_t = decltype(load(std::declval(), std::declval())); + +template +inline constexpr bool has_load_v = is_detected::value; + +template +using serdes_t = decltype(serdes(std::declval(), std::declval())); + +template +inline constexpr bool has_serdes_v = is_detected::value; + +template +class ArrayWrapper { +public: + ArrayWrapper(T* t, std::size_t size): t_{t}, size_{size} + { + static_assert(std::is_trivially_copyable_v, "ArrayWrapper requires trivially copyable type"); + } + + T* data() const + { + return t_; + } + + std::size_t count() const + { + return size_; + } + + T* const t_; + const std::size_t size_; +}; + +template +inline constexpr bool is_array_wrapper_v = std::false_type{}; + +template +inline constexpr bool is_array_wrapper_v> = std::true_type{}; + +template +struct OutputArchive { + static constexpr bool is_loading = false; + + template + void operator&(T&& x) + { + if constexpr (has_save_v) { + save(*this, (T &&) x); + } + else if constexpr (has_serdes_v) { + serdes(*this, (T &&) x); + } + else { + reinterpret_cast(this)->write((T &&) x); + } + } +}; + +template +struct InputArchive { + static constexpr bool is_loading = true; + + template + void operator&(T&& x) + { + if constexpr (has_load_v) { + load(*this, (T &&) x); + } + else if constexpr (has_serdes_v) { + serdes(*this, (T &&) x); + } + else { + reinterpret_cast(this)->read((T &&) x); + } + } +}; + +struct BinarySizeArchive: OutputArchive { + size_t size_{}; + + size_t size() + { + return size_; + } + + template + void write(const T& x) + { + static_assert(std::is_trivially_copyable_v); + size_ += sizeof(x); + } + + template + void write(const ArrayWrapper& arr) + { + static_assert(std::is_trivially_copyable_v); + size_ += sizeof(T) * arr.count(); + } +}; + +struct BinaryOutputArchive: OutputArchive { + + ArrayWrapper external_; + size_t ptr_; + + BinaryOutputArchive(ArrayWrapper external): external_{external}, ptr_{} {} + + template + void write(const T& x) + { + static_assert(std::is_trivially_copyable_v); + auto data = (const std::byte*)&x; + TM_CHECK_LE(ptr_ + sizeof(T), external_.count()); + std::copy_n(data, sizeof(T), external_.data() + ptr_); + ptr_ += sizeof(T); + } + + template + void write(const ArrayWrapper& arr) + { + static_assert(std::is_trivially_copyable_v); + auto data = (const std::byte*)arr.data(); + TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count()); + std::copy_n(data, sizeof(T) * arr.count(), external_.data() + ptr_); + ptr_ += sizeof(T) * arr.count(); + } +}; + +struct BinaryInputArchive: InputArchive { + + ArrayWrapper external_; + size_t ptr_; + + BinaryInputArchive(ArrayWrapper external): external_{external}, ptr_{} {} + + template + void read(T& x) + { + static_assert(std::is_trivially_copyable_v); + TM_CHECK_LE(ptr_ + sizeof(T), external_.count()); + std::copy_n(external_.data() + ptr_, sizeof(T), (std::byte*)&x); + ptr_ += sizeof(T); + } + + template + void read(ArrayWrapper&& arr) + { + static_assert(std::is_trivially_copyable_v); + TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count()); + std::copy_n(external_.data() + ptr_, sizeof(T) * arr.count(), (std::byte*)arr.data()); + ptr_ += sizeof(T) * arr.count(); + } +}; + +template +void save(Archive& ar, const std::vector& xs) +{ + // clang-format off + ar & xs.size(); + if constexpr (std::is_trivially_copyable_v) { + ar & ArrayWrapper(xs.data(), xs.size()); + } + else { + for (const auto& x : xs) { + ar & x; + } + } + // clang-format on +} + +template +void load(Archive& ar, std::vector& xs) +{ + // clang-format off + decltype(xs.size()) size; + ar & size; + xs.resize(size); + + if constexpr (std::is_trivially_copyable_v) { + ar & ArrayWrapper(xs.data(), size); + } else { + for (size_t i = 0; i < size; ++i) { + ar & xs[i]; + } + } + // clang-format on +} + +template +void save(Archive& ar, const std::string& s) +{ + // clang-format off + ar & s.size(); + ar & ArrayWrapper(s.data(), s.size()); + // clang-format on +} + +template +void load(Archive& ar, std::string& s) +{ + // clang-format off + decltype(s.size()) size; + ar & size; + s.resize(size); + ar & ArrayWrapper(s.data(), size); + // clang-format on +} + +template +void save(Archive& ar, const std::shared_ptr& p) +{ + // clang-format off + ar & (bool)p; + if (p) { + ar & (*p); + } + // clang-format on +} + +template +void load(Archive& ar, std::shared_ptr& p) +{ + // clang-format off + bool pred; + ar & pred; + if (pred) { + p = std::make_shared(); + ar & (*p); + } +} + +template +void serdes(Archive& ar, std::array& xs) +{ + // clang-format off + if constexpr (std::is_trivially_copyable_v) { + ar & ArrayWrapper(xs.data(), N); + } + else { + for (size_t i = 0; i < N; ++i) { + ar & xs[i]; + } + } + // clang-format on +} + +template +void serdes(Archive& ar, std::tuple& tpl) +{ + std::apply([&](auto&... elems) { ((ar & elems), ...); }, tpl); +} + +} // namespace turbomind::core diff --git a/src/turbomind/core/tensor.h b/src/turbomind/core/tensor.h index 499261312b..55480503f5 100644 --- a/src/turbomind/core/tensor.h +++ b/src/turbomind/core/tensor.h @@ -344,4 +344,50 @@ class TensorMap: public std::unordered_map { std::string get_out_of_range_msg(const std::string& key) const; }; +// clang-format off +template, int> = 0> +void save(Archive& ar, const T& tensor) +{ + TM_CHECK(tensor.size() == 0 || tensor.is_contiguous()); + ar & tensor.buffer(); // implicit convert to tensor + ar & tensor.layout(); +} + +template +void load(Archive& ar, Tensor& tensor) +{ + Buffer buffer; + Layout layout; + ar & buffer; + ar & layout; + tensor = Tensor{std::move(buffer), std::move(layout)}; +} + + +template +void save(Archive& ar, const TensorMap& map) +{ + ar & map.size(); + for (const auto& [k, t]: map) { + ar & k; + ar & t; + } +} + +template +void load(Archive& ar, TensorMap& map) +{ + map.clear(); + decltype(map.size()) size; + ar & size; + for (int i = 0; i < size; ++i) { + std::string k; + Tensor t; + ar & k; + ar & t; + map.emplace(std::move(k), std::move(t)); + } +} +// clang-format on + } // namespace turbomind::core diff --git a/src/turbomind/engine/gateway.cc b/src/turbomind/engine/gateway.cc index 3dd8c4b4cb..ff7846bff7 100644 --- a/src/turbomind/engine/gateway.cc +++ b/src/turbomind/engine/gateway.cc @@ -7,9 +7,13 @@ namespace turbomind { -Gateway::Gateway(int groups, int group_size, std::function()> ctx_factory): +Gateway::Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory): size_{groups * group_size}, group_size_{group_size}, + node_dp_ranks_{std::move(node_dp_ranks)}, queues_(size_), flags_(groups), ctx_factory_{ctx_factory}, diff --git a/src/turbomind/engine/gateway.h b/src/turbomind/engine/gateway.h index 8350822046..307a02c602 100644 --- a/src/turbomind/engine/gateway.h +++ b/src/turbomind/engine/gateway.h @@ -60,7 +60,10 @@ class SeqId2Rank { class Gateway { public: - Gateway(int groups, int group_size, std::function()> ctx_factory); + Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory); void shutdown(); @@ -72,8 +75,9 @@ class Gateway { // route to corresponding rank rank = seqid2rank_.find(r->session.id); } - else { - rank = next_.fetch_add(1, std::memory_order_relaxed) % size_; + else if (node_dp_ranks_.size() > 0) { + rank = next_.fetch_add(1, std::memory_order_relaxed) % node_dp_ranks_.size(); + rank = node_dp_ranks_[rank]; } if (rank >= 0) { @@ -188,6 +192,7 @@ class Gateway { std::vector> queues_; std::vector>> flags_; + std::vector node_dp_ranks_; std::function()> ctx_factory_; diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index ba7ebe321f..542e9b236c 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -6,6 +6,8 @@ #include #include +#include + #include "src/turbomind/engine/model_request.h" #include "src/turbomind/engine/request.h" #include "src/turbomind/utils/constant.h" diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index aa50a48100..711da50949 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -10,11 +10,14 @@ #include #include -#include - #include "src/turbomind/core/core.h" +#include "src/turbomind/core/serdes.h" #include "src/turbomind/utils/metrics.h" +namespace xgrammar { +class GrammarMatcher; // forward declaration +} // namespace xgrammar + namespace turbomind { struct GenerationConfig { @@ -174,4 +177,80 @@ inline void UpdateState(Request& r, int status, int seq_len) } } +template +void serdes(Archive& ar, GenerationConfig& g) +{ + // clang-format off + ar & g.max_new_tokens; + ar & g.min_new_tokens; + ar & g.eos_ids; + ar & g.stop_ids[0]; + ar & g.stop_ids[1]; + ar & g.bad_ids[0]; + ar & g.bad_ids[1]; + ar & g.top_k; + ar & g.top_p; + ar & g.min_p; + ar & g.temperature; + ar & g.repetition_penalty; + ar & g.random_seed; + ar & g.output_logprobs; + ar & g.output_last_hidden_state; + ar & g.output_logits; + // clang-format on +} + +template +void save_req_output(Archive& ar, const TensorMap& map) +{ + // clang-format off + ar & map.size(); + for (const auto& [k, t] : map) { + TM_CHECK(t.device().type == kCPU); + ar & k; + ar & t.layout(); + ar & t.dtype(); + } + // clang-format on +} + +template +void load_req_output(Archive& ar, TensorMap& map) +{ + // clang-format off + decltype(map.size()) size; + ar & size; + for (int i = 0; i < size; ++i) { + std::string k; + Layout layout; + DataType dtype; + ar & k; + ar & layout; + ar & dtype; + map.emplace(std::move(k), Tensor{layout, dtype, kCPU}); + } + // clang-format on +} + +template +void serdes(Archive& ar, Request& r) +{ + // clang-format off + ar & r.id; + ar & r.unique_id; + ar & r.session; + ar & r.gen_cfg; + ar & r.stream_output; + ar & r.inputs; + if constexpr(Archive::is_loading) { + load_req_output(ar, r.outputs); + r.output_ids = r.outputs.at("output_ids"); + r.sequence_length = r.outputs.at("sequence_length"); + } else { + save_req_output(ar, r.outputs); + } + ar & r.ec; + // clang-format on +} + } // namespace turbomind diff --git a/src/turbomind/engine/request_queue.h b/src/turbomind/engine/request_queue.h index ccaa4af700..b5ad239c89 100644 --- a/src/turbomind/engine/request_queue.h +++ b/src/turbomind/engine/request_queue.h @@ -47,7 +47,7 @@ class RequestQueue { auto it = queue_.begin(); int count{}; while (rs.size() < max_rs_size && count < max_count && it != queue_.end()) { - if (!(*it)->session.start_flag) { + if ((*it)->session.start_flag) { rs.push_back(std::move(*it)); ++count; auto tmp = it; @@ -78,10 +78,10 @@ class RequestQueue { || flag_->load(std::memory_order_relaxed) == expected_ // || closed_; }); - if (closed_) { - abort = true; - return false; - } + } + if (closed_) { + abort = true; + return false; } bool is_first = false; diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu index b94a838643..a85240daf6 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.cu +++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu @@ -653,6 +653,11 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n return invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>); } } + else if (experts <= 512) { + if (experts_per_token <= 8) { + return invoke(_Int<512>, _Int<8>, _Int<16>, _Int<4>); + } + } return false; }; diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h index 45cc917976..e1f45aeb0d 100644 --- a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h @@ -18,6 +18,8 @@ #include +#include + #include "src/turbomind/layers/BaseDynamicDecodeLayer.h" #include "src/turbomind/engine/request.h" diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index f0122b69c4..7015ee5b65 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -26,6 +26,7 @@ #include "src/turbomind/core/allocator.h" #include "src/turbomind/core/buffer.h" #include "src/turbomind/core/context.h" +#include "src/turbomind/core/serdes.h" #include "src/turbomind/core/tensor.h" #include "src/turbomind/macro.h" @@ -1082,6 +1083,10 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat void LlamaBatch::OutputLastHiddenState(const Tensor& hidden_states, int first, int last) { + if (tp_rank_ != 0) { + return; + } + const auto& src_buf = hidden_states.buffer(); const auto data_type = src_buf.dtype(); int base = 0; @@ -1343,6 +1348,17 @@ struct RequestData { bool abort; }; +template +void serdes(Archive& ar, RequestData& r) +{ + // clang-format off + ar & r.infer; + ar & r.kill; + ar & r.cancel; + ar & r.abort; + // clang-format on +} + } // namespace void LlamaBatch::InternalThreadEntry() @@ -1367,8 +1383,17 @@ void LlamaBatch::InternalThreadEntry() NvtxScope _("pop"); const int free_slot_count = max_batch_size_ - state_->size + g.finished_count; const bool is_empty = (free_slot_count == max_batch_size_); - // Block if batch is empty AND no silbings are ready - gateway_->pop(req->infer, req->kill, free_slot_count, is_empty, req->abort, dp_rank_); + // Block if batch is empty AND no silbings are ready AND comm in same node + const bool blocking = is_empty && comm_.h_comm->is_same_process(); + int wait = 0; + do { + gateway_->pop(req->infer, req->kill, free_slot_count, blocking, req->abort, dp_rank_); + if (!comm_.h_comm->is_same_process()) { + bool empty_pop = req->infer.size() == 0 && req->kill.size() == 0 && req->abort == false; + wait = is_empty && empty_pop; + wait = AllReduce(comm_.h_dp_group, wait, comm::RedOp::kSum) == comm_.h_dp_group->n_ranks(); + } + } while (wait); } // Mark reqs to the same session_id as invalid and also interactive-mode reqs when // prefix caching is enabled(which are dangerous to the engine) @@ -1385,7 +1410,13 @@ void LlamaBatch::InternalThreadEntry() // 1. Wait while rank-0 is dequeueing // 2. Broadcast `ec` from rank-0 - Broadcast(comm_.h_tp_group, req, 0); + if (comm_.h_tp_group->n_ranks() > 1) { + Broadcast(comm_.h_tp_group, req, 0); + } + + if (!comm_.h_comm->is_same_process()) { + req->abort = AllReduce(comm_.h_comm, (int)req->abort, comm::RedOp::kSum) > 0; + } if (req->abort) { TM_LOG_INFO("[InternalThreadEntry] stop requested."); diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index e3cdd973ea..dae3214910 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -110,6 +110,10 @@ struct EngineParam { int mlp_tp_size; int mlp_tp_rank; + // multi-node + int nnodes; + int node_rank; + std::vector devices; }; diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index f4d090fefd..8ae98aa002 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -573,6 +573,7 @@ PYBIND11_MODULE(_turbomind, m) "device_id"_a, "tags"_a, "rank"_a) + .def("is_dummy_node", [](LlamaTritonModel* model) { return model->isDummyNode(); }) .def("__str__", &LlamaTritonModel::toString) .def("__repr__", &LlamaTritonModel::toString) .def("get_tensor_para_size", &LlamaTritonModel::getTensorParaSize) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 56f9b2ddb9..1ba1845dbd 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -379,6 +379,10 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ engine_param_.devices = engine_reader["devices"].as>(); + // multi-node information + engine_param_.nnodes = engine_reader["nnodes"].as(); + engine_param_.node_rank = engine_reader["node_rank"].as(); + { auto tp = engine_param_.attn_tp_size * engine_param_.attn_cp_size; engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + tp - 1) / tp * tp; @@ -414,9 +418,6 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ handleMissingParams(); - gateway_ = std::make_shared(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory); - ffi_ctx_factory_ = ffi_ctx_factory; - weights_.resize(engine_param_.devices.size()); engines_.resize(engine_param_.devices.size()); contexts_.resize(engine_param_.devices.size()); @@ -434,24 +435,39 @@ LlamaTritonModel::LlamaTritonModel(std::string model_ // NOTE: This runs on Python main thread group_ids_.resize(engine_param_.outer_dp_size); for (size_t i = 0; i < group_ids_.size(); ++i) { - group_ids_[i] = comm::CreateHostGroupId(""); + const std::string group_backend = (engine_param_.nnodes == 1) ? "" : "hybrid"; + + group_ids_[i] = comm::CreateHostGroupId(group_backend); group_ids_[i]->Initialize(); } - const int device_num = engine_param_.outer_dp_size * comm_size_; - const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; + const int device_per_node = engine_param_.devices.size(); + const int device_offset = device_per_node * engine_param_.node_rank; + const int tp_cp_size = engine_param_.attn_tp_size * engine_param_.attn_cp_size; // comm layout: outer_dp x inner(dp, tp, cp) - engine_params_.resize(device_num, engine_param_); - for (int i = 0; i < device_num; ++i) { + engine_params_.resize(device_per_node, engine_param_); + for (int i = 0; i < device_per_node; ++i) { auto& e = engine_params_[i]; - e.outer_dp_rank = i / comm_size_; - e.attn_cp_rank = i % comm_size_ % e.attn_cp_size; - e.attn_tp_rank = i % tp_cp_size / e.attn_cp_size; - e.attn_dp_rank = i % comm_size_ / tp_cp_size; - e.mlp_tp_rank = i % comm_size_; + e.outer_dp_rank = (i + device_offset) / comm_size_; + e.attn_cp_rank = (i + device_offset) % comm_size_ % e.attn_cp_size; + e.attn_tp_rank = (i + device_offset) % tp_cp_size / e.attn_cp_size; + e.attn_dp_rank = (i + device_offset) % comm_size_ / tp_cp_size; + e.mlp_tp_rank = (i + device_offset) % comm_size_; } + for (int local_rank = 0; local_rank < device_per_node; ++local_rank) { + auto& e = engine_params_[local_rank]; + if (e.attn_tp_rank == 0 && e.attn_cp_rank == 0) { + node_dp_ranks_.push_back(e.outer_dp_rank * e.attn_dp_size + e.attn_dp_rank); + } + } + is_dummy_node_ = node_dp_ranks_.size() == 0; + + gateway_ = std::make_shared( + engine_param_.outer_dp_size, engine_param_.attn_dp_size, node_dp_ranks_, ffi_ctx_factory); + ffi_ctx_factory_ = ffi_ctx_factory; + TM_LOG_INFO("%s", toString().c_str()); } @@ -466,13 +482,13 @@ std::unique_ptr LlamaTritonModel::createModelInstance(int device_i void LlamaTritonModel::createSharedWeights(int device_id, int rank) { CudaDeviceGuard dev_guard(engine_param_.devices[device_id]); - weights_[rank] = - std::make_shared(dtype_, model_param_, engine_params_.at(rank), lora_param_, moe_param_); + weights_[device_id] = + std::make_shared(dtype_, model_param_, engine_params_.at(device_id), lora_param_, moe_param_); } TensorMap LlamaTritonModel::getParams(int device_id, int rank) { - const auto& tensor_ptr_map = TM_CHECK_NOTNULL(weights_[rank])->get_parameters(); + const auto& tensor_ptr_map = TM_CHECK_NOTNULL(weights_[device_id])->get_parameters(); TensorMap params; for (const auto& [name, tensor_ptr] : tensor_ptr_map) { params[name] = *tensor_ptr; @@ -504,7 +520,7 @@ Communicators LlamaTritonModel::createCommSplits(int rank) const int color_cp = inner_rank / engine_param_.attn_cp_size; const int color_dp = inner_rank % tp_cp_size; - comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); + comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank, engine_param_.node_rank); comm.h_tp_group = comm.h_comm->Split(color_tp, 0); comm.h_dp_group = comm.h_comm->Split(color_dp, 0); @@ -538,7 +554,7 @@ void LlamaTritonModel::createEngine(int device_id, int rank) core::ContextGuard guard{ctx->core_stream, ctx->allocator, Allocator{kCPUpinned}}; - const auto& engine_param = engine_params_.at(rank); + const auto& engine_param = engine_params_.at(device_id); // Get `h_comm` first as ctx will be moved later const auto h_comm = ctx->comm.h_comm; @@ -647,8 +663,8 @@ void LlamaTritonModel::wakeup(int device_id, const std::vector& tag if (keys.find("kv_cache") != keys.end()) { if (device_id == 0) { - gateway_ = - std::make_shared(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory_); + gateway_ = std::make_shared( + engine_param_.outer_dp_size, engine_param_.attn_dp_size, node_dp_ranks_, ffi_ctx_factory_); } TM_CHECK(contexts_[device_id] != nullptr); contexts_[device_id]->comm.h_comm->Sync(); @@ -656,6 +672,11 @@ void LlamaTritonModel::wakeup(int device_id, const std::vector& tag } } +bool LlamaTritonModel::isDummyNode() +{ + return is_dummy_node_; +} + std::string LlamaTritonModel::toString() { std::stringstream ss; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index 953dc22a65..5bfd4b38c3 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -56,6 +56,8 @@ class LlamaTritonModel { void wakeup(int device_id, const std::vector& tags, int rank); + bool isDummyNode(); + std::string toString(); int getTensorParaSize(); @@ -85,6 +87,7 @@ class LlamaTritonModel { std::shared_ptr gateway_; std::function()> ffi_ctx_factory_; + std::vector node_dp_ranks_; // Weights & engine instances for the ranks std::vector> weights_; @@ -93,6 +96,7 @@ class LlamaTritonModel { std::vector> contexts_; bool is_fp16_; + bool is_dummy_node_; std::string model_name_; std::string model_dir_;