From 897547620d1fc82524a89d45822f6cd26a65631f Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 27 Feb 2025 06:26:52 +0000 Subject: [PATCH 1/2] support passing model parameters --- lmdeploy/messages.py | 3 ++- lmdeploy/turbomind/deploy/loader.py | 39 ++++++++++++++++++++++++++++- lmdeploy/turbomind/turbomind.py | 8 +++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index f4a9b813a5..4ab6549bf8 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum from dataclasses import dataclass, field -from typing import Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional import torch from pydantic.dataclasses import dataclass as pydantic_dataclass @@ -223,6 +223,7 @@ class TurbomindEngineConfig: max_prefill_token_num: int = 8192 num_tokens_per_iter: int = 0 max_prefill_iters: int = 1 + model_params_que: Any = None def __post_init__(self): """Check input validation.""" diff --git a/lmdeploy/turbomind/deploy/loader.py b/lmdeploy/turbomind/deploy/loader.py index 8ad440a5a5..4ff3f4644f 100644 --- a/lmdeploy/turbomind/deploy/loader.py +++ b/lmdeploy/turbomind/deploy/loader.py @@ -6,7 +6,9 @@ from collections import defaultdict from functools import partial from glob import glob -from typing import Iterator, Tuple +from multiprocessing.queues import Queue as MpQueue +from queue import Queue +from typing import Iterator, Tuple, Union import torch from safetensors import safe_open @@ -144,9 +146,44 @@ def items(self): yield (idx, params.pop(idx)) +class QueueLoader: + + def __init__(self, queue: Union[Queue, MpQueue], pattern: str): + self.que = queue + self.pattern = pattern + self.item_count = defaultdict(int) + + def items(self): + params = defaultdict(dict) + # the first item should be all keys of weight + keys = self.que.get() + for k in keys: + match = re.findall(self.pattern, k) + if match: + self.item_count[int(match[0])] += 1 + # load weights from queue + for state_dict in iter(self.que.get, None): + misc = [] + for k, v in state_dict.items(): + match = re.findall(self.pattern, k) + if not match: + misc.append((k, v)) + else: + idx = int(match[0]) + param = params[idx] + param[k] = v + if len(param) == self.item_count[idx]: + yield (idx, params.pop(idx)) + if misc: + yield (-1, {k: v for k, v in misc}) + + def create_loader(model_path: str, pattern: str) -> BaseLoader: args = (model_path, pattern) + if isinstance(model_path, Queue) or isinstance(model_path, MpQueue): + return QueueLoader(*args) + if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)): return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 6e797caf3d..268c2441b5 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -90,7 +90,7 @@ def __init__(self, self.model_name = model_name self.chat_template_name = chat_template_name - _engine_config = copy.deepcopy(engine_config) + _engine_config = copy.copy(engine_config) if _engine_config is None: _engine_config = TurbomindEngineConfig() if _engine_config.max_batch_size is None: @@ -200,6 +200,10 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu from .deploy.converter import get_tm_model tm_model = get_tm_model(model_path, self.model_name, self.chat_template_name, engine_config) + # get model_params_que + model_params_que = engine_config.model_params_que + engine_config.model_params_que = None + self._postprocess_config(tm_model.tm_config, engine_config) model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='', @@ -214,6 +218,8 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu tm_params = tm_model.tm_params self._get_model_params(model_comm, tm_params) logger.warning(f'get {len(tm_params)} model params') + if model_params_que is not None: + tm_model.input_model.model_path = model_params_que tm_model.export() # there should be no left turbomind params. if len(tm_params) > 0: From de0095ef51f8a614c68b59f2696e81e04eb0cd87 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 6 Mar 2025 02:57:36 +0000 Subject: [PATCH 2/2] fix comments --- lmdeploy/api.py | 8 ++++++++ lmdeploy/messages.py | 3 +-- lmdeploy/serve/async_engine.py | 8 ++++++++ lmdeploy/serve/openai/api_server.py | 4 ++++ lmdeploy/turbomind/deploy/loader.py | 4 ++-- lmdeploy/turbomind/turbomind.py | 21 ++++++++++++++------- 6 files changed, 37 insertions(+), 11 deletions(-) diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 3377e7800b..b4bc6183e8 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +from multiprocessing.queues import Queue as MpQueue +from queue import Queue from typing import List, Literal, Optional, Union from .archs import autoget_backend_config, get_task @@ -92,6 +94,7 @@ def serve(model_path: str, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, + model_params_que: Optional[Union[Queue, MpQueue]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, server_name: str = '0.0.0.0', server_port: int = 23333, @@ -122,6 +125,10 @@ def serve(model_path: str, `turbomind` backend. backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend config instance. Default to none. + model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters. + The first item should be list of all names of a model (state_dict().keys()), + the following item should be part of state_dict(), and the last item should + be None, indicating the end of the queue. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. server_name (str): host ip for serving @@ -156,6 +163,7 @@ def serve(model_path: str, kwargs=dict(model_name=model_name, backend=backend, backend_config=backend_config, + model_params_que=model_params_que, chat_template_config=chat_template_config, server_name=server_name, server_port=server_port, diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 3fbe6092bb..4df4204022 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Callable, Dict, List, Literal, Optional import torch from pydantic.dataclasses import dataclass as pydantic_dataclass @@ -224,7 +224,6 @@ class TurbomindEngineConfig: num_tokens_per_iter: int = 0 max_prefill_iters: int = 1 communicator: str = 'nccl' - model_params_que: Any = None def __post_init__(self): """Check input validation.""" diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index ce22a3c6e1..86e00fc653 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -12,6 +12,7 @@ from copy import deepcopy from functools import partial from itertools import count +from multiprocessing.queues import Queue as MpQueue from queue import Queue from threading import Thread from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union @@ -244,6 +245,10 @@ class AsyncEngine(LogitsMixin): `turbomind` backend. backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend config instance. Default to none. + model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters. + The first item should be list of all names of a model (state_dict().keys()), + the following item should be part of state_dict(), and the last item should + be None, indicating the end of the queue. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. max_log_len (int): Max number of prompt characters or prompt tokens @@ -255,6 +260,7 @@ def __init__(self, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, + model_params_que: Optional[Union[Queue, MpQueue]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, **kwargs) -> None: @@ -275,6 +281,8 @@ def __init__(self, self.arch, _ = get_model_arch(model_path) # build backend engine + assert model_params_que is None or backend == 'turbomind', 'only support turbomind backend' + kwargs.update(model_params_que=model_params_que) if backend == 'turbomind': self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs) elif backend == 'pytorch': diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 2fb17e4c6f..56a7ce1bc1 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -6,6 +6,8 @@ import time from functools import partial from http import HTTPStatus +from multiprocessing.queues import Queue as MpQueue +from queue import Queue from typing import AsyncGenerator, Dict, List, Literal, Optional, Union import uvicorn @@ -925,6 +927,7 @@ def serve(model_path: str, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[PytorchEngineConfig, TurbomindEngineConfig]] = None, + model_params_que: Optional[Union[Queue, MpQueue]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, server_name: str = '0.0.0.0', server_port: int = 23333, @@ -1032,6 +1035,7 @@ def serve(model_path: str, model_name=model_name, backend=backend, backend_config=backend_config, + model_params_que=model_params_que, chat_template_config=chat_template_config, max_log_len=max_log_len, **kwargs) diff --git a/lmdeploy/turbomind/deploy/loader.py b/lmdeploy/turbomind/deploy/loader.py index 4ff3f4644f..5277d55917 100644 --- a/lmdeploy/turbomind/deploy/loader.py +++ b/lmdeploy/turbomind/deploy/loader.py @@ -146,7 +146,7 @@ def items(self): yield (idx, params.pop(idx)) -class QueueLoader: +class StateDictLoader: def __init__(self, queue: Union[Queue, MpQueue], pattern: str): self.que = queue @@ -182,7 +182,7 @@ def create_loader(model_path: str, pattern: str) -> BaseLoader: args = (model_path, pattern) if isinstance(model_path, Queue) or isinstance(model_path, MpQueue): - return QueueLoader(*args) + return StateDictLoader(*args) if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)): return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 9c150f53d0..178fba0935 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -9,8 +9,9 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from functools import partial +from multiprocessing.queues import Queue as MpQueue from queue import Queue -from typing import Dict, List +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -76,6 +77,10 @@ class TurboMind: engine model_source (int): the source of the model, which is either turbomind model, or a transformers model + model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters. + The first item should be list of all names of a model (state_dict().keys()), + the following item should be part of state_dict(), and the last item should + be None, indicating the end of the queue. """ def __init__(self, @@ -85,6 +90,7 @@ def __init__(self, chat_template_name: str = None, engine_config: TurbomindEngineConfig = None, model_source: ModelSource = ModelSource.WORKSPACE, + model_params_que: Optional[Union[Queue, MpQueue]] = None, **kwargs): self.model_name = model_name self.chat_template_name = chat_template_name @@ -107,7 +113,8 @@ def __init__(self, model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision) self.model_comm = self._from_hf(model_source=model_source, model_path=model_path, - engine_config=_engine_config) + engine_config=_engine_config, + model_params_que=model_params_que) with ThreadPoolExecutor(max_workers=self.gpu_count) as e: ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)] @@ -187,7 +194,11 @@ def _postprocess_config(self, tm_config, engine_config): logger.info(f'turbomind model config:\n\n' f'{json.dumps(self.config_dict, indent=2)}') - def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: TurbomindEngineConfig): + def _from_hf(self, + model_source: ModelSource, + model_path: str, + engine_config: TurbomindEngineConfig, + model_params_que: Optional[Union[Queue, MpQueue]] = None): """Load model which is in hf format.""" assert model_source == ModelSource.HF_MODEL, \ f'{model_source} is not supported' @@ -198,10 +209,6 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu from .deploy.converter import get_tm_model tm_model = get_tm_model(model_path, self.model_name, self.chat_template_name, engine_config) - # get model_params_que - model_params_que = engine_config.model_params_que - engine_config.model_params_que = None - self._postprocess_config(tm_model.tm_config, engine_config) model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',