Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3899022
init gloo support
irexyc Mar 26, 2025
916e44b
use pytorch tcpstore
irexyc Apr 2, 2025
bd5be0f
update gateway and support setting devices
irexyc Apr 9, 2025
ecc2623
fix build
irexyc Apr 9, 2025
edb4dfe
use tm cfg instead of env
irexyc Apr 10, 2025
216ad15
Merge remote-tracking branch 'origin/main' into gloo-comm
irexyc Apr 23, 2025
9b8d8a7
fix dp
irexyc Apr 23, 2025
22569cb
fix lint
irexyc Apr 23, 2025
eb03cbf
fix build
irexyc Apr 24, 2025
e0f2409
Merge remote-tracking branch 'origin/main' into gloo-comm
irexyc Apr 24, 2025
aca00e1
Merge remote-tracking branch 'origin/main' into gloo-comm
irexyc Jul 11, 2025
e78156d
fix ci
irexyc Jul 11, 2025
e06f256
update gloo version to match pytroch/v2.8.0-rc4
irexyc Jul 11, 2025
dce9eb7
Merge remote-tracking branch 'github/main' into gloo-comm
irexyc Nov 17, 2025
63173f1
simplify devices setup
irexyc Nov 17, 2025
386b411
Merge remote-tracking branch 'github/main' into gloo-comm
irexyc Nov 17, 2025
ef343f9
change the size of engine_params_ to device_per_node
irexyc Nov 17, 2025
875d747
use dist_init_addr for init
irexyc Nov 18, 2025
723d14d
remove unused
irexyc Nov 20, 2025
24627d0
update
irexyc Nov 26, 2025
6851447
Merge remote-tracking branch 'github/main' into gloo-comm
irexyc Nov 26, 2025
6121732
optimize serialization
irexyc Dec 2, 2025
889aba5
buffer management
irexyc Dec 2, 2025
083bfd4
fix wait
irexyc Dec 3, 2025
be25f2d
remove constraint that each node must has attn_dp
irexyc Dec 3, 2025
6c77a59
add hybrid comm & optimize broadcast
irexyc Dec 4, 2025
8d67213
add test & benchmark code
irexyc Dec 5, 2025
f5e815f
add ibverbs transport
irexyc Dec 5, 2025
983f41e
remove grammar deps in irrelevant cmakelists
irexyc Dec 8, 2025
29c27b0
use serdes
irexyc Dec 11, 2025
fcb49cf
hide hostcomm implementation details
irexyc Dec 11, 2025
ec4f386
skip serialize buffer of Request.outputs
irexyc Dec 11, 2025
fdd1438
fix try_pop
irexyc Dec 11, 2025
d635c5e
use default 30mins timeout
irexyc Dec 12, 2025
2dfb965
support loading model with 512 experts
irexyc Dec 17, 2025
d4bdd69
remove unused
irexyc Dec 18, 2025
62338b2
remove ex archive
irexyc Dec 19, 2025
fb889df
use is_loading static var
irexyc Dec 19, 2025
eccf76c
fix dummy node logic
irexyc Dec 19, 2025
c6c77d7
use large timeout for broadcast request
irexyc Dec 19, 2025
f786edd
add comments to metrics
irexyc Dec 19, 2025
1a94ae1
use hybrid comm as default for multi nodes
irexyc Dec 19, 2025
b457e1b
update inter comm split in hybrid comm
irexyc Dec 19, 2025
b387512
remove unused
irexyc Dec 19, 2025
b16d9ba
fix lint
irexyc Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def add_parser_api_server():
ArgumentHelper.role(pt_group)
ArgumentHelper.migration_backend(pt_group)
# multi-node serving args
ArgumentHelper.node_rank(parser)
ArgumentHelper.num_nodes(parser)
node_rank_act = ArgumentHelper.node_rank(pt_group)
num_nodes_act = ArgumentHelper.num_nodes(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
Expand All @@ -135,6 +135,8 @@ def add_parser_api_server():
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(num_nodes_act)
tb_group._group_actions.append(node_rank_act)
tb_group._group_actions.append(hf_overrides)
tb_group._group_actions.append(disable_metrics)
tb_group._group_actions.append(dp)
Expand All @@ -143,6 +145,7 @@ def add_parser_api_server():
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.dist_init_addr(tb_group)

# vlm args
vision_group = parser.add_argument_group('Vision model arguments')
Expand Down Expand Up @@ -239,6 +242,9 @@ def api_server(args):
tp=args.tp,
dp=args.dp,
cp=args.cp,
nnodes=args.nnodes,
node_rank=args.node_rank,
dist_init_addr=args.dist_init_addr,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def num_nodes(parser):

return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums')

@staticmethod
def dist_init_addr(parser):
"""Add argument dist_init_addr to parser."""

return parser.add_argument('--dist-init-addr', type=str, default=None)

@staticmethod
def session_id(parser):
"""Add argument session_id to parser."""
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ class TurbomindEngineConfig:
mlp_tp_size: int = None
mlp_dp_size: int = None
outer_dp_size: int = None
nnodes: int = 1
node_rank: int = 0
dist_init_addr: Optional[str] = None
devices: List[int] = None
session_len: Optional[int] = None
max_batch_size: int = None
cache_max_entry_count: float = 0.8
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ def _build_stat_loggers(self):

if getattr(self.backend_config, 'enable_metrics', False):
from lmdeploy.metrics.loggers import LoggingStatLogger, PrometheusStatLogger
dp_rank = self.backend_config.dp_rank if self.backend_config.dp > 1 else 0

# currently, metrics in TM engine doesn't support dp
dp_rank = self.backend_config.dp_rank if self.backend == 'pytorch' else 0

logger.info(f'enable metrics, with dp: {self.backend_config.dp} dp_rank: {dp_rank}')
self.stat_loggers = [
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,9 @@ async def startup_event():

if VariableInterface.proxy_url is None:
return
elif getattr(async_engine.engine, 'is_dummy', False):
logger.info('Dummy node started')
return
try:
import requests
engine_config = VariableInterface.async_engine.backend_config
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ModelConfig:
attn_sink: bool = False
qk_norm: bool = False
size_per_head: int = 128
group_size: int = 64
group_size: int = 32
data_type: str = None
weight_type: str = None
expert_weight_type: str = None
Expand Down
7 changes: 4 additions & 3 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ def _tofile(tensor, path):
torch_tensor = torch_tensor.bfloat16()
else:
torch_tensor = torch_tensor.half()
for tm_tensor in tm_params[name]:
tm_tensor.copy_from(torch_tensor)
tm_params.pop(name)
if name in tm_params:
for tm_tensor in tm_params[name]:
tm_tensor.copy_from(torch_tensor)
tm_params.pop(name)
else:
tprint('skip export', name, param.shape)

Expand Down
39 changes: 26 additions & 13 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import json
import math
import os
import os.path as osp
import sys
from collections.abc import Sequence
Expand Down Expand Up @@ -86,10 +87,11 @@ def complete_parallel_config(cfg: TurbomindEngineConfig):


def update_parallel_config(cfg: TurbomindEngineConfig):
cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num
if not complete_parallel_config(cfg):
total = cfg.dp * cfg.tp
if not cfg.device_num:
count = torch.cuda.device_count()
count = torch.cuda.device_count() * cfg.nnodes
if total < count:
count = total
cfg.device_num = count
Expand All @@ -106,7 +108,10 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
cfg.mlp_tp_size = mlp_tp_size * inner_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num
cfg.devices = cfg.devices or list(range(cfg.device_num))
# update devices
cfg.devices = cfg.devices or list(range(cfg.device_num // cfg.nnodes))
cfg.devices = cfg.devices[:cfg.device_num // cfg.nnodes]
assert len(cfg.devices) == cfg.device_num // cfg.nnodes


class TurboMind:
Expand Down Expand Up @@ -142,14 +147,27 @@ def __init__(self,
f' greater than 0, but got {_engine_config.max_batch_size}'

update_parallel_config(_engine_config)

self.gpu_count = _engine_config.device_num
if _engine_config.nnodes > 1:
logger.info(f'dist_init_addr={_engine_config.dist_init_addr}')
assert _engine_config.dist_init_addr is not None
hostname, port = _engine_config.dist_init_addr.split(':')
os.environ['LMDEPLOY_DIST_INIT_ADDR'] = hostname
os.environ['LMDEPLOY_DIST_INIT_PORT'] = port
# this will block the process and ignore signals until all ranks done
from torch.distributed import TCPStore
self.store = TCPStore(host_name=hostname,
port=int(port),
world_size=_engine_config.nnodes,
is_master=_engine_config.node_rank == 0)

self.gpu_count = len(_engine_config.devices)
self.devices = _engine_config.devices
self._engine_created = False

if not osp.exists(model_path):
model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision)
self.model_comm = self._from_hf(model_path=model_path, engine_config=_engine_config)
self.is_dummy = self.model_comm.is_dummy_node()
self.tokenizer = Tokenizer(model_path)
if not _engine_config.empty_init:
self._load_weights()
Expand Down Expand Up @@ -192,10 +210,8 @@ def _create_engine(self):
def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

# TODO: support mpi
self.node_id = 0
self.node_num = 1
torch.cuda.synchronize()
engine_cfg = self.config_dict['engine_config']
self.node_id = engine_cfg['node_rank']

# create weight
def _create_weight_func(device_id):
Expand Down Expand Up @@ -382,6 +398,8 @@ def close(self):
if self.model_comm is not None:
self.model_comm = None
self._engine_created = False
if hasattr(self, 'store'):
del self.store

def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
Expand Down Expand Up @@ -527,11 +545,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id

self.node_id = tm_model.node_id
self.gpu_count = tm_model.gpu_count

self.session_len = tm_model.session_len

# create model instances
lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)
self._model_inst = None if lazy_init else self._create_model_instance(0)
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/comm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ if (BUILD_MULTI_GPU)
target_link_libraries(device_comm INTERFACE nccl_comm)
endif ()

add_subdirectory(gloo)
target_link_libraries(host_comm INTERFACE gloo_comm)

if (BUILD_TEST)
add_executable(test_comm test_comm.cu)
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
Expand Down
39 changes: 39 additions & 0 deletions src/turbomind/comm/gloo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.8)

include(FetchContent)
FetchContent_Declare(
gloo
GIT_REPOSITORY https://github.com/pytorch/gloo.git
GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4
)

# some settings of gloo,
set(GLOO_INSTALL OFF CACHE BOOL "" FORCE)
set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE)
set(USE_NCCL OFF)
set(BUILD_TEST OFF)
set(USE_IBVERBS OFF)
FetchContent_MakeAvailable(gloo)

# gloo build doesn't add include directories as a target property...
target_include_directories(gloo PUBLIC
$<BUILD_INTERFACE:${gloo_SOURCE_DIR}>
$<BUILD_INTERFACE:${gloo_BINARY_DIR}> # config.h generated at cmake config time
)

target_compile_options(gloo PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:/W0>
$<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-w>
)

add_library(gloo_comm STATIC
gloo_comm.cc
hybrid_comm.cc
tcp_store.cc
)
set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(gloo_comm PUBLIC gloo host_comm logger xgrammar)

add_executable(test_ipc_comm test_ipc_comm.cc)
target_link_libraries(test_ipc_comm PRIVATE gloo_comm Threads::Threads)
Loading