Skip to content
14 changes: 8 additions & 6 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Any, Dict, List, Tuple
from typing import Any

import torch
from torch.profiler import record_function
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
max_tokens: int,
num_blocks: int,
is_decoding: bool,
pool: Tuple[int, int],
pool: tuple[int, int],
model_config: ModelConfig,
device: torch.device,
decode_query_len: int = 1,
Expand All @@ -89,6 +89,8 @@ def __init__(
mla_index_topk=getattr(self.model_config, 'mla_index_topk', None),
decode_query_len=decode_query_len,
use_fa3_decoding=model_config.model_paradigm == 'ar_spec',
is_ssm=len(model_config.states_shapes) > 0,
use_mrope=model_config.use_mrope,
)
self.device = device
self.max_batches = max_batches
Expand Down Expand Up @@ -153,7 +155,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
self.enable_graph = self.check_enable_graph()

self.graph_pool_handle = torch.cuda.graph_pool_handle()
self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
self._runner_map: dict[Any, CUDASingleGraphRunner] = dict()
self.has_try_compile_model: bool = False

# strategy factory
Expand Down Expand Up @@ -187,7 +189,7 @@ def _get_capture_tokens(self, batch_size: int):
return size
assert False, f'Unsupported batch_size={batch_size}'

def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list,
attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):
"""Get graph key."""
context = self.ctx_mgr.current_context()
Expand Down Expand Up @@ -261,7 +263,7 @@ def __call__(self, **kwargs):
@record_function('prepare_inputs_for_generation')
def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
past_key_values: list[list[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
Expand Down Expand Up @@ -303,6 +305,6 @@ def update_inputs(self, inputs):
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
def get_capture_batch_sizes(self) -> list[int]:
"""Capture batch sizes."""
return _get_capture_batch_size_impl(self.cache_config.max_batches)
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/check_env/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base import BaseChecker

MIN_TRANSFORMERS_VERSION = '4.33.0'
MAX_TRANSFORMERS_VERSION = '5.2.0'
MAX_TRANSFORMERS_VERSION = '5.3.0'


class TransformersChecker(BaseChecker):
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ class ModelConfig:
# quant config
quant_config: 'QuantizationConfig' = None

# flags mark if this model use mrope
use_mrope: bool = False

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand Down
13 changes: 8 additions & 5 deletions lmdeploy/pytorch/configurations/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
hf_config.scoring_func = 'sigmoid'
if not hasattr(hf_config, 'moe_layer_freq'):
hf_config.moe_layer_freq = 1
return super().build(hf_config,
model_path=model_path,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg = super().build(hf_config,
model_path=model_path,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg.use_mrope = True
return cfg


class Glm4MoeModelConfigBuilder(DefaultModelConfigBuilder):
Expand Down Expand Up @@ -58,6 +60,7 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg.use_mrope = True
cfg.model_paradigm = model_paradigm
cfg.num_layers = num_layers
return cfg
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/configurations/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
dtype = torch.bfloat16
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
cfg.check_env_func = _check_env_qwen3_next

cfg.use_mrope = True
return cfg
1 change: 1 addition & 0 deletions lmdeploy/pytorch/configurations/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
setattr(hf_config, 'dtype', hf_config.text_config.dtype)
cfg.hf_config = hf_config
cfg.use_mrope = True
return cfg
32 changes: 18 additions & 14 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gc
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any

import numpy as np
import torch
Expand All @@ -28,7 +28,7 @@

logger = get_logger('lmdeploy')

SeqList = List[SchedulerSequence]
SeqList = list[SchedulerSequence]


@dataclass
Expand All @@ -37,15 +37,15 @@ class InferOutput:

session_id: int
resp: Response
token_ids: Union[np.ndarray, List[int]]
token_ids: np.ndarray | list[int]
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
logprobs: torch.Tensor = None

# send cache blocks back for migration in Disaggregated LLM Serving
# when Prefill Engine is Done.
cache_block_ids: List[int] = None
cache_block_ids: list[int] = None

# for logging
req_metrics: RequestMetrics = None
Expand All @@ -54,10 +54,13 @@ class InferOutput:
routed_experts: torch.Tensor = None


def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):
def _build_seq_meta(model_config: ModelConfig, cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):
from lmdeploy.pytorch.messages import SequenceMeta

seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy)
seq_meta = SequenceMeta(cache_config.block_size,
strategy=seq_strategy,
sampling_strategy=sampling_strategy,
use_mrope=model_config.use_mrope)
return seq_meta


Expand Down Expand Up @@ -156,7 +159,8 @@ def __init__(
self.input_processor = self.executor.get_input_processor()
cache_config = self.executor.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
self.seq_meta = _build_seq_meta(cache_config,
self.seq_meta = _build_seq_meta(model_config=self.model_config,
cache_config=cache_config,
seq_strategy=self.seq_strategy,
sampling_strategy=self.sampling_strategy)
self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta)
Expand Down Expand Up @@ -229,7 +233,7 @@ def from_pretrained(cls,
speculative_config=speculative_config,
)

def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig):
def _download_adapters(self, adapters: dict[str, str], engine_config: PytorchEngineConfig):
"""Download adapters."""
download_dir = engine_config.download_dir
revision = engine_config.revision
Expand Down Expand Up @@ -274,7 +278,7 @@ def _get_max_session_len(self):
session_len = min(max_tokens, session_len)
return session_len

def _on_add_session(self, reqs: List[Request], **kwargs):
def _on_add_session(self, reqs: list[Request], **kwargs):
"""On add session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -286,7 +290,7 @@ def _on_add_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_stop_session(self, reqs: List[Request], **kwargs):
def _on_stop_session(self, reqs: list[Request], **kwargs):
"""On stop session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -305,7 +309,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_end_session(self, reqs: List[Request], **kwargs):
def _on_end_session(self, reqs: list[Request], **kwargs):
"""On end session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -321,7 +325,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_add_message(self, reqs: List[Request], **kwargs):
def _on_add_message(self, reqs: list[Request], **kwargs):
"""On add message callback."""
valid_reqs = []
for req in reqs:
Expand Down Expand Up @@ -359,7 +363,7 @@ def _on_add_message(self, reqs: List[Request], **kwargs):
if len(valid_reqs) > 0:
self._add_message(valid_reqs)

def _add_message(self, reqs: List[Request]):
def _add_message(self, reqs: list[Request]):

def __update_max_new_tokens(msg):
"""Update max new tokens."""
Expand Down Expand Up @@ -440,7 +444,7 @@ def sleep(self, level: int = 1):
"""Sleep."""
self.executor.sleep(level)

def wakeup(self, tags: Optional[List[str]] = None):
def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.executor.wakeup(tags)

Expand Down
32 changes: 23 additions & 9 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand Down Expand Up @@ -54,10 +54,12 @@ class InputsMakerConfig:
dp: int = 1
spec_decoding: bool = False
enable_chunked_prefill: bool = False
use_mrope: bool = False

@staticmethod
def from_engine(engine: 'Engine'):
cache_config = engine.cache_config
model_config = engine.model_config
return InputsMakerConfig(
spec_decoding=engine.specdecode_config is not None,
max_batches=cache_config.max_batches,
Expand All @@ -66,6 +68,7 @@ def from_engine(engine: 'Engine'):
is_ssm=len(cache_config.states_shapes) > 0,
dp=engine.dist_config.dp,
enable_chunked_prefill=engine.misc_config.enable_chunked_prefill,
use_mrope=model_config.use_mrope,
)


Expand Down Expand Up @@ -219,8 +222,8 @@ def __init__(

# running seqs
# mark the seqs that have been sent to executor
self.running_seqs: List['SchedulerSequence'] = []
self.to_evict_seqs: List['SchedulerSequence'] = []
self.running_seqs: list['SchedulerSequence'] = []
self.to_evict_seqs: list['SchedulerSequence'] = []

# long context chunker
self.long_context_chunker = LongContextChunker(config.max_prefill_token_num)
Expand Down Expand Up @@ -379,14 +382,19 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
state_offsets = torch.tensor([msg.logical_state for msg in messages])
model_inputs.state_offsets = state_offsets

if self.config.use_mrope:
mrope_pos_ids = [msg.mrope_pos_ids for msg in messages]
mrope_pos_ids = torch.as_tensor(np.concatenate(mrope_pos_ids)).T
model_inputs.mrope_pos_ids = mrope_pos_ids

return model_inputs

@torch.inference_mode()
@record_function('create_model_inputs_long_context')
def create_model_inputs_long_context(self,
seq: 'SchedulerSequence',
chunk_size: int,
multimodals: Optional['MultiModalInputs'] = None):
multimodals: 'MultiModalInputs|None' = None):
"""Create model inputs for long context messages."""
token_ids = seq.token_ids[:chunk_size]
input_ids = torch.as_tensor(token_ids)[None]
Expand Down Expand Up @@ -436,6 +444,12 @@ def create_model_inputs_long_context(self,
if self.config.is_ssm:
model_inputs.state_offsets = torch.tensor([seq.logical_state])

# mrope
if self.config.use_mrope:
mrope_pos_ids = seq.mrope_pos_ids[:chunk_size]
mrope_pos_ids = torch.as_tensor(mrope_pos_ids).T
model_inputs.mrope_pos_ids = mrope_pos_ids

return model_inputs

@torch.inference_mode()
Expand All @@ -453,8 +467,8 @@ def create_model_inputs_delta(self):

valid_mask = np.array(valid_mask)
indices_cpu = np.arange(0, batch_size)[valid_mask]
valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
if len(valid_seqs) == 0:
return None, valid_seqs, invalid_seqs

Expand Down Expand Up @@ -498,8 +512,8 @@ def create_model_inputs_delta_valid_only(self):

valid_mask = np.array(valid_mask, dtype=bool)
indices_cpu = np.arange(0, batch_size)[valid_mask]
valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]

num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
max_q_seqlen = num_decode_tokens
Expand All @@ -523,7 +537,7 @@ def create_model_inputs_delta_valid_only(self):

return output, valid_seqs, invalid_seqs

def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]):
def update_running_seqs(self, running: 'SeqList', inputs: 'ModelInputs|None'):
"""Update running seqs."""
if self.config.role == EngineRole.Prefill:
# p node will not update running seqs
Expand Down
12 changes: 9 additions & 3 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def __init__(
# long context
self._prev_chunk_output: Dict = None

# make dummy meta
self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config)

@contextmanager
def all_context(self):
device_mgr = get_device_manager()
Expand All @@ -427,10 +430,11 @@ def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheCo
self.cache_config = cache_config
self.spec_agent.set_cache_config(spec_cache_config)

def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig | None = None):
"""Set model config."""
self.model_config = model_config
self.spec_agent.set_model_config(spec_model_config)
self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config)

def get_free_mem(self):
"""Gather available memory."""
Expand Down Expand Up @@ -461,7 +465,8 @@ def warmup(self):
inputs = self.inputs_strategy.make_dummy(max_batches,
is_decoding=False,
device='cuda',
vocab_size=self.model_config.vocab_size)
vocab_size=self.model_config.vocab_size,
meta=self.make_dummy_meta)
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
Expand All @@ -480,7 +485,8 @@ def warmup(self):
inputs = self.inputs_strategy.make_dummy(num_tokens,
is_decoding=True,
device='cuda',
vocab_size=self.model_config.vocab_size)
vocab_size=self.model_config.vocab_size,
meta=self.make_dummy_meta)
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
Expand Down
Loading
Loading