From 34d4fa67251774e52339dfee2a123db403cb8024 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 2 Mar 2026 17:55:04 +0800 Subject: [PATCH 1/9] built-in mrope pos ids support --- lmdeploy/pytorch/config.py | 3 + lmdeploy/pytorch/configurations/glm4.py | 12 +- lmdeploy/pytorch/configurations/qwen3_5.py | 2 + lmdeploy/pytorch/configurations/qwen3_vl.py | 1 + lmdeploy/pytorch/engine/engine.py | 10 +- lmdeploy/pytorch/engine/inputs_maker.py | 14 ++ lmdeploy/pytorch/engine/model_agent/agent.py | 9 +- .../engine/model_agent/inputs_maker.py | 6 +- lmdeploy/pytorch/messages.py | 82 +++++++++++ lmdeploy/pytorch/model_inputs.py | 54 ++++--- lmdeploy/pytorch/models/glm4_1v.py | 130 +---------------- lmdeploy/pytorch/models/qwen2_5_vl.py | 133 +----------------- lmdeploy/pytorch/models/qwen2_vl.py | 109 +++----------- lmdeploy/pytorch/models/qwen3_5.py | 88 ------------ lmdeploy/pytorch/models/qwen3_vl.py | 86 ----------- lmdeploy/pytorch/multimodal/__init__.py | 4 +- lmdeploy/pytorch/multimodal/data_type.py | 21 ++- lmdeploy/pytorch/multimodal/image_type.py | 15 -- lmdeploy/pytorch/paging/scheduler.py | 2 +- lmdeploy/pytorch/spec_decode/spec_agent.py | 40 +++--- lmdeploy/pytorch/strategies/ar/model_agent.py | 7 + .../pytorch/strategies/ar/model_inputs.py | 21 ++- lmdeploy/pytorch/strategies/ar/sequence.py | 2 + .../strategies/ar_spec/model_inputs.py | 6 +- .../pytorch/strategies/ar_spec/sequence.py | 2 + .../pytorch/strategies/base/model_inputs.py | 35 ++++- .../pytorch/strategies/dllm/model_inputs.py | 20 +-- 27 files changed, 298 insertions(+), 616 deletions(-) delete mode 100644 lmdeploy/pytorch/multimodal/image_type.py diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 79ee15d903..252c5592aa 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -337,6 +337,9 @@ class ModelConfig: # quant config quant_config: 'QuantizationConfig' = None + # flags mark if this model use mrope + use_mrope: bool = False + def get_head_size(self): """Get head size.""" return self.head_dim diff --git a/lmdeploy/pytorch/configurations/glm4.py b/lmdeploy/pytorch/configurations/glm4.py index 475520388e..587ecbf155 100644 --- a/lmdeploy/pytorch/configurations/glm4.py +++ b/lmdeploy/pytorch/configurations/glm4.py @@ -17,8 +17,10 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, hf_config.scoring_func = 'sigmoid' if not hasattr(hf_config, 'moe_layer_freq'): hf_config.moe_layer_freq = 1 - return super().build(hf_config, - model_path=model_path, - is_draft_model=is_draft_model, - spec_method=spec_method, - **kwargs) + cfg = super().build(hf_config, + model_path=model_path, + is_draft_model=is_draft_model, + spec_method=spec_method, + **kwargs) + cfg.use_mrope = True + return cfg diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 9c1b0111fa..0bb87545bb 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -45,4 +45,6 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): dtype = torch.bfloat16 cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)] cfg.check_env_func = _check_env_qwen3_next + + cfg.use_mrope = True return cfg diff --git a/lmdeploy/pytorch/configurations/qwen3_vl.py b/lmdeploy/pytorch/configurations/qwen3_vl.py index 6b78efcd0b..212d34c721 100644 --- a/lmdeploy/pytorch/configurations/qwen3_vl.py +++ b/lmdeploy/pytorch/configurations/qwen3_vl.py @@ -22,4 +22,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs): cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs) setattr(hf_config, 'dtype', hf_config.text_config.dtype) cfg.hf_config = hf_config + cfg.use_mrope = True return cfg diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 71f061f5ef..ad1fae4901 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -54,10 +54,13 @@ class InferOutput: routed_experts: torch.Tensor = None -def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any): +def _build_seq_meta(model_config: ModelConfig, cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta - seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy) + seq_meta = SequenceMeta(cache_config.block_size, + strategy=seq_strategy, + sampling_strategy=sampling_strategy, + use_mrope=model_config.use_mrope) return seq_meta @@ -156,7 +159,8 @@ def __init__( self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.seq_meta = _build_seq_meta(cache_config, + self.seq_meta = _build_seq_meta(model_config=self.model_config, + cache_config=cache_config, seq_strategy=self.seq_strategy, sampling_strategy=self.sampling_strategy) self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 506a372250..57094f5ea7 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -54,10 +54,12 @@ class InputsMakerConfig: dp: int = 1 spec_decoding: bool = False enable_chunked_prefill: bool = False + use_mrope: bool = False @staticmethod def from_engine(engine: 'Engine'): cache_config = engine.cache_config + model_config = engine.model_config return InputsMakerConfig( spec_decoding=engine.specdecode_config is not None, max_batches=cache_config.max_batches, @@ -66,6 +68,7 @@ def from_engine(engine: 'Engine'): is_ssm=len(cache_config.states_shapes) > 0, dp=engine.dist_config.dp, enable_chunked_prefill=engine.misc_config.enable_chunked_prefill, + use_mrope=model_config.use_mrope, ) @@ -379,6 +382,11 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): state_offsets = torch.tensor([msg.logical_state for msg in messages]) model_inputs.state_offsets = state_offsets + if self.config.use_mrope: + mrope_pos_ids = [msg.mrope_pos_ids for msg in messages] + mrope_pos_ids = torch.as_tensor(np.concatenate(mrope_pos_ids)).T + model_inputs.mrope_pos_ids = mrope_pos_ids + return model_inputs @torch.inference_mode() @@ -436,6 +444,12 @@ def create_model_inputs_long_context(self, if self.config.is_ssm: model_inputs.state_offsets = torch.tensor([seq.logical_state]) + # mrope + if self.config.use_mrope: + mrope_pos_ids = seq.mrope_pos_ids[:chunk_size] + mrope_pos_ids = torch.as_tensor(mrope_pos_ids).T + model_inputs.mrope_pos_ids = mrope_pos_ids + return model_inputs @torch.inference_mode() diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 45c9ddb436..7a09579e05 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -415,6 +415,9 @@ def __init__( # long context self._prev_chunk_output: Dict = None + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config) + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -461,7 +464,8 @@ def warmup(self): inputs = self.inputs_strategy.make_dummy(max_batches, is_decoding=False, device='cuda', - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) @@ -480,7 +484,8 @@ def warmup(self): inputs = self.inputs_strategy.make_dummy(num_tokens, is_decoding=True, device='cuda', - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) diff --git a/lmdeploy/pytorch/engine/model_agent/inputs_maker.py b/lmdeploy/pytorch/engine/model_agent/inputs_maker.py index ba1dfb7a63..d3cb10bf78 100644 --- a/lmdeploy/pytorch/engine/model_agent/inputs_maker.py +++ b/lmdeploy/pytorch/engine/model_agent/inputs_maker.py @@ -43,6 +43,9 @@ def __init__(self, model_agent: 'BaseModelAgent'): self._ready_event = torch.cuda.Event() self._ready_event.record() + # other + self.make_dummy_meta = model_agent.make_dummy_meta + def _make_dummy_forward_inputs(self): """Make dummy forward inputs.""" is_decoding = self.cache_config.role != EngineRole.Prefill @@ -52,7 +55,8 @@ def _make_dummy_forward_inputs(self): model_inputs = self.inputs_strategy.make_dummy(batch_size, is_decoding, device=self.device, - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) forward_inputs = dict(inputs=model_inputs, ) return forward_inputs diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c020403fa8..be8aa2b61a 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -176,6 +176,7 @@ class SequenceMeta: block_size: int strategy: 'SequenceStrategy' = None sampling_strategy: 'SamplingStrategy' = None + use_mrope: bool = False class SequenceManager: @@ -532,6 +533,25 @@ def clone(self): return ret +class HistoryMropePosIds(_HistoryDataBase): + """History mrope position ids.""" + ALLOC_SIZE = 64 + + def __init__(self, pos_ids: np.ndarray | None = None, dtype: np.dtype = np.int64): + super().__init__(pos_ids, dtype) + + def _create_empty_array(self, dtype): + """Create empty array. + + Override in subclass for different shapes. + """ + return np.empty((self.ALLOC_SIZE, 3), dtype=dtype) + + def _get_pad_width(self, reserve_size: int): + """Get pad width for multi-dimensional array.""" + return ((0, reserve_size), (0, 0)) + + class HistoryMultiModals: def __init__(self, multimodals: MultiModalInputs = None): @@ -617,6 +637,9 @@ class SchedulerSequence: # logits all_logits: HistoryLogits = field(default_factory=HistoryLogits) + # mrope + history_mrope_pos_ids: HistoryMropePosIds = field(default_factory=HistoryMropePosIds) + def __post_init__(self): """Post init.""" self._seq_meta: SequenceMeta = self.session.seq_meta @@ -756,6 +779,13 @@ def logits(self): """Get logits.""" return self.all_logits.get_logits() + @property + def mrope_pos_ids(self): + """Get mrope pos ids.""" + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_mrope_pos_ids[start:end] + def append_logits(self, logits: Union[Tensor, np.ndarray]): """Append logits.""" if not self.return_logits: @@ -797,6 +827,58 @@ def _update_multimodals(self, multimodals: MultiModalInputs): multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) self.history_multimodals.add_inputs(multimodals) + def _update_mrope_pos_ids(self): + """Update mrope pos ids.""" + if not self._seq_meta.use_mrope: + return + + num_rope_pos = len(self.history_mrope_pos_ids) + num_appends = self.num_all_ids - num_rope_pos + + if num_appends == 0: + return + + if num_rope_pos == 0: + next_pos = 0 + else: + next_pos = self.history_mrope_pos_ids[-1].max() + 1 + + multimodals = self.history_multimodals.get_datas(num_rope_pos, self.num_all_ids) + if multimodals is None or len(multimodals) == 0: + if num_appends == 1: + pos_ids = np.array([[next_pos] * 3], dtype=np.int64) + else: + pos_ids = np.arange(next_pos, next_pos + num_appends, dtype=np.int64) + pos_ids = pos_ids[:, None].repeat(3, axis=1) + else: + pos_ids = [] + assert len(multimodals) == 1 + modal_datas = list(multimodals.values())[0] + mm_offset = next_pos + for modal_data in modal_datas: + mm_start = modal_data.start + mm_offset + + # tokens + if next_pos < mm_start: + text_pos_ids = np.arange(next_pos, mm_start, dtype=np.int64) + pos_ids.append(text_pos_ids[:, None].repeat(3, axis=1)) + + # imgs + mm_pos_ids = modal_data.mrope_pos_ids + assert mm_pos_ids is not None, ( + 'MROPE position ids is required for multimodal inputs when use_mrope is True.') + new_pos = mm_pos_ids[-1].max() + 1 + next_pos = mm_start + new_pos + mm_offset = mm_offset + new_pos - mm_pos_ids.shape[0] + pos_ids.append(mm_pos_ids + mm_start) + + # add final text part + text_pos_ids = np.arange(next_pos, num_appends + mm_offset, dtype=np.int64) + pos_ids.append(text_pos_ids[:, None].repeat(3, axis=1)) + pos_ids = np.concatenate(pos_ids, axis=0) + + self.history_mrope_pos_ids.append(pos_ids) + def update_token_ids(self, token_ids: Tensor, multimodals: MultiModalInputs = None, diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 6ee9fb5133..f26699f1c7 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -182,18 +182,21 @@ class ModelInputs: max_q_seqlen: int max_kv_seqlen: int sum_kv_seqlen: int - local_adapter_ids: torch.Tensor = None - vision_inputs: VisionModelInputs = None - model_metas: List[Dict[str, Any]] = None - dp_meta: 'DPMeta' = None + local_adapter_ids: torch.Tensor | None = None + vision_inputs: VisionModelInputs | None = None + model_metas: List[Dict[str, Any]] | None = None + dp_meta: DPMeta | None = None enable_microbatch: bool = False is_dummy: bool = False - state_offsets: torch.Tensor = None - target_hidden_states: torch.Tensor = None - target_position_ids: torch.Tensor = None + state_offsets: torch.Tensor | None = None + target_hidden_states: torch.Tensor | None = None + target_position_ids: torch.Tensor | None = None is_chunk: bool = False is_first_chunk: bool = True + # mrope, shape(3, sum_seqlens) + mrope_pos_ids: torch.Tensor | None = None + def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): """Update input ids.""" assert self.is_decoding @@ -205,6 +208,9 @@ def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): if input_ids.dim() == 1: input_ids = input_ids[None, :] self.input_ids = input_ids + + if self.mrope_pos_ids is not None: + self.mrope_pos_ids = self.mrope_pos_ids + self.seq_length[None] return self @torch.inference_mode() @@ -251,23 +257,26 @@ class StepContext: kv_caches: List is_decoding: bool sum_kv_seqlen: int - max_kv_seqlen: int = None - local_adapter_ids: torch.LongTensor = None - input_embeddings: torch.Tensor = None - input_embedding_indexing: torch.Tensor = None - input_multimodals: List[MultiModalTensor] = None - vision_inputs: VisionModelInputs = None + max_kv_seqlen: int | None = None + local_adapter_ids: torch.LongTensor | None = None + input_embeddings: torch.Tensor | None = None + input_embedding_indexing: torch.Tensor | None = None + input_multimodals: List[MultiModalTensor] | None = None + vision_inputs: VisionModelInputs | None = None attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 - model_metas: List[Dict[str, Any]] = None - dp_meta: DPMeta = None + model_metas: List[Dict[str, Any]] | None = None + dp_meta: DPMeta | None = None enable_microbatch: bool = False # for draft model - target_hidden_states: torch.Tensor = None + target_hidden_states: torch.Tensor | None = None # states for ssm - state_caches: List = None - state_offsets: torch.LongTensor = None + state_caches: List | None = None + state_offsets: torch.LongTensor | None = None + + # mrope + mrope_pos_ids: torch.Tensor | None = None _outputs: Dict = field(default_factory=dict) @@ -277,8 +286,8 @@ def new( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List = None, - state_caches: List = None, + kv_caches: List | None = None, + state_caches: List | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build step context. @@ -334,6 +343,7 @@ def new( state_caches=state_caches, state_offsets=inputs.state_offsets, target_hidden_states=inputs.target_hidden_states, + mrope_pos_ids=inputs.mrope_pos_ids, ) ret = get_backend().update_step_context(ret) @@ -406,8 +416,8 @@ def build_context( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List = None, - state_caches: List = None, + kv_caches: List | None = None, + state_caches: List | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build context.""" diff --git a/lmdeploy/pytorch/models/glm4_1v.py b/lmdeploy/pytorch/models/glm4_1v.py index 9b89164bef..b8fbcee1e6 100644 --- a/lmdeploy/pytorch/models/glm4_1v.py +++ b/lmdeploy/pytorch/models/glm4_1v.py @@ -2,21 +2,21 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .glm4 import Glm4DecoderLayer +from .qwen2_vl import Qwen2VLInputProcessor as Glm4vInputProcessor from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -749,130 +749,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): return new_inputs - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor - - -class Glm4vInputProcessor(BaseModelInputProcessor): - """Glm4v input processor.""" - - def __init__(self, config: PretrainedConfig) -> None: - self.config = config - - def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, - **kwargs) -> PreprocessInputResult: - """Prepare multimodal input.""" - if input_multimodals is None or len(input_multimodals) == 0: - return input_ids, input_multimodals - - input_imgs = [] - for input_mm in input_multimodals: - pixel_values = input_mm['pixel_values'] - image_grid_thw = input_mm['image_grid_thw'] - offset = input_mm['offset'] - start = offset - image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() - - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) - input_imgs.append(mm_data) - - result = PreprocessInputResult( - input_ids=input_ids, - input_multimodals=dict(image=input_imgs), - ) - return result diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 53b683154a..851d2cf526 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -2,17 +2,17 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.qwen2_vl import Qwen2Model -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.models.qwen2_vl import Qwen2VLInputProcessor as Qwen2_5_VLInputProcessor from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -609,133 +609,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): return new_inputs - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor - - -InputMultiModalType = List[Dict[str, Any]] - - -class Qwen2_5_VLInputProcessor(BaseModelInputProcessor): - """Qwen2 input processor.""" - - def __init__(self, config: PretrainedConfig) -> None: - self.config = config - - def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, - **kwargs) -> PreprocessInputResult: - """Prepare multimodal input.""" - if input_multimodals is None or len(input_multimodals) == 0: - return input_ids, input_multimodals - - input_imgs = [] - for input_mm in input_multimodals: - pixel_values = input_mm['pixel_values'] - image_grid_thw = input_mm['image_grid_thw'] - offset = input_mm['offset'] - start = offset - image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() - - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) - input_imgs.append(mm_data) - - result = PreprocessInputResult( - input_ids=input_ids, - input_multimodals=dict(image=input_imgs), - ) - return result diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index a6f08af680..6e38268d95 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple +import numpy as np import torch from torch import nn from transformers.configuration_utils import PretrainedConfig @@ -810,92 +811,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): return new_inputs - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor @@ -910,6 +825,23 @@ class Qwen2VLInputProcessor(BaseModelInputProcessor): def __init__(self, config: PretrainedConfig) -> None: self.config = config + @staticmethod + def _get_multimodal_pos_ids(grid_thw: Sequence[int]) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // 2 + w = w // 2 + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + return pos_ids + + @staticmethod + def make_mrope(grid_thw: torch.Tensor, ): + img_pos_ids = Qwen2VLInputProcessor._get_multimodal_pos_ids(grid_thw[0].tolist()) + return img_pos_ids + def preprocess_input(self, input_ids: List[int], input_multimodals: List[Dict[str, Any]] = None, @@ -929,9 +861,12 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() + mrope_pos_ids = self.make_mrope(image_grid_thw) + mm_data = MultiModalTensor(data=pixel_values, start=start, end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) input_imgs.append(mm_data) diff --git a/lmdeploy/pytorch/models/qwen3_5.py b/lmdeploy/pytorch/models/qwen3_5.py index 826b14a646..e5c81ed203 100644 --- a/lmdeploy/pytorch/models/qwen3_5.py +++ b/lmdeploy/pytorch/models/qwen3_5.py @@ -1219,94 +1219,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs): return new_inputs - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas_cpu = torch.tensor(mrope_deltas, device='cpu') - if (mrope_deltas_cpu == mrope_deltas_cpu[0]).all(): - mrope_deltas = position_ids.new_full((len(mrope_deltas), ), mrope_deltas[0]) - else: - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, past_key_values: List[List[torch.Tensor]], inputs_embeds: torch.Tensor | None, - context: StepContext): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 01c7cf8cdf..fc8415be8a 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -759,92 +759,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): return new_inputs - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py index c3e8c6a16f..54f819bba0 100644 --- a/lmdeploy/pytorch/multimodal/__init__.py +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .data_type import MultiModalData, MultiModalTensor +from .data_type import MultiModalTensor -__all__ = ['MultiModalData', 'MultiModalTensor'] +__all__ = ['MultiModalTensor'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index dd3ec9a37d..9b35d37bdb 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -1,26 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, fields -from typing import Any, Dict, List, Union +from typing import Any, Dict, List +import numpy as np from torch import Tensor - -class MultiModalData: - pass - - -MultiModalDataList = List[MultiModalData] - -NestedTensor = Union[Tensor, List[Tensor]] +NestedTensor = Tensor | List[Tensor] @dataclass class MultiModalTensor: data: NestedTensor start: int - end: int = None - encoder_len: int = None - meta: Dict[str, Any] = None + end: int | None = None + encoder_len: int | None = None + meta: Dict[str, Any] | None = None + + # for qwen-vl + mrope_pos_ids: np.ndarray | None = None def __post_init__(self): if self.end is None: diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py deleted file mode 100644 index 19211a381f..0000000000 --- a/lmdeploy/pytorch/multimodal/image_type.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass -from typing import Any, ClassVar, Dict - -from PIL import Image - -from .data_type import MultiModalData - - -@dataclass -class ImageData(MultiModalData): - data: Image - loc: int - meta: Dict[str, Any] = None - type: ClassVar[str] = 'image' diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 9208b7cdf2..8606a6344a 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -46,7 +46,7 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - seq_meta: SequenceMeta = None, + seq_meta: SequenceMeta | None = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 51739d05d5..0b17ef0949 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -45,6 +45,9 @@ def __init__( self.cache_config = specdecode_config.cache_config self.num_spec_tokens = specdecode_config.num_speculative_tokens + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(self.model_config) + def set_cache_config(self, cache_config: CacheConfig): """Set all cache config.""" self.cache_config = cache_config @@ -194,7 +197,8 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): device='cuda', vocab_size=self.model_config.vocab_size, target_hidden_size=target_hidden_size, - target_dtype=self.model_config.dtype) + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) @@ -203,26 +207,24 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): for batch_size in capture_batch_sizes: # decode with num_spec_tokens + 1 per seq - inputs = self.inputs_strategy.make_dummy( - batch_size, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size, - max_q_seqlen=self.num_spec_tokens + 1, - target_hidden_size=target_hidden_size, - target_dtype=self.model_config.dtype, - ) + inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=self.num_spec_tokens + 1, + target_hidden_size=target_hidden_size, + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) # decode 1 tokens per sequence - inputs = self.inputs_strategy.make_dummy( - batch_size, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size, - max_q_seqlen=1, - target_hidden_size=self.model_config.hidden_size, - target_dtype=self.model_config.dtype, - ) + inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=1, + target_hidden_size=self.model_config.hidden_size, + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) def reset_graph_runner(self): diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 9c7abb5887..5678a11aab 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -25,6 +25,12 @@ def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, state_offsets = inputs.state_offsets if state_offsets is not None: state_offsets = state_offsets.clone() + + # mrope + mrope_pos_ids = inputs.mrope_pos_ids + if mrope_pos_ids is not None: + index = inputs.seq_length.cumsum(0) - 1 + mrope_pos_ids = mrope_pos_ids[:, index] + 1 return ModelInputs( input_ids=input_ids, seq_length=torch.full_like(inputs.seq_length, max_q_seqlen), @@ -38,6 +44,7 @@ def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, local_adapter_ids=inputs.local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py index 7c1910311a..1fcff049e0 100644 --- a/lmdeploy/pytorch/strategies/ar/model_inputs.py +++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs def merge_model_inputs(inputs: ModelInputs, other: ModelInputs) -> ModelInputs: @@ -51,6 +51,12 @@ def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int): if inputs.state_offsets is not None: state_offsets = torch.cat([inputs.state_offsets, other.state_offsets], dim=0) + # mrope + mrope_pos_ids = None + if inputs.mrope_pos_ids is not None: + assert other.mrope_pos_ids is not None + mrope_pos_ids = torch.cat([inputs.mrope_pos_ids, other.mrope_pos_ids], dim=1) + return ModelInputs( input_ids=input_ids, seq_length=seq_length, @@ -64,6 +70,7 @@ def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int): local_adapter_ids=local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) @@ -74,14 +81,16 @@ def make_dummy(self, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Create dummy model inputs.""" return make_dummy_inputs(batch_size, max_q_seqlen=1, is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) @record_function('ModelInputs.merge') def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs: @@ -140,6 +149,11 @@ def index_select(inputs: ModelInputs, if target_position_ids is not None: target_position_ids = target_position_ids[indices] + # mrope + mrope_pos_ids = inputs.mrope_pos_ids + if mrope_pos_ids is not None: + mrope_pos_ids = mrope_pos_ids[:, indices] + # return new inputs return ModelInputs( input_ids=input_ids, @@ -156,6 +170,7 @@ def index_select(inputs: ModelInputs, state_offsets=state_offsets, target_hidden_states=target_hidden_states, target_position_ids=target_position_ids, + mrope_pos_ids=mrope_pos_ids, ) @record_function('ModelInputs.update_inputs') diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index b9b277f961..39754da69f 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -57,6 +57,8 @@ def update_token_ids(self, if model_meta is not None: self.model_meta = model_meta + self._update_mrope_pos_ids() + def set_step(self, step: int): """Set step.""" num_all_ids = self.num_all_ids diff --git a/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py index b8ffc94352..aecf6caf05 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py @@ -5,7 +5,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta from ..ar.model_inputs import merge_model_inputs -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs class ARSpecModelInputsStrategy(ModelInputsStrategy): @@ -23,6 +23,7 @@ def make_dummy( max_q_seqlen: int = 1, target_hidden_size: int = None, target_dtype: torch.dtype = torch.bfloat16, + meta: MakeDummyMeta | None = None, ) -> ModelInputs: """Create dummy model inputs.""" inputs = make_dummy_inputs(batch_size, @@ -30,7 +31,8 @@ def make_dummy( is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) if target_hidden_size is not None: inputs.target_hidden_states = torch.randn((1, batch_size * max_q_seqlen, target_hidden_size), dtype=target_dtype, diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py index 7089bce3d0..828310dd3b 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/sequence.py +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -136,6 +136,8 @@ def update_token_ids(self, if model_meta is not None: self.model_meta = model_meta + self._update_mrope_pos_ids() + class ARSpecSequenceStrategy(ARSequenceStrategy): diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py index 8c18420e0b..1f06772735 100644 --- a/lmdeploy/pytorch/strategies/base/model_inputs.py +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -1,20 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from dataclasses import dataclass import torch from torch.profiler import record_function +from lmdeploy.pytorch.config import ModelConfig from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta +@dataclass +class MakeDummyMeta: + """Make dummy meta for model inputs strategy.""" + # Add any fields needed for making dummy inputs + use_ssm: bool = False + use_mrope: bool = False + + @record_function('make_dummy_input') def make_dummy_inputs(batch_size: int, max_q_seqlen: int, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1): + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Make dummy inputs global implement.""" + if meta is None: + meta = MakeDummyMeta() num_tokens = batch_size * max_q_seqlen max_kv_seqlen = max_q_seqlen input_ids = torch.randint(0, vocab_size, ( @@ -26,7 +39,14 @@ def make_dummy_inputs(batch_size: int, block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) - state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device) + + state_offsets = None + if meta.use_ssm: + state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device) + + mrope_pos_ids = None + if meta.use_mrope: + mrope_pos_ids = torch.zeros(3, num_tokens, dtype=torch.long, device=device) return ModelInputs( input_ids=input_ids, @@ -41,18 +61,27 @@ def make_dummy_inputs(batch_size: int, local_adapter_ids=local_adapter_ids, is_dummy=True, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) class ModelInputsStrategy(ABC): + def create_make_dummy_meta(self, model_config: ModelConfig): + """Create make dummy meta.""" + return MakeDummyMeta( + use_ssm=len(model_config.states_shapes) > 0, + use_mrope=model_config.use_mrope, + ) + @abstractmethod def make_dummy(self, batch_size: int, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Create dummy model inputs.""" pass diff --git a/lmdeploy/pytorch/strategies/dllm/model_inputs.py b/lmdeploy/pytorch/strategies/dllm/model_inputs.py index 151b952d0a..6cf291360f 100644 --- a/lmdeploy/pytorch/strategies/dllm/model_inputs.py +++ b/lmdeploy/pytorch/strategies/dllm/model_inputs.py @@ -2,7 +2,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta from ..ar.model_inputs import merge_model_inputs -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs class DLLMModelInputsStrategy(ModelInputsStrategy): @@ -10,19 +10,23 @@ class DLLMModelInputsStrategy(ModelInputsStrategy): def __init__(self, block_size: int): self.block_size = block_size - def make_dummy(self, - batch_size: int, - is_decoding: bool, - device: str = 'cpu', - dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + def make_dummy( + self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1, + meta: MakeDummyMeta | None = None, + ) -> ModelInputs: """Create dummy model inputs.""" return make_dummy_inputs(batch_size, max_q_seqlen=self.block_size, is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs: """Merge model inputs.""" From e35c90586921174ca96c86ae7810848475581997 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 3 Mar 2026 11:31:53 +0800 Subject: [PATCH 2/9] remove make and fill buffer --- .../pytorch/backends/cuda/graph_runner.py | 2 + lmdeploy/pytorch/models/glm4_1v.py | 34 +-------------- lmdeploy/pytorch/models/qwen2_5_vl.py | 34 +-------------- lmdeploy/pytorch/models/qwen2_vl.py | 34 +-------------- lmdeploy/pytorch/models/qwen3_5.py | 42 +------------------ lmdeploy/pytorch/models/qwen3_next.py | 25 +---------- lmdeploy/pytorch/models/qwen3_vl.py | 34 +-------------- lmdeploy/pytorch/models/utils/cudagraph.py | 35 +++++++++++++++- 8 files changed, 42 insertions(+), 198 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 58d093cf9b..c31fdb61e2 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -89,6 +89,8 @@ def __init__( mla_index_topk=getattr(self.model_config, 'mla_index_topk', None), decode_query_len=decode_query_len, use_fa3_decoding=model_config.model_paradigm == 'ar_spec', + is_ssm=len(model_config.states_shapes) > 0, + use_mrope=model_config.use_mrope, ) self.device = device self.max_batches = max_batches diff --git a/lmdeploy/pytorch/models/glm4_1v.py b/lmdeploy/pytorch/models/glm4_1v.py index b8fbcee1e6..dbefc845d5 100644 --- a/lmdeploy/pytorch/models/glm4_1v.py +++ b/lmdeploy/pytorch/models/glm4_1v.py @@ -17,7 +17,7 @@ from .glm4 import Glm4DecoderLayer from .qwen2_vl import Qwen2VLInputProcessor as Glm4vInputProcessor -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -717,38 +717,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 7f31994f69..7a3e3b73d6 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import add_prefix -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -568,38 +568,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 4e2accb220..65834660ec 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -16,7 +16,7 @@ build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, build_embedding, vlm_model @@ -772,38 +772,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen3_5.py b/lmdeploy/pytorch/models/qwen3_5.py index 6fc7f05999..a88c1cb0dd 100644 --- a/lmdeploy/pytorch/models/qwen3_5.py +++ b/lmdeploy/pytorch/models/qwen3_5.py @@ -24,7 +24,7 @@ from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3_5InputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3_5VisionAttention -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -1175,46 +1175,6 @@ def __skip_layers(name): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - - max_batchs = graph_meta.max_batchs - device = graph_meta.device - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device) - input_buffers['state_ids'] = state_ids - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - input_buffers = graph_meta.input_buffers - new_inputs = super().fill_buffers_cudagraph(graph_meta, *args, **kwargs) - state_ids = kwargs['state_ids'] - input_buffers['state_ids'].fill_(-1) - input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) - new_inputs['state_ids'] = input_buffers['state_ids'] - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py index 99a2c26c07..100d24c88a 100644 --- a/lmdeploy/pytorch/models/qwen3_next.py +++ b/lmdeploy/pytorch/models/qwen3_next.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, build_embedding @@ -705,29 +705,6 @@ def prepare_inputs_for_generation( state_ids=context.state_offsets, ) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_batchs = graph_meta.max_batchs - device = graph_meta.device - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device) - input_buffers['state_ids'] = state_ids - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - input_buffers = graph_meta.input_buffers - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - state_ids = kwargs['state_ids'] - input_buffers['state_ids'].fill_(-1) - input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) - new_inputs['state_ids'] = input_buffers['state_ids'] - - return new_inputs - def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], expert_params_mapping: List): """Load weight experts.""" diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 0b2b9acb9c..323d757e95 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -21,7 +21,7 @@ from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3VLInputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention from .qwen3 import Qwen3model -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -718,38 +718,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 2b5a4dc8ad..f704fb4646 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -80,6 +80,8 @@ class CudaGraphMeta: mla_index_topk: Optional[int] = None decode_query_len: int = 1 use_fa3_decoding: bool = False + is_ssm: bool = False + use_mrope: bool = False class CudaGraphMixin: @@ -190,6 +192,15 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_valu max_seqlen_k=decode_query_len, cache_seqlens=input_buffers['kv_seqlens']) + # mrope + if graph_meta.use_mrope: + input_buffers['mrope_position_ids'] = torch.zeros(3, max_tokens, dtype=torch.int64, device=device) + + # ssm + if graph_meta.is_ssm: + state_ids = torch.full((max_batches, ), -1, dtype=torch.int64, device=device) + input_buffers['state_ids'] = state_ids + return input_buffers @record_function('fill_buffers_cudagraph') @@ -269,6 +280,7 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p past_key_values=past_key_values, attn_metadata=attn_metadata, ) + new_inputs.update(kwargs) new_inputs['input_ids'] = input_buffers['input_ids'] new_inputs['position_ids'] = input_buffers['position_ids'] @@ -276,7 +288,20 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p if inputs_embeds is not None: new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] - new_inputs.update(kwargs) + # mrope + if graph_meta.use_mrope: + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] + + # ssm + if graph_meta.is_ssm: + state_ids = kwargs['state_ids'] + input_buffers['state_ids'].fill_(-1) + input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) + new_inputs['state_ids'] = input_buffers['state_ids'] + return new_inputs def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepContext): @@ -293,6 +318,14 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte context.kv_seqlens = input_buffers['kv_seqlens'] context.q_start_loc = input_buffers['q_start_loc'] + # mrope + if graph_meta.use_mrope: + context.mrope_pos_ids = input_buffers['mrope_position_ids'] + + # ssm + if graph_meta.is_ssm: + context.state_offsets = input_buffers['state_ids'] + def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs): """Get outputs from buffers.""" num_tokens = input_ids.size(-1) From dc5f0525f980aa9a42bea164c75bfe6d0e01baff Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 3 Mar 2026 11:52:00 +0800 Subject: [PATCH 3/9] fix context --- lmdeploy/pytorch/model_inputs.py | 4 ++-- lmdeploy/pytorch/models/utils/cudagraph.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index d8c8eda2ce..e35d4f9e36 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -276,7 +276,7 @@ class StepContext: state_offsets: torch.LongTensor | None = None # mrope - mrope_pos_ids: torch.Tensor | None = None + mrope_position_ids: torch.Tensor | None = None _outputs: Dict = field(default_factory=dict) @@ -343,7 +343,7 @@ def new( state_caches=state_caches, state_offsets=inputs.state_offsets, target_hidden_states=inputs.target_hidden_states, - mrope_pos_ids=inputs.mrope_pos_ids, + mrope_position_ids=inputs.mrope_pos_ids, ) ret = get_backend().update_step_context(ret) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index f704fb4646..4c2da7ab46 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -320,7 +320,7 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte # mrope if graph_meta.use_mrope: - context.mrope_pos_ids = input_buffers['mrope_position_ids'] + context.mrope_position_ids = input_buffers['mrope_position_ids'] # ssm if graph_meta.is_ssm: From 3a46034e55faecc8712c3c16f9848a666fa184a2 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 4 Mar 2026 11:03:00 +0800 Subject: [PATCH 4/9] update typehint --- .../pytorch/backends/cuda/graph_runner.py | 12 +++--- lmdeploy/pytorch/engine/engine.py | 22 +++++----- lmdeploy/pytorch/engine/inputs_maker.py | 18 ++++---- lmdeploy/pytorch/model_inputs.py | 42 +++++++++---------- lmdeploy/pytorch/models/utils/cudagraph.py | 25 +++++------ lmdeploy/pytorch/multimodal/data_type.py | 8 ++-- lmdeploy/pytorch/paging/scheduler.py | 9 ++-- 7 files changed, 68 insertions(+), 68 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index c31fdb61e2..2d64c88d70 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools -from typing import Any, Dict, List, Tuple +from typing import Any import torch from torch.profiler import record_function @@ -66,7 +66,7 @@ def __init__( max_tokens: int, num_blocks: int, is_decoding: bool, - pool: Tuple[int, int], + pool: tuple[int, int], model_config: ModelConfig, device: torch.device, decode_query_len: int = 1, @@ -155,7 +155,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.enable_graph = self.check_enable_graph() self.graph_pool_handle = torch.cuda.graph_pool_handle() - self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() + self._runner_map: dict[Any, CUDASingleGraphRunner] = dict() self.has_try_compile_model: bool = False # strategy factory @@ -189,7 +189,7 @@ def _get_capture_tokens(self, batch_size: int): return size assert False, f'Unsupported batch_size={batch_size}' - def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, + def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list, attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): """Get graph key.""" context = self.ctx_mgr.current_context() @@ -263,7 +263,7 @@ def __call__(self, **kwargs): @record_function('prepare_inputs_for_generation') def prepare_inputs_for_generation( self, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], inputs_embeds: torch.Tensor = None, context: StepContext = None, ): @@ -305,6 +305,6 @@ def update_inputs(self, inputs): dp_meta.sync_tp_size(tp_size) return inputs - def get_capture_batch_sizes(self) -> List[int]: + def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index ad1fae4901..de160a0276 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -3,7 +3,7 @@ import gc import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any import numpy as np import torch @@ -28,7 +28,7 @@ logger = get_logger('lmdeploy') -SeqList = List[SchedulerSequence] +SeqList = list[SchedulerSequence] @dataclass @@ -37,7 +37,7 @@ class InferOutput: session_id: int resp: Response - token_ids: Union[np.ndarray, List[int]] + token_ids: np.ndarray | list[int] meta: Any = None finish: bool = False logits: torch.Tensor = None @@ -45,7 +45,7 @@ class InferOutput: # send cache blocks back for migration in Disaggregated LLM Serving # when Prefill Engine is Done. - cache_block_ids: List[int] = None + cache_block_ids: list[int] = None # for logging req_metrics: RequestMetrics = None @@ -233,7 +233,7 @@ def from_pretrained(cls, speculative_config=speculative_config, ) - def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig): + def _download_adapters(self, adapters: dict[str, str], engine_config: PytorchEngineConfig): """Download adapters.""" download_dir = engine_config.download_dir revision = engine_config.revision @@ -278,7 +278,7 @@ def _get_max_session_len(self): session_len = min(max_tokens, session_len) return session_len - def _on_add_session(self, reqs: List[Request], **kwargs): + def _on_add_session(self, reqs: list[Request], **kwargs): """On add session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -290,7 +290,7 @@ def _on_add_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_stop_session(self, reqs: List[Request], **kwargs): + def _on_stop_session(self, reqs: list[Request], **kwargs): """On stop session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -309,7 +309,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_end_session(self, reqs: List[Request], **kwargs): + def _on_end_session(self, reqs: list[Request], **kwargs): """On end session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -325,7 +325,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_add_message(self, reqs: List[Request], **kwargs): + def _on_add_message(self, reqs: list[Request], **kwargs): """On add message callback.""" valid_reqs = [] for req in reqs: @@ -363,7 +363,7 @@ def _on_add_message(self, reqs: List[Request], **kwargs): if len(valid_reqs) > 0: self._add_message(valid_reqs) - def _add_message(self, reqs: List[Request]): + def _add_message(self, reqs: list[Request]): def __update_max_new_tokens(msg): """Update max new tokens.""" @@ -444,7 +444,7 @@ def sleep(self, level: int = 1): """Sleep.""" self.executor.sleep(level) - def wakeup(self, tags: Optional[List[str]] = None): + def wakeup(self, tags: list[str] | None = None): """Wakeup.""" self.executor.wakeup(tags) diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 57094f5ea7..dff099d892 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import numpy as np import torch @@ -222,8 +222,8 @@ def __init__( # running seqs # mark the seqs that have been sent to executor - self.running_seqs: List['SchedulerSequence'] = [] - self.to_evict_seqs: List['SchedulerSequence'] = [] + self.running_seqs: list['SchedulerSequence'] = [] + self.to_evict_seqs: list['SchedulerSequence'] = [] # long context chunker self.long_context_chunker = LongContextChunker(config.max_prefill_token_num) @@ -394,7 +394,7 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): def create_model_inputs_long_context(self, seq: 'SchedulerSequence', chunk_size: int, - multimodals: Optional['MultiModalInputs'] = None): + multimodals: 'MultiModalInputs|None' = None): """Create model inputs for long context messages.""" token_ids = seq.token_ids[:chunk_size] input_ids = torch.as_tensor(token_ids)[None] @@ -467,8 +467,8 @@ def create_model_inputs_delta(self): valid_mask = np.array(valid_mask) indices_cpu = np.arange(0, batch_size)[valid_mask] - valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] - invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] + valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] + invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] if len(valid_seqs) == 0: return None, valid_seqs, invalid_seqs @@ -512,8 +512,8 @@ def create_model_inputs_delta_valid_only(self): valid_mask = np.array(valid_mask, dtype=bool) indices_cpu = np.arange(0, batch_size)[valid_mask] - valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] - invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] + valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] + invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] num_decode_tokens = self.engine_strategy.get_num_decode_tokens() max_q_seqlen = num_decode_tokens @@ -537,7 +537,7 @@ def create_model_inputs_delta_valid_only(self): return output, valid_seqs, invalid_seqs - def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]): + def update_running_seqs(self, running: 'SeqList', inputs: 'ModelInputs|None'): """Update running seqs.""" if self.config.role == EngineRole.Prefill: # p node will not update running seqs diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index e35d4f9e36..7351487989 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch @@ -20,11 +20,11 @@ @dataclass class DPMeta: - tp_sizes: List[int] = None - moe_tp_sizes: List[int] = None + tp_sizes: list[int] = None + moe_tp_sizes: list[int] = None @staticmethod - def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist.DistContext, layer_type: str): + def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: list[int], dist_ctx: dist.DistContext, layer_type: str): """Gather tp size.""" attn_tp = dist_ctx.dist_config.attn_tp if tp > 1 and tp != attn_tp: @@ -38,7 +38,7 @@ def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist return tp_sizes @classmethod - def build(cls, seqlen: int, num_tokens: List[int]): + def build(cls, seqlen: int, num_tokens: list[int]): """Get dp meta.""" dist_ctx = dist.get_dist_manager().current_context() dist_config = dist_ctx.dist_config @@ -63,10 +63,10 @@ def sync_tp_size(self, tp_size: int): class VisionModelInputs: """Vision model inputs.""" history_lengths: torch.LongTensor = None - input_embeddings: List[List[torch.Tensor]] = None - input_embedding_ranges: List[torch.LongTensor] = None + input_embeddings: list[list[torch.Tensor]] = None + input_embedding_ranges: list[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None - input_multimodals: List[MultiModalTensor] = None + input_multimodals: list[MultiModalTensor] = None def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -125,7 +125,7 @@ def get_inputs(self, history_lengths: torch.Tensor, seq_lengths: torch.Tensor): class ModelInputsDelta: """Delta of ModelInputs.""" # valid indices - indices: Optional[torch.Tensor] + indices: torch.Tensor | None # new block offsets block_offsets: torch.Tensor # cpu copy of indices @@ -135,7 +135,7 @@ class ModelInputsDelta: sum_kv_seqlen: int is_decoding: bool = True # sliding window - num_ignored_history: Optional[torch.Tensor] = None + num_ignored_history: torch.Tensor | None = None @property def seq_length(self): @@ -184,7 +184,7 @@ class ModelInputs: sum_kv_seqlen: int local_adapter_ids: torch.Tensor | None = None vision_inputs: VisionModelInputs | None = None - model_metas: List[Dict[str, Any]] | None = None + model_metas: list[dict[str, Any]] | None = None dp_meta: DPMeta | None = None enable_microbatch: bool = False is_dummy: bool = False @@ -228,7 +228,7 @@ def to_device(self, device: str, non_blocking: bool = False): return ModelInputs(**out_dict) - def build_dp_meta(self, num_tokens: List[int]): + def build_dp_meta(self, num_tokens: list[int]): """Build dp meta.""" self.dp_meta = DPMeta.build(self.input_ids.numel(), num_tokens) @@ -254,31 +254,31 @@ class StepContext: q_seqlens: torch.LongTensor kv_seqlens: torch.IntTensor q_start_loc: torch.LongTensor - kv_caches: List + kv_caches: list is_decoding: bool sum_kv_seqlen: int max_kv_seqlen: int | None = None local_adapter_ids: torch.LongTensor | None = None input_embeddings: torch.Tensor | None = None input_embedding_indexing: torch.Tensor | None = None - input_multimodals: List[MultiModalTensor] | None = None + input_multimodals: list[MultiModalTensor] | None = None vision_inputs: VisionModelInputs | None = None attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 - model_metas: List[Dict[str, Any]] | None = None + model_metas: list[dict[str, Any]] | None = None dp_meta: DPMeta | None = None enable_microbatch: bool = False # for draft model target_hidden_states: torch.Tensor | None = None # states for ssm - state_caches: List | None = None + state_caches: list | None = None state_offsets: torch.LongTensor | None = None # mrope mrope_position_ids: torch.Tensor | None = None - _outputs: Dict = field(default_factory=dict) + _outputs: dict = field(default_factory=dict) @classmethod def new( @@ -286,8 +286,8 @@ def new( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List | None = None, - state_caches: List | None = None, + kv_caches: list | None = None, + state_caches: list | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build step context. @@ -418,8 +418,8 @@ def build_context( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List | None = None, - state_caches: List | None = None, + kv_caches: list | None = None, + state_caches: list | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build context.""" diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 4c2da7ab46..f4c7394cd4 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any import torch from torch import Tensor @@ -8,7 +8,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager -BuffType = Dict[str, Tensor] +BuffType = dict[str, Tensor] def _get_meta_flashattn( @@ -21,9 +21,9 @@ def _get_meta_flashattn( cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k_new: torch.Tensor | None = None, + page_size: int | None = None, causal=True, window_size=(-1, -1), # -1 means infinite context window num_splits=0, @@ -77,7 +77,7 @@ class CudaGraphMeta: vocab_size: int = 1 use_mla_fp8_cache: bool = False use_flash_mla: bool = False - mla_index_topk: Optional[int] = None + mla_index_topk: int | None = None decode_query_len: int = 1 use_fa3_decoding: bool = False is_ssm: bool = False @@ -91,7 +91,7 @@ def support_cuda_graph( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, @@ -104,7 +104,7 @@ def make_output_buffers(self, output): if isinstance(output, torch.Tensor): output_buffers = dict(hidden_states=output) else: - assert isinstance(output, Dict) + assert isinstance(output, dict) output_buffers = output return output_buffers @@ -140,7 +140,8 @@ def update_meta_flashattn(self, graph_meta: CudaGraphMeta, block_size: int, max_ ) return scheduler_metadata - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: List, **kwargs) -> BuffType: + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: list[list[torch.Tensor]], + **kwargs) -> BuffType: """Make cudagraph buffers from forward inputs.""" max_batches = graph_meta.max_batchs max_tokens = graph_meta.max_tokens @@ -205,8 +206,8 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_valu @record_function('fill_buffers_cudagraph') def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor, - past_key_values: List, attn_metadata: Any, inputs_embeds: Tensor, - **kwargs) -> Dict[str, Tensor]: + past_key_values: list[list[torch.Tensor]], attn_metadata: Any, inputs_embeds: Tensor, + **kwargs) -> dict[str, Tensor]: """Fill cudagraph buffers from forward inputs.""" block_offsets: Tensor = attn_metadata.block_offsets @@ -326,7 +327,7 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte if graph_meta.is_ssm: context.state_offsets = input_buffers['state_ids'] - def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs): + def get_outputs_cudagraph(self, output_buffers: dict[str, torch.Tensor], input_ids: Tensor, **kwargs): """Get outputs from buffers.""" num_tokens = input_ids.size(-1) outputs = dict() diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 9b35d37bdb..23c96a806d 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, fields -from typing import Any, Dict, List +from typing import Any import numpy as np from torch import Tensor -NestedTensor = Tensor | List[Tensor] +NestedTensor = Tensor | list[Tensor] @dataclass @@ -14,7 +14,7 @@ class MultiModalTensor: start: int end: int | None = None encoder_len: int | None = None - meta: Dict[str, Any] | None = None + meta: dict[str, Any] | None = None # for qwen-vl mrope_pos_ids: np.ndarray | None = None @@ -53,4 +53,4 @@ def to_device(self, device: str, non_blocking: bool = False): return MultiModalTensor(**out_dict) -MultiModalInputs = Dict[str, List[MultiModalTensor]] +MultiModalInputs = dict[str, list[MultiModalTensor]] diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 8606a6344a..c9611a0cba 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -4,7 +4,6 @@ from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, List from torch.profiler import record_function @@ -20,8 +19,8 @@ logger = get_logger('lmdeploy') -MapType = Dict[int, int] -SeqList = List[SchedulerSequence] +MapType = dict[int, int] +SeqList = list[SchedulerSequence] @dataclass @@ -50,10 +49,10 @@ def __init__( ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.sessions: Dict[int, SchedulerSession] = OrderedDict() + self.sessions: dict[int, SchedulerSession] = OrderedDict() # For Disaggregation - self.locked_sessions: Dict[int, SchedulerSession] = OrderedDict() + self.locked_sessions: dict[int, SchedulerSession] = OrderedDict() self.block_manager = build_block_manager(cache_config) self.block_trie = BlockTrie(self.cache_config, self.block_manager) From 2146481a9eca19411cbb8f684fc3d0de19df7c8b Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 4 Mar 2026 11:26:14 +0800 Subject: [PATCH 5/9] fix comment of copilot --- lmdeploy/pytorch/engine/model_agent/agent.py | 3 ++- lmdeploy/pytorch/model_inputs.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 654d64186c..2886da873c 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -430,10 +430,11 @@ def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheCo self.cache_config = cache_config self.spec_agent.set_cache_config(spec_cache_config) - def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig | None = None): """Set model config.""" self.model_config = model_config self.spec_agent.set_model_config(spec_model_config) + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config) def get_free_mem(self): """Gather available memory.""" diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 7351487989..c51bdfe5ce 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -210,7 +210,7 @@ def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): self.input_ids = input_ids if self.mrope_pos_ids is not None: - self.mrope_pos_ids = self.mrope_pos_ids + self.seq_length[None] + self.mrope_pos_ids = self.mrope_pos_ids + step_seqlens[None] return self @torch.inference_mode() From db5daf00fd8337a19cdfc2f8006b0c1d35e8315e Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 4 Mar 2026 11:32:58 +0800 Subject: [PATCH 6/9] fix comment of copilot2 --- lmdeploy/pytorch/spec_decode/spec_agent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 0b17ef0949..af99df1d63 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -55,6 +55,9 @@ def set_cache_config(self, cache_config: CacheConfig): def set_model_config(self, model_config: ModelConfig): """Set model config.""" self.model_config = model_config + if model_config is not None: + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(self.model_config) def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None): """Build draft model.""" From 1d5224bb5f47222373f5f824bdd44d20098c6e37 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 4 Mar 2026 18:44:12 +0800 Subject: [PATCH 7/9] remove index select --- lmdeploy/pytorch/nn/gated_delta.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/nn/gated_delta.py b/lmdeploy/pytorch/nn/gated_delta.py index c61dcab6b5..01038f889f 100644 --- a/lmdeploy/pytorch/nn/gated_delta.py +++ b/lmdeploy/pytorch/nn/gated_delta.py @@ -76,19 +76,16 @@ def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, state_ids = gated_delta_meta.state_ids assert x.dim() == 3 - x = x.transpose(-2, -1) if weight.dim() == 3: assert weight.size(1) == 1 weight = weight[:, 0] # fill conv state - # TODO: find efficient way to fill conv state without gather + scatter - final_state = conv_state.index_select(0, state_ids) - batch_size = conv_state.size(0) - conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1) - torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=final_state) + final_state = x[0, conv_idx].transpose(-2, -1) conv_state = conv_state.index_copy_(0, state_ids, final_state) + # note that we have not set init states + x = x.transpose(-2, -1) out = self.causal_conv1d_fn( x, weight, From 51b4b17f06f87b781170ad4a60009e2ed811f5f0 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 5 Mar 2026 10:58:01 +0800 Subject: [PATCH 8/9] rename mmdata --- lmdeploy/pytorch/model_inputs.py | 6 +++--- lmdeploy/pytorch/models/chatglm2.py | 10 +++++----- lmdeploy/pytorch/models/cogvlm.py | 10 +++++----- lmdeploy/pytorch/models/deepseek_vl2.py | 16 ++++++++-------- lmdeploy/pytorch/models/gemma3_vl.py | 10 +++++----- lmdeploy/pytorch/models/interns1_pro.py | 18 +++++++++--------- lmdeploy/pytorch/models/internvl.py | 10 +++++----- lmdeploy/pytorch/models/internvl3_hf.py | 10 +++++----- lmdeploy/pytorch/models/llama4.py | 10 +++++----- lmdeploy/pytorch/models/llava.py | 18 +++++++++--------- lmdeploy/pytorch/models/phi3_v.py | 10 +++++----- lmdeploy/pytorch/models/qwen2_vl.py | 12 ++++++------ lmdeploy/pytorch/multimodal/__init__.py | 4 ++-- lmdeploy/pytorch/multimodal/data_type.py | 6 +++--- 14 files changed, 75 insertions(+), 75 deletions(-) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index c51bdfe5ce..d1519a46ad 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -11,7 +11,7 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig, QuantizationConfig -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.utils import CtxMgrBase, singleton if TYPE_CHECKING: @@ -66,7 +66,7 @@ class VisionModelInputs: input_embeddings: list[list[torch.Tensor]] = None input_embedding_ranges: list[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None - input_multimodals: list[MultiModalTensor] = None + input_multimodals: list[MultiModalData] = None def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -261,7 +261,7 @@ class StepContext: local_adapter_ids: torch.LongTensor | None = None input_embeddings: torch.Tensor | None = None input_embedding_indexing: torch.Tensor | None = None - input_multimodals: list[MultiModalTensor] | None = None + input_multimodals: list[MultiModalData] | None = None vision_inputs: VisionModelInputs | None = None attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 56e3169bb7..b690217564 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -9,7 +9,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding, build_rotary_params) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj, @@ -866,10 +866,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index ad8adc9739..9b8d0b5472 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) @@ -901,10 +901,10 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/deepseek_vl2.py b/lmdeploy/pytorch/models/deepseek_vl2.py index 290b9a4fc0..b778c6ebeb 100644 --- a/lmdeploy/pytorch/models/deepseek_vl2.py +++ b/lmdeploy/pytorch/models/deepseek_vl2.py @@ -11,7 +11,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .deepseek_v2 import DeepseekV2ForCausalLM @@ -440,13 +440,13 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict( - image_token_id=image_token_id, - images_spatial_crop=images_spatial_crop, - )) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_token_id=image_token_id, + images_spatial_crop=images_spatial_crop, + )) input_imgs.append(mm_data) diff --git a/lmdeploy/pytorch/models/gemma3_vl.py b/lmdeploy/pytorch/models/gemma3_vl.py index 8f4ea8e972..cff9615df2 100644 --- a/lmdeploy/pytorch/models/gemma3_vl.py +++ b/lmdeploy/pytorch/models/gemma3_vl.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config @@ -108,10 +108,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/interns1_pro.py b/lmdeploy/pytorch/models/interns1_pro.py index 51ed9deaf6..77f7b57f93 100644 --- a/lmdeploy/pytorch/models/interns1_pro.py +++ b/lmdeploy/pytorch/models/interns1_pro.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .interns1_pro_ts import InternS1ProTimeSeriesModel @@ -383,10 +383,10 @@ def preprocess_input(self, ts_sr = input_mm['ts_sr'] num_pad = input_mm['num_ts_tokens'] - mm_data = MultiModalTensor(data=ts_values, - start=offset, - end=offset + num_pad, - meta=dict(ts_token_id=ts_token_id, ts_lens=ts_lens, ts_sr=ts_sr)) + mm_data = MultiModalData(data=ts_values, + start=offset, + end=offset + num_pad, + meta=dict(ts_token_id=ts_token_id, ts_lens=ts_lens, ts_sr=ts_sr)) else: pixel_values = input_mm['pixel_values'].to(self.dtype) image_grid_thw = input_mm['image_grid_thw'] @@ -397,10 +397,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 5b6c261dd2..43b80644f6 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -992,10 +992,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 7cd4cd940c..4ea2eb2f45 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -13,7 +13,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -736,10 +736,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/llama4.py b/lmdeploy/pytorch/models/llama4.py index 4b3c2196bc..e7711b83d3 100644 --- a/lmdeploy/pytorch/models/llama4.py +++ b/lmdeploy/pytorch/models/llama4.py @@ -8,7 +8,7 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) @@ -1033,10 +1033,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index e87242df4c..4004441050 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -11,7 +11,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -555,10 +555,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( @@ -834,10 +834,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py index c6804d5586..aff3f78935 100644 --- a/lmdeploy/pytorch/models/phi3_v.py +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .phi3 import Phi3ForCausalLM, Phi3Model @@ -379,10 +379,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 65834660ec..78fdae83d1 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -9,7 +9,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, LayerNorm, RMSNorm, SiluAndMul, build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, @@ -824,11 +824,11 @@ def preprocess_input(self, mrope_pos_ids = self.make_mrope(image_grid_thw) - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - mrope_pos_ids=mrope_pos_ids, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=start, + end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py index 54f819bba0..fc2d5890d9 100644 --- a/lmdeploy/pytorch/multimodal/__init__.py +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .data_type import MultiModalTensor +from .data_type import MultiModalData -__all__ = ['MultiModalTensor'] +__all__ = ['MultiModalData'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 23c96a806d..34c545ce0e 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -9,7 +9,7 @@ @dataclass -class MultiModalTensor: +class MultiModalData: data: NestedTensor start: int end: int | None = None @@ -50,7 +50,7 @@ def to_device(self, device: str, non_blocking: bool = False): new_meta[k] = v out_dict['meta'] = new_meta - return MultiModalTensor(**out_dict) + return MultiModalData(**out_dict) -MultiModalInputs = dict[str, list[MultiModalTensor]] +MultiModalInputs = dict[str, list[MultiModalData]] From 3e8780821291846d1497046e6da21422aff17425 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 12 Mar 2026 14:38:55 +0800 Subject: [PATCH 9/9] merge main --- docker/Dockerfile | 20 +- docker/Dockerfile_dev | 17 +- docker/build.sh | 18 +- docker/install.sh | 129 +-- docker/prepare_wheel.sh | 74 +- lmdeploy/archs.py | 4 +- lmdeploy/messages.py | 6 + lmdeploy/pytorch/engine/engine_loop.py | 3 +- lmdeploy/pytorch/engine/logits_process.py | 188 ++++- lmdeploy/pytorch/envs.py | 3 + lmdeploy/pytorch/messages.py | 22 +- lmdeploy/pytorch/strategies/ar/sampling.py | 41 +- lmdeploy/pytorch/strategies/dllm/sampling.py | 24 +- lmdeploy/serve/core/async_engine.py | 2 +- lmdeploy/serve/managers/session_manager.py | 18 +- lmdeploy/serve/openai/api_server.py | 2 + lmdeploy/serve/openai/protocol.py | 4 +- lmdeploy/serve/openai/tool_parser/__init__.py | 2 + .../openai/tool_parser/qwen3coder_parser.py | 235 ++++++ lmdeploy/turbomind/deploy/config.py | 12 + lmdeploy/turbomind/deploy/converter.py | 20 +- lmdeploy/turbomind/deploy/module.py | 198 ++++- .../turbomind/deploy/source_model/llama.py | 7 +- .../turbomind/deploy/source_model/qwen.py | 247 ++++++ lmdeploy/turbomind/supported_models.py | 3 + lmdeploy/utils.py | 31 +- src/turbomind/core/module.h | 5 + src/turbomind/engine/request.h | 9 + .../kernels/attention/CMakeLists.txt | 58 +- src/turbomind/kernels/attention/arch.h | 19 + src/turbomind/kernels/attention/attention.cu | 56 +- .../kernels/attention/attention_config.h | 101 --- .../kernels/attention/attention_template.h | 19 +- .../codegen/attention_sm70_128_f16.cu | 16 - .../codegen/attention_sm70_576_f16.cu | 13 - .../codegen/attention_sm70_64_f16.cu | 16 - .../codegen/attention_sm75_128_f16.cu | 17 - .../codegen/attention_sm75_576_f16.cu | 13 - .../codegen/attention_sm75_64_f16.cu | 17 - .../codegen/attention_sm80_128_bf16.cu | 16 - .../codegen/attention_sm80_128_f16.cu | 16 - .../attention/codegen/attention_sm80_192.cu | 16 - .../codegen/attention_sm80_576_bf16.cu | 13 - .../codegen/attention_sm80_576_f16.cu | 13 - .../codegen/attention_sm80_64_bf16.cu | 16 - .../codegen/attention_sm80_64_f16.cu | 16 - .../codegen/decoding_sm70_128_f16_f16.cu | 16 - .../codegen/decoding_sm70_128_f16_u4.cu | 17 - .../codegen/decoding_sm70_128_f16_u8.cu | 17 - .../codegen/decoding_sm70_576_f16_f16.cu | 16 - .../codegen/decoding_sm70_576_f16_u4.cu | 17 - .../codegen/decoding_sm70_576_f16_u8.cu | 17 - .../codegen/decoding_sm70_64_f16_f16.cu | 16 - .../codegen/decoding_sm70_64_f16_u4.cu | 17 - .../codegen/decoding_sm70_64_f16_u8.cu | 17 - .../codegen/decoding_sm75_128_f16_f16.cu | 14 - .../codegen/decoding_sm75_128_f16_u4.cu | 14 - .../codegen/decoding_sm75_128_f16_u8.cu | 14 - .../codegen/decoding_sm75_576_f16_f16.cu | 14 - .../codegen/decoding_sm75_576_f16_u4.cu | 15 - .../codegen/decoding_sm75_576_f16_u8.cu | 15 - .../codegen/decoding_sm75_64_f16_f16.cu | 14 - .../codegen/decoding_sm75_64_f16_u4.cu | 14 - .../codegen/decoding_sm75_64_f16_u8.cu | 14 - .../codegen/decoding_sm80_128_bf16_bf16.cu | 22 - .../codegen/decoding_sm80_128_bf16_u4.cu | 14 - .../codegen/decoding_sm80_128_bf16_u8.cu | 14 - .../codegen/decoding_sm80_128_f16_f16.cu | 18 - .../codegen/decoding_sm80_128_f16_u4.cu | 14 - .../codegen/decoding_sm80_128_f16_u8.cu | 14 - .../attention/codegen/decoding_sm80_192.cu | 20 - .../codegen/decoding_sm80_576_bf16_bf16.cu | 22 - .../codegen/decoding_sm80_576_bf16_u4.cu | 14 - .../codegen/decoding_sm80_576_bf16_u8.cu | 14 - .../codegen/decoding_sm80_576_f16_f16.cu | 18 - .../codegen/decoding_sm80_576_f16_u4.cu | 15 - .../codegen/decoding_sm80_576_f16_u8.cu | 15 - .../codegen/decoding_sm80_64_bf16_bf16.cu | 22 - .../codegen/decoding_sm80_64_bf16_u4.cu | 14 - .../codegen/decoding_sm80_64_bf16_u8.cu | 14 - .../codegen/decoding_sm80_64_f16_f16.cu | 18 - .../codegen/decoding_sm80_64_f16_u4.cu | 14 - .../codegen/decoding_sm80_64_f16_u8.cu | 14 - src/turbomind/kernels/attention/decoding.cu | 139 +--- .../kernels/attention/decoding_config.h | 98 --- .../kernels/attention/decoding_template.h | 19 +- src/turbomind/kernels/attention/desc.h | 75 ++ src/turbomind/kernels/attention/impl_16816.h | 3 + src/turbomind/kernels/attention/impl_1688.h | 2 + src/turbomind/kernels/attention/impl_81616.h | 104 ++- src/turbomind/kernels/attention/impl_884.h | 12 +- src/turbomind/kernels/attention/impl_simt.h | 89 ++- .../kernels/attention/iterator_sm80.h | 1 + src/turbomind/kernels/attention/kernel.h | 47 ++ .../kernels/attention/kernel/CMakeLists.txt | 37 + .../attention/kernel/attention_sm70_128.cu | 30 + .../attention/kernel/attention_sm70_256.cu | 30 + .../attention/kernel/attention_sm70_576.cu | 31 + .../attention/kernel/attention_sm70_64.cu | 30 + .../attention/kernel/attention_sm75_128.cu | 30 + .../attention/kernel/attention_sm75_256.cu | 30 + .../attention/kernel/attention_sm75_576.cu | 30 + .../attention/kernel/attention_sm75_64.cu | 30 + .../attention/kernel/attention_sm80_128.cu | 36 + .../attention/kernel/attention_sm80_192.cu | 36 + .../attention/kernel/attention_sm80_256.cu | 36 + .../attention/kernel/attention_sm80_576.cu | 36 + .../attention/kernel/attention_sm80_64.cu | 37 + .../attention/kernel/decoding_sm70_128.cu | 44 ++ .../attention/kernel/decoding_sm70_256.cu | 44 ++ .../attention/kernel/decoding_sm70_576.cu | 31 + .../attention/kernel/decoding_sm70_64.cu | 44 ++ .../attention/kernel/decoding_sm75_128.cu | 40 + .../attention/kernel/decoding_sm75_256.cu | 41 + .../attention/kernel/decoding_sm75_576.cu | 31 + .../attention/kernel/decoding_sm75_64.cu | 40 + .../attention/kernel/decoding_sm80_128.cu | 59 ++ .../attention/kernel/decoding_sm80_192.cu | 41 + .../attention/kernel/decoding_sm80_256.cu | 59 ++ .../attention/kernel/decoding_sm80_576.cu | 47 ++ .../attention/kernel/decoding_sm80_64.cu | 59 ++ src/turbomind/kernels/attention/kernel_impl.h | 80 ++ .../kernels/attention/kv_cache_utils_v2.cu | 46 +- .../kernels/attention/mainloop_sm80.h | 104 ++- src/turbomind/kernels/attention/reduce.cu | 2 + src/turbomind/kernels/attention/registrar.h | 46 ++ src/turbomind/kernels/attention/registry.cu | 115 +++ src/turbomind/kernels/attention/registry.h | 45 ++ src/turbomind/kernels/core/thread_map.h | 21 +- src/turbomind/kernels/norm/rms_norm.cu | 63 +- src/turbomind/models/CMakeLists.txt | 5 +- src/turbomind/models/llama/CMakeLists.txt | 3 + .../models/llama/GatedDeltaNetLayer.cc | 312 ++++++++ .../models/llama/GatedDeltaNetLayer.h | 66 ++ .../models/llama/GatedDeltaNetWeight.cc | 164 ++++ .../models/llama/GatedDeltaNetWeight.h | 52 ++ .../models/llama/LlamaDecoderLayerWeight.cc | 82 +- .../models/llama/LlamaDecoderLayerWeight.h | 2 + .../models/llama/LlamaDenseWeight.cc | 14 +- src/turbomind/models/llama/LlamaDenseWeight.h | 3 +- .../models/llama/gated_delta_net_kernels.cu | 741 ++++++++++++++++++ .../models/llama/gated_delta_net_kernels.h | 128 +++ src/turbomind/models/llama/llama_kernels.cu | 29 + src/turbomind/models/llama/llama_kernels.h | 7 + src/turbomind/models/llama/llama_params.h | 16 + src/turbomind/models/llama/moe_ffn_layer.cc | 3 +- .../models/llama/unified_attention_layer.cc | 30 +- src/turbomind/models/llama/unified_decoder.cc | 31 +- src/turbomind/models/llama/unified_decoder.h | 2 + src/turbomind/turbomind.cc | 29 + tests/pytorch/engine/test_logits_process.py | 63 ++ tests/test_lmdeploy/test_qwen3coder_parser.py | 255 ++++++ 152 files changed, 4925 insertions(+), 1645 deletions(-) create mode 100644 lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py delete mode 100644 src/turbomind/kernels/attention/attention_config.h delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm70_128_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm70_576_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm75_128_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm75_576_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_128_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_128_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_192.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_576_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_576_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu delete mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu delete mode 100644 src/turbomind/kernels/attention/decoding_config.h create mode 100644 src/turbomind/kernels/attention/desc.h create mode 100644 src/turbomind/kernels/attention/kernel.h create mode 100644 src/turbomind/kernels/attention/kernel/CMakeLists.txt create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm70_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm70_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm70_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm70_64.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm75_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm75_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm75_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm75_64.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm80_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm80_192.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm80_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm80_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/attention_sm80_64.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu create mode 100644 src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu create mode 100644 src/turbomind/kernels/attention/kernel_impl.h create mode 100644 src/turbomind/kernels/attention/registrar.h create mode 100644 src/turbomind/kernels/attention/registry.cu create mode 100644 src/turbomind/kernels/attention/registry.h create mode 100644 src/turbomind/models/llama/GatedDeltaNetLayer.cc create mode 100644 src/turbomind/models/llama/GatedDeltaNetLayer.h create mode 100644 src/turbomind/models/llama/GatedDeltaNetWeight.cc create mode 100644 src/turbomind/models/llama/GatedDeltaNetWeight.h create mode 100644 src/turbomind/models/llama/gated_delta_net_kernels.cu create mode 100644 src/turbomind/models/llama/gated_delta_net_kernels.h create mode 100644 tests/test_lmdeploy/test_qwen3coder_parser.py diff --git a/docker/Dockerfile b/docker/Dockerfile index 7c0d86b716..43bf995263 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -8,16 +8,12 @@ ENV CUDA_VERSION_SHORT=cu130 FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8 ENV CUDA_VERSION_SHORT=cu128 -FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS cu12 -ENV CUDA_VERSION_SHORT=cu124 - -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS cu11 -ENV CUDA_VERSION_SHORT=cu118 +FROM nvidia/cuda:12.6.3-devel-ubuntu22.04 AS cu12 +ENV CUDA_VERSION_SHORT=cu126 # Builder image FROM ${CUDA_VERSION} AS dev ARG PYTHON_VERSION=3.10 -ARG NCCL_BRANCH=v2.26.6-1 ENV PATH=/opt/py3/bin:/root/.local/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive @@ -42,6 +38,9 @@ FROM dev AS builder COPY . /opt/lmdeploy WORKDIR /opt/lmdeploy +RUN --mount=type=cache,target=/root/.cache \ + pip install -r requirements/runtime_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} + RUN --mount=type=cache,target=/root/.cache \ docker/build.sh @@ -55,11 +54,8 @@ ENV CUDA_VERSION_SHORT=cu130 FROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS cu12.8-base ENV CUDA_VERSION_SHORT=cu128 -FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS cu12-base -ENV CUDA_VERSION_SHORT=cu124 - -FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS cu11-base -ENV CUDA_VERSION_SHORT=cu118 +FROM nvidia/cuda:12.6.3-base-ubuntu22.04 AS cu12-base +ENV CUDA_VERSION_SHORT=cu126 FROM ${CUDA_VERSION}-base AS final ARG PYTHON_VERSION=3.10 @@ -71,11 +67,9 @@ COPY requirements /tmp/requirements COPY docker/install.sh /tmp/install.sh RUN --mount=type=cache,target=/root/.cache \ --mount=type=cache,target=/wheels,from=builder,source=/wheels \ - --mount=type=cache,target=/nccl,from=builder,source=/nccl \ /tmp/install.sh # explicitly set ptxas path for triton ENV PATH=/opt/py3/bin:$PATH ENV TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas -ENV LD_LIBRARY_PATH=/nccl/lib:$LD_LIBRARY_PATH FROM ${IMAGE_TYPE} diff --git a/docker/Dockerfile_dev b/docker/Dockerfile_dev index 53b1919844..8e71e3ce8e 100644 --- a/docker/Dockerfile_dev +++ b/docker/Dockerfile_dev @@ -4,7 +4,8 @@ FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8 ENV DEBIAN_FRONTEND=noninteractive \ TZ=Etc/UTC \ PATH=/opt/py3/bin:/root/.local/bin:${PATH} \ - TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas + TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \ + CUDA_VERSION_SHORT=cu128 # Install dependencies and create python virtual environment RUN --mount=type=cache,target=/var/cache/apt \ @@ -26,7 +27,6 @@ RUN --mount=type=cache,target=/var/cache/apt \ # Should be in the lmdeploy root directory when building docker image COPY . /opt/lmdeploy - WORKDIR /opt/lmdeploy # install lmdeploy and its dependencies @@ -34,8 +34,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/cu128 && \ uv pip install -e . -# install flash_attn RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl + docker/prepare_wheel.sh + +RUN --mount=type=cache,target=/root/.cache/uv \ + cp -r requirements /tmp/requirements && \ + docker/install.sh -RUN uv cache clean +# Clean up to reduce image size +RUN uv cache clean && \ + rm -rf /wheels /tmp/* /var/tmp/* /root/.cache/uv/* && \ + find /opt/lmdeploy -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true && \ + find /opt/lmdeploy -type f -name "*.pyc" -delete 2>/dev/null || true diff --git a/docker/build.sh b/docker/build.sh index ef282cac82..ec492e6643 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -1,25 +1,11 @@ #!/bin/bash -ex -mkdir -p /wheels /nccl +mkdir -p /wheels if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then pip install nvidia-nccl-cu13 -elif [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then - pip install nvidia-nccl-cu12 else - NVCC_GENCODE="-gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_90,code=compute_90" - pushd /tmp >/dev/null - git clone --depth=1 --branch ${NCCL_BRANCH} https://github.com/NVIDIA/nccl.git - pushd nccl >/dev/null - make NVCC_GENCODE="$NVCC_GENCODE" -j$(nproc) src.build - mkdir -p /nccl/include /nccl/lib - mv build/include/* /nccl/include/ - mv build/lib/lib* /nccl/lib/ - popd >/dev/null - popd >/dev/null - rm -rf /tmp/nccl - export LD_LIBRARY_PATH=/nccl/lib:$LD_LIBRARY_PATH + pip install nvidia-nccl-cu12 fi -pip install --upgrade pip build python3 -m build -w -o /wheels -v . diff --git a/docker/install.sh b/docker/install.sh index 54b97948f3..0cfc4cc7d2 100755 --- a/docker/install.sh +++ b/docker/install.sh @@ -1,40 +1,41 @@ #!/bin/bash -ex -# install system packages -export DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC -sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list -apt-get update -y -apt-get install -y --no-install-recommends \ - tzdata wget curl ssh sudo git-core vim libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1 - -if [[ ${PYTHON_VERSION} != "3.10" ]]; then - apt-get install -y --no-install-recommends software-properties-common - add-apt-repository -y ppa:deadsnakes/ppa +# Skip system setup if virtual env already exists (e.g., in dev image) +if [ ! -f "/opt/py3/bin/python" ]; then + # install system packages + export DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC + sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list apt-get update -y + apt-get install -y --no-install-recommends \ + tzdata wget curl ssh sudo git-core vim libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1 + + if [[ ${PYTHON_VERSION} != "3.10" ]]; then + apt-get install -y --no-install-recommends software-properties-common + add-apt-repository -y ppa:deadsnakes/ppa + apt-get update -y + fi + + # install python, create virtual env + apt-get install -y --no-install-recommends \ + python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv + + pushd /opt >/dev/null + python${PYTHON_VERSION} -m venv py3 + popd >/dev/null + + # install CUDA build tools + if [[ "${CUDA_VERSION_SHORT}" = "cu126" ]]; then + apt-get install -y --no-install-recommends cuda-minimal-build-12-6 numactl dkms + elif [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then + apt-get install -y --no-install-recommends cuda-minimal-build-12-8 numactl dkms + elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then + apt-get install -y --no-install-recommends cuda-minimal-build-13-0 numactl dkms + fi + + apt-get clean -y + rm -rf /var/lib/apt/lists/* fi -# install python, create virtual env -apt-get install -y --no-install-recommends \ - python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv - -pushd /opt >/dev/null - python${PYTHON_VERSION} -m venv py3 -popd >/dev/null - -# install CUDA build tools -if [[ "${CUDA_VERSION_SHORT}" = "cu118" ]]; then - apt-get install -y --no-install-recommends cuda-minimal-build-11-8 -elif [[ "${CUDA_VERSION_SHORT}" = "cu124" ]]; then - apt-get install -y --no-install-recommends cuda-minimal-build-12-4 numactl dkms -elif [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then - apt-get install -y --no-install-recommends cuda-minimal-build-12-8 numactl dkms -elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then - apt-get install -y --no-install-recommends cuda-minimal-build-13-0 numactl dkms -fi - -apt-get clean -y -rm -rf /var/lib/apt/lists/* - # install GDRCopy debs if [ "$(ls -A /wheels/*.deb 2>/dev/null)" ]; then dpkg -i /wheels/*.deb @@ -43,74 +44,30 @@ fi # install python packages export PATH=/opt/py3/bin:$PATH -if [[ "${CUDA_VERSION_SHORT}" = "cu118" ]]; then - FA_VERSION=2.7.3 - TORCH_VERSION="<2.7" -elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then - FA_VERSION=2.8.3 - TORCH_VERSION="==2.9.0" -else - FA_VERSION=2.8.3 - # pin torch version to avoid build and runtime mismatch, o.w. deep_gemm undefined symbol error - TORCH_VERSION="==2.8.0" -fi - pip install -U pip wheel setuptools if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then pip install nvidia-nvshmem-cu13==3.4.5 -elif [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then +else pip install nvidia-nvshmem-cu12==3.4.5 fi -pip install torch${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} pip install /wheels/*.whl - -if [[ "${CUDA_VERSION_SHORT}" != "cu118" ]] && [[ "${PYTHON_VERSION}" != "3.9" ]]; then - pip install cuda-python dlblas==0.0.7 dlslime==0.0.2.post1 -fi +pip install dlblas==0.0.7 dlslime==0.0.2.post1 # install pre-built flash attention 3 wheel -pip install ninja einops packaging -if [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then - FA3_WHEELS_URL="https://windreamer.github.io/flash-attention3-wheels/cu128_torch280" - pip install --no-index flash_attn_3 --find-links ${FA3_WHEELS_URL} -elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then - FA3_WHEELS_URL="https://windreamer.github.io/flash-attention3-wheels/cu130_torch290" - pip install --no-index flash_attn_3 --find-links ${FA3_WHEELS_URL} -fi - -# install pre-built flash attention wheel - -PLATFORM="linux_x86_64" -PY_VERSION=$(python3 - <<'PY' -import torch, sys -torch_ver = '.'.join(torch.__version__.split('.')[:2]) -cuda_ver = torch.version.cuda.split('.')[0] -cxx11abi = str(torch.compiled_with_cxx11_abi()).upper() -py_tag = f"cp{sys.version_info.major}{sys.version_info.minor}" -print(f"{torch_ver} {cuda_ver} {cxx11abi} {py_tag}") -PY -) - -read TORCH_VER CUDA_VER CXX11ABI PY_TAG <<< "$PY_VERSION" +TORCH_VER=$(python3 -c "import torch; print(''.join(torch.__version__.split('+')[0].split('.')))") -if [[ "${CUDA_VER}" == "12" ]]; then - WHEEL="flash_attn-${FA_VERSION}+cu${CUDA_VER}torch${TORCH_VER}cxx11abi${CXX11ABI}-${PY_TAG}-${PY_TAG}-${PLATFORM}.whl" - BASE_URL="https://github.com/Dao-AILab/flash-attention/releases/download/v${FA_VERSION}" - FULL_URL="${BASE_URL}/${WHEEL}" - - pip install "$FULL_URL" -fi +pip install ninja einops packaging +FA3_WHEELS_URL="https://windreamer.github.io/flash-attention3-wheels/${CUDA_VERSION_SHORT}_torch${TORCH_VER}" +pip install --no-index flash_attn_3 --find-links ${FA3_WHEELS_URL} # install requirements/serve.txt dependencies such as timm -pip install -r /tmp/requirements/serve.txt +if [ -f /tmp/requirements/serve.txt ]; then + pip install -r /tmp/requirements/serve.txt +fi -# copy nccl -if [[ "${CUDA_VERSION_SHORT}" = "cu118" ]]; then - rm -rf /opt/py3/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl - cp -R /nccl /opt/py3/lib/python${PYTHON_VERSION}/site-packages/nvidia/ -elif [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then +if [[ "${CUDA_VERSION_SHORT}" = "cu128" ]]; then # As described in https://github.com/InternLM/lmdeploy/pull/4313, # window registration may cause memory leaks in NCCL 2.27, NCCL 2.28+ resolves the issue, # but turbomind engine will use nccl GIN for EP in future, which is brought in since 2.29 diff --git a/docker/prepare_wheel.sh b/docker/prepare_wheel.sh index 86bec51078..6c64e23ecb 100755 --- a/docker/prepare_wheel.sh +++ b/docker/prepare_wheel.sh @@ -2,17 +2,7 @@ export PATH=/opt/py3/bin:$PATH -if [[ "${CUDA_VERSION_SHORT}" = "cu118" ]]; then - TORCH_VERSION="<2.7" -elif [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then - TORCH_VERSION="==2.9.0" -else - # pin torch version to avoid build and runtime mismatch, o.w. deep_gemm undefined symbol error - TORCH_VERSION="==2.8.0" -fi - pip install "cmake<4.0" wheel ninja setuptools packaging -pip install torch${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT} if [[ ${PYTHON_VERSION} = "3.13" ]]; then curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -22,35 +12,37 @@ if [[ ${PYTHON_VERSION} = "3.13" ]]; then pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/google/sentencepiece.git@v0.2.0#subdirectory=python" fi -if [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then - GDRCOPY_VERSION=2.5.1 - DEEP_EP_VERSION=9af0e0d # v1.2.1 - DEEP_GEMM_VERSION=c9f8b34 # v2.1.1.post3 - FLASH_MLA_VERSION=1408756 # no release, pick the latest commit - - # DeepEP - if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then - export CPLUS_INCLUDE_PATH="/usr/local/cuda/include/cccl":${CPLUS_INCLUDE_PATH} - pip install nvidia-nvshmem-cu13==3.4.5 - else - pip install nvidia-nvshmem-cu12==3.4.5 - fi - pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepEP.git@${DEEP_EP_VERSION}" - - # DeepGEMM - pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepGEMM.git@${DEEP_GEMM_VERSION}" - - # FlashMLA - # sm100 compilation for Flash MLA requires NVCC 12.9 or higher - FLASH_MLA_DISABLE_SM100=1 pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/FlashMLA.git@${FLASH_MLA_VERSION}" - - # GDRCopy debs - apt-get update -y \ - && apt-get install -y --no-install-recommends build-essential devscripts debhelper fakeroot pkg-config dkms - - wget -q https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \ - && tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \ - && cd gdrcopy-${GDRCOPY_VERSION}/packages \ - && CUDA=/usr/local/cuda ./build-deb-packages.sh \ - && mv ./*.deb /wheels +GDRCOPY_VERSION=2.5.1 +DEEP_EP_VERSION=9af0e0d # v1.2.1 +DEEP_GEMM_VERSION=c9f8b34 # v2.1.1.post3 +FLASH_MLA_VERSION=1408756 # no release, pick the latest commit + +# DeepEP +if [[ "${CUDA_VERSION_SHORT}" = "cu130" ]]; then + export CPLUS_INCLUDE_PATH="/usr/local/cuda/include/cccl":${CPLUS_INCLUDE_PATH} + pip install nvidia-nvshmem-cu13==3.4.5 +else + pip install nvidia-nvshmem-cu12==3.4.5 fi +pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepEP.git@${DEEP_EP_VERSION}" + +# DeepGEMM +pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/DeepGEMM.git@${DEEP_GEMM_VERSION}" + +# FlashMLA +# sm100 compilation for Flash MLA requires NVCC 12.9 or higher +FLASH_MLA_DISABLE_SM100=1 pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/deepseek-ai/FlashMLA.git@${FLASH_MLA_VERSION}" + +# GDRCopy debs +apt-get update -y \ +&& apt-get install -y --no-install-recommends build-essential devscripts debhelper fakeroot pkg-config dkms + +wget -q https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \ +&& tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \ +&& cd gdrcopy-${GDRCOPY_VERSION}/packages \ +&& CUDA=/usr/local/cuda ./build-deb-packages.sh \ +&& mv ./*.deb /wheels + +# Clean up build artifacts +cd / && rm -rf gdrcopy-${GDRCOPY_VERSION} +apt-get clean -y && rm -rf /var/lib/apt/lists/* diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 59aeff3c10..a4fe0d2333 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -43,7 +43,7 @@ def autoget_backend(model_path: str) -> Literal['turbomind', 'pytorch']: if is_turbomind_installed: if not turbomind_has: logger.warning('Fallback to pytorch engine because ' - f'`{model_path}` not supported by turbomind' + f'{model_path!r} not supported by turbomind' ' engine.') else: logger.warning('Fallback to pytorch engine because turbomind engine is not ' @@ -121,6 +121,8 @@ def check_vl_llm(config: dict) -> bool: return True elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config: return True + elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration']: + return False elif arch in supported_archs: return True return False diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 8b54705e31..b029d98c26 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -95,6 +95,8 @@ class GenerationConfig: } logits_processors: Custom logit processors. + repetition_ngram_size: The size of n-grams to consider for repetition early stop. + repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop. """ n: int = 1 @@ -129,6 +131,10 @@ class GenerationConfig: # router replay return_routed_experts: bool = False + # ngram, generation would stop if latest [size] tokens are repeated for [threshold] times + repetition_ngram_size: int = 0 + repetition_ngram_threshold: int = 0 + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 1106f20c9c..d0b6a5e2d6 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -146,8 +146,7 @@ def _log_resps(outputs: List[InferOutput]): if logger.level <= logging.DEBUG: session_ids = [out.session_id for out in outputs] logger.debug(f'Response sessions: {session_ids}') - elif logger.level <= logging.INFO: - logger.info(f'Response: num_outputs={len(outputs)}.') + logger.debug(f'Response: num_outputs={len(outputs)}.') def _send_resp(self, out: InferOutput): """Send response.""" diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 821978d504..f6f290fc29 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple +from functools import lru_cache +from typing import Any +import numpy as np import torch from lmdeploy.messages import LogitsProcessor +from lmdeploy.pytorch import envs from ..messages import SchedulerSequence from .guided_process import GuidedDecodingManager @@ -53,7 +56,7 @@ def _process_bad_words_(scores: torch.Tensor, return scores -def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor): +def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor): """Process repetition penalty.""" score = torch.gather(scores, 1, input_ids) penalty = penalty.to(score.dtype) @@ -92,6 +95,131 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: return scores +@lru_cache +def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1): + return torch.ones(fill, dtype=dtype, device=device) + + +def ngram( + token_ids: torch.Tensor, + n: torch.Tensor | None, + threshold: torch.Tensor, + max_n: int, + max_window_size: int, +): + """Compute n-gram matches between sliding windows and a target sequence. + + For each batch, performs cosine similarity checking between: + - All sliding windows of length `max_n` from the full sequence + - The last `max_n` tokens of the sequence (target window) + + A match is counted when both: + 1. Cosine similarity ≈ 1 (normalized vectors match) + 2. Vector lengths match (preventing zero/normalization artifacts) + + Parameters + ---------- + token_ids : torch.Tensor + Input token IDs of shape (batch_size, seq_len). + Values are typically ≥0 (0 may represent padding/special tokens). + n : torch.Tensor + Effective n-gram length for each batch element, shape (batch_size,). + When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked. + threshold : torch.Tensor + Minimum number of matching windows required for validity, shape (batch_size,). + max_n : int + Maximum n-gram length (window size for matching). + max_window_size: int + Maximum window size for matching. + + Returns + ------- + matched_mask : torch.Tensor + Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating + which sliding windows match the target n-gram. + found : torch.Tensor + Boolean tensor of shape (batch_size,) indicating whether each batch + element has at least `threshold` matches. + """ + + batch_size, seq_len = token_ids.size() + if seq_len < max_n: + # Not enough tokens to form a single n-gram + matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device) + found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) + return matched_mask, found + # token_ids could be 0, so we add 2 to avoid div 0 + token_ids = (token_ids + 2).to(torch.float32).log2() + + # Trim to max_window_size + if seq_len >= max_window_size: + token_ids = token_ids[:, -max_window_size:] + max_window_size = token_ids.size(1) + + # normalize ids + # we would set n=None if n shared same value. Read lmdeploy/pytorch/strategies/ar/sampling.py for more details + same_n = n is None + norm = token_ids[:, -max_n:] + if not same_n: + # fill 0 for n < max_n + mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1)) + norm = norm * mask.to(torch.float32) + norm = norm.norm(2, dim=-1, keepdim=True) + normed_ids = token_ids / norm + + # concate p1 and p2 so we can check distance and vector in one conv1d + normed_n_ids = normed_ids[:, -max_n:] + normed_ids_p2 = normed_ids * normed_ids + ones_ids = torch.ones_like(normed_n_ids) + if not same_n: + # fill 0 for n < max_n + normed_n_ids = normed_n_ids * mask.to(torch.float32) + ones_ids = ones_ids * mask.to(torch.float32) + normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0) + normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) + + # check cos distance & check vector length + match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] + match_norm, match_ones = match_norm.chunk(2, dim=0) + + # both match result should be close to 1 + one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1) + matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) + + # threshold + count = matched_mask.sum(-1) + found = (count >= threshold) & (threshold > 0) + + return matched_mask, found + + +def _filter_repetition_ngram_( + scores: torch.Tensor, + stop_words: torch.Tensor, + generated_ids: torch.Tensor, + n: torch.Tensor | None, + threshold: torch.Tensor, + max_n: int, + max_ngram_window_size: int, +): + """Filter ngram. + + if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist. + """ + if stop_words is None or stop_words.numel() == 0: + return scores + # use first stop words + _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size) + stop_words = stop_words[:, 0] + # fill all scores -inf + scores.masked_fill_(found[:, None], -float('inf')) + # set stop words to 0 + stop_scores = scores.gather(1, stop_words[:, None]) + stop_scores.masked_fill_(found[:, None], 0) + scores.scatter_(1, stop_words[:, None], stop_scores) + return scores + + def _multinomial_sampling(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, @@ -101,14 +229,14 @@ def _multinomial_sampling(scores: torch.Tensor, return multinomial_sampling(scores, seeds, offsets, indices) -SeqList = List[SchedulerSequence] +SeqList = list[SchedulerSequence] @dataclass class SamplingInputsDelta: num_ignore_eos: torch.Tensor = None random_offsets: torch.Tensor = None - all_ids: Optional[torch.Tensor] = None + all_ids: None | torch.Tensor = None @dataclass @@ -126,18 +254,28 @@ class SamplingInputs: random_offsets: torch.Tensor = None max_top_k: int = 1 min_top_p: float = 1.0 - response_formats: Tuple[str] = () - logits_processors: List[List[LogitsProcessor]] = None - max_num_logprobs: Optional[int] = None - all_ids: Optional[torch.Tensor] = None + response_formats: list[str, ...] = () + logits_processors: list[list[LogitsProcessor]] = None + max_num_logprobs: None | int = None + all_ids: None | torch.Tensor = None num_ignore_eos: torch.Tensor = None batch_size: int = 0 - session_ctx: Optional[List[Dict[str, Any]]] = None - session_to_cleanup: Optional[List[int]] = None + session_ctx: None | list[dict[str, Any]] = None + session_to_cleanup: None | list[int] = None + # for repetition_penalty and ngram + generated_ids: torch.Tensor | None = None + generated_ids_cpu: np.ndarray | None = None + + # n gram + repetition_ngram_size: torch.Tensor | None = None + repetition_ngram_threshold: torch.Tensor | None = None + max_repetition_ngram_size: int = 0 def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() + if self.generated_ids is None and self.generated_ids_cpu is not None: + self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) for f in fields(self): k = f.name v = getattr(self, k) @@ -192,8 +330,8 @@ class FusedLogitsProcessor: def __init__( self, sampling_inputs: SamplingInputs, - logprobs_mode: Optional[str] = None, - guided_decoding_manager: Optional[GuidedDecodingManager] = None, + logprobs_mode: None | str = None, + guided_decoding_manager: None | GuidedDecodingManager = None, ): self.sampling_inputs: SamplingInputs = sampling_inputs self.logprobs_mode = logprobs_mode @@ -213,10 +351,10 @@ async def _wait_stream_once(self): if not stream.query(): await asyncio.sleep(0) - async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: + async def __call__(self, scores: torch.Tensor) -> torch.Tensor: r""" Args: - scores (torch.FloatTensor): + scores (torch.Tensor): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token @@ -224,7 +362,7 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: Return: - torch.FloatTensor: The processed prediction scores. + torch.Tensor: The processed prediction scores. """ @@ -262,7 +400,23 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: repetition_penalty = sampling_inputs.repetition_penalty if repetition_penalty is not None: - scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) + generated_ids = sampling_inputs.generated_ids + scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) + + if sampling_inputs.max_repetition_ngram_size > 0: + generated_ids = sampling_inputs.generated_ids + assert generated_ids is not None + assert sampling_inputs.repetition_ngram_threshold is not None + max_repetition_ngram_window_size = envs.repetition_window_size + scores = _filter_repetition_ngram_( + scores, + sampling_inputs.stop_words, + generated_ids, + sampling_inputs.repetition_ngram_size, + sampling_inputs.repetition_ngram_threshold, + sampling_inputs.max_repetition_ngram_size, + max_repetition_ngram_window_size, + ) temperature = sampling_inputs.temperature if temperature is not None: @@ -346,7 +500,7 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens return logprobs, indices.to(torch.int32) - def cleanup_sessions(self, session_ids: List[int]): + def cleanup_sessions(self, session_ids: list[int]): if self.guided_decoding_manager: for session_id in session_ids: self.guided_decoding_manager.remove_processor(session_id) diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index 100b921280..0e44e19ac1 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -157,6 +157,9 @@ def _patched_get_env( # model format scale_fmt = os.getenv('LMDEPLOY_SCALE_FMT', None) + # repetition check + repetition_window_size = env_to_int('LMDEPLOY_REPETITION_WINDOW_SIZE', 1024) + def get_all_envs(): """Get all environment variables.""" diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index be8aa2b61a..288cfce1de 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -2,7 +2,7 @@ import enum from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List import numpy as np import torch @@ -56,13 +56,17 @@ class SamplingParam: bad_words: List[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 - response_format: Optional[str] = None - logits_processors: Optional[List[LogitsProcessor]] = None + response_format: None | str = None + logits_processors: None | List[LogitsProcessor] = None out_logits: bool = False out_last_hidden_states: bool = False num_logprobs: int = -1 return_routed_experts: bool = False + # ngram + repetition_ngram_size: int = 0 + repetition_ngram_threshold: int = 0 + @classmethod def from_gen_config(cls, gen_config: GenerationConfig): """From gen config.""" @@ -144,6 +148,8 @@ def from_gen_config(cls, gen_config: GenerationConfig): out_logits=(output_logits is not None), num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, + repetition_ngram_size=gen_config.repetition_ngram_size, + repetition_ngram_threshold=gen_config.repetition_ngram_threshold, ) @@ -263,7 +269,7 @@ def add_sequence(self, adapter_name: str = None, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, - migration_request: Optional[MigrationRequest] = None, + migration_request: None | MigrationRequest = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -624,7 +630,7 @@ class SchedulerSequence: model_meta: Dict[str, Any] = None # For Disaggregation - migration_request: Optional[MigrationRequest] = None + migration_request: None | MigrationRequest = None resp_cache: bool = False preserve_cache: bool = False @@ -721,7 +727,7 @@ def routed_experts(self) -> np.ndarray: else: return None - def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]): + def append_routed_experts(self, routed_experts: Tensor | np.ndarray): """Append routed experts.""" if not self.return_routed_experts: return @@ -786,7 +792,7 @@ def mrope_pos_ids(self): end = start + self._num_token_ids return self.history_mrope_pos_ids[start:end] - def append_logits(self, logits: Union[Tensor, np.ndarray]): + def append_logits(self, logits: Tensor | np.ndarray): """Append logits.""" if not self.return_logits: return @@ -806,7 +812,7 @@ def get_input_multimodals(self): def record_event( self, event_type: EventType, - timestamp: Optional[float] = None, + timestamp: None | float = None, ) -> None: self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 3766377311..8a8c45d032 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List import numpy as np import torch @@ -11,12 +10,12 @@ from ..base.sampling import SamplingStrategy -SeqList = List[SchedulerSequence] +SeqList = list[SchedulerSequence] def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + if not any(sampling_inputs.logits_processors): return None batch = len(seqs) max_len = max(seq.num_valid_ids for seq in seqs) @@ -30,6 +29,22 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) return output +def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None: + """Gather history.""" + if sampling_inputs.repetition_penalty is None and sampling_inputs.max_repetition_ngram_size == 0: + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = np.full((batch, max_len), pad_id, dtype=np.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = seq.generated_ids + output[idx, -h_len:] = h_ids + return output + + def _get_num_ignore_eos(seqs: SeqList): """Get num ignore eos.""" ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] @@ -62,6 +77,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: num_logprobs = [None] * batch_size session_to_cleanup = self.session_to_cleanup self.session_to_cleanup = [] + repetition_ngram_sizes = [None] * batch_size + repetition_ngram_thresholds = [None] * batch_size def __gather_params(): """Gather params.""" @@ -85,6 +102,8 @@ def __gather_params(): stop_words[idx] = sw logits_processors[idx] = param.logits_processors num_logprobs[idx] = param.num_logprobs + repetition_ngram_sizes[idx] = param.repetition_ngram_size + repetition_ngram_thresholds[idx] = param.repetition_ngram_threshold def __get_topp(top_p): """Get topp.""" @@ -165,6 +184,18 @@ def __get_bad_words(bad_words): 'seq_id': seq.seq_id, } for seq in seqs] + # repetition ngram + max_repetition_ngram_size = max(repetition_ngram_sizes) + if max_repetition_ngram_size == 0: + repetition_ngram_sizes = None + repetition_ngram_thresholds = None + else: + repetition_ngram_sizes = torch.tensor(repetition_ngram_sizes) + repetition_ngram_thresholds = torch.tensor(repetition_ngram_thresholds) + repetition_ngram_same_n = (repetition_ngram_sizes == max_repetition_ngram_size).all().item() + if repetition_ngram_same_n: + repetition_ngram_sizes = None + sampling_input = SamplingInputs( temperature=temperature, bad_words=bad_words, @@ -185,10 +216,14 @@ def __get_bad_words(bad_words): batch_size=batch_size, session_ctx=session_ctx, session_to_cleanup=session_to_cleanup, + repetition_ngram_size=repetition_ngram_sizes, + repetition_ngram_threshold=repetition_ngram_thresholds, + max_repetition_ngram_size=max_repetition_ngram_size, ) pad_token_id = self.pad_token_id sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input) sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 5a027e922d..d7c8bc4716 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +import numpy as np import torch from torch.profiler import record_function @@ -42,15 +43,34 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'random_offsets', 'all_ids', 'num_ignore_eos', + 'ngram_size', + 'ngram_threshold', ] for name in update_attr_names: attr = getattr(out, name) if attr is None: continue - repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) - attr = attr[None].repeat(*repeats).flatten(0, 1) + if attr.dim() == 1: + repeats = (dllm_block_length, 1) + attr = attr[None].repeat(*repeats).flatten(0, 1) + elif attr.dim() == 2: + repeats = (1, dllm_block_length, 1) + attr = attr[:, None].repeat(*repeats).flatten(0, 1) + else: + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) setattr(out, name, attr) + # update generated_ids_cpu + if out.generated_ids_cpu is not None: + generated_ids_cpu = out.generated_ids_cpu + if generated_ids_cpu.shape[1] == 0: + out.generated_ids_cpu = np.repeat(generated_ids_cpu, dllm_block_length, axis=0) + else: + generated_ids_cpu = np.repeat(generated_ids_cpu[:, None], dllm_block_length, axis=1) + generated_ids_cpu = np.reshape(generated_ids_cpu, (-1, generated_ids_cpu.shape[-1])) + out.generated_ids_cpu = generated_ids_cpu + if len(out.response_formats) > 0: new_resp_formats = [] for resp in out.response_formats: diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 2a68483ad4..05c3485fdb 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -394,7 +394,7 @@ def is_error(status): metrics_processor.increase_total_requests() async with session.request_handle() as handle: if epoch != self.epoch: - logger.debug(f'[generate] session {session_id} got aborted before starting inference') + logger.info(f'[generate] session {session_id} got aborted before starting inference') # TODO(lvhan): metrics_processor.increase_failed_requests('abort') metrics_processor.increase_completed_requests() yield GenOut(response='', diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 64eea7c22b..7dfefc767c 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -84,11 +84,11 @@ async def request_handle(self): except SafeRunException: pass except (asyncio.CancelledError, GeneratorExit) as e: - logger.error(f'[request_handle] session {self.session_id} exception caught: {e}') + logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}') await self._handle.async_cancel(self.session_id) except Exception as e: - logger.error(f'Session {self.session_id} failed to acquire an inference instance: {e}') - raise e + logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}') + raise finally: logger.debug(f'[request_handle] session {self.session_id} releasing the instance') # Return inference instance if it was acquired @@ -104,8 +104,6 @@ async def async_abort(self): logger.info(f'[session] Aborting session {self.session_id}') if self._handle is not None: await self._handle.async_cancel(self.session_id) - # DO NOT reset the session here because it might be used by other components. - # Leave the cleanup to the caller. async def async_close(self): """End the session.""" @@ -117,8 +115,8 @@ async def async_close(self): async with self.request_handle() as handle: try: await handle.async_end(self.session_id) - except (Exception, asyncio.CancelledError, GeneratorExit): - logger.exception('[async_end] exception caught') + except (Exception, asyncio.CancelledError, GeneratorExit) as e: + logger.exception(f'[async_close] exception caught: {e}') self.reset() def abort(self): @@ -219,16 +217,14 @@ async def async_abort_all(self): tasks.append(session.async_abort()) await asyncio.gather(*tasks, return_exceptions=True) # "abort all" is designed for async RL. The aborted sessions will be no longer used, - # so we reset and clear the sessions here. - for session in list(self.sessions.values()): - session.reset() + # so we clear the sessions here. self.sessions.clear() def has(self, session_id): return session_id in self.sessions def remove(self, session: Session): - self.sessions.pop(session.session_id) + self.sessions.pop(session.session_id, None) def clear(self): self.sessions.clear() diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 3e37caffe5..7b8b2cd9db 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -962,6 +962,8 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): spaces_between_special_tokens=request.spaces_between_special_tokens, include_stop_str_in_output=request.include_stop_str_in_output, return_routed_experts=request.return_routed_experts, + repetition_ngram_size=request.repetition_ngram_size, + repetition_ngram_threshold=request.repetition_ngram_threshold, ) result_generator = VariableInterface.async_engine.generate( diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 9749ade086..5f3d252e7b 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -226,7 +226,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[ChoiceLogprobs] = None - finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None + finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None class ChatCompletionResponse(BaseModel): @@ -453,6 +453,8 @@ class GenerateReqInput(BaseModel): spaces_between_special_tokens: Optional[bool] = True include_stop_str_in_output: Optional[bool] = False return_routed_experts: Optional[bool] = False + repetition_ngram_size: int = 0 + repetition_ngram_threshold: int = 0 # kwargs for hf processor mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, diff --git a/lmdeploy/serve/openai/tool_parser/__init__.py b/lmdeploy/serve/openai/tool_parser/__init__.py index ced46b833c..e1e2b2726e 100644 --- a/lmdeploy/serve/openai/tool_parser/__init__.py +++ b/lmdeploy/serve/openai/tool_parser/__init__.py @@ -3,12 +3,14 @@ from .llama3_parser import Llama3JsonToolParser from .qwen2d5_parser import Qwen2d5ToolParser from .qwen3_parser import Qwen3ToolParser +from .qwen3coder_parser import Qwen3CoderToolParser from .tool_parser import ToolParser, ToolParserManager __all__ = [ 'Internlm2ToolParser', 'Qwen2d5ToolParser', 'Qwen3ToolParser', + 'Qwen3CoderToolParser', 'ToolParser', 'ToolParserManager', 'Llama3JsonToolParser', diff --git a/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py b/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py new file mode 100644 index 0000000000..24ee53c7a8 --- /dev/null +++ b/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import shortuuid + +from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, + ExtractedToolCallInformation, FunctionCall, ToolCall) +from lmdeploy.utils import get_logger + +from .tool_parser import ToolParser, ToolParserManager + +logger = get_logger('lmdeploy') + + +@dataclass +class ParserState(object): + """Maintains the state of parsing during tool call extraction.""" + position: int = 0 # Current position in the text being parsed + current_index: int = -1 # Index of the current tool call + + id: str = '' # ID of the current tool call + + def reset_tool_call(self): + """Called when `` finish tag occurred.""" + self.id = '' + + +@ToolParserManager.register_module(['qwen3coder']) +class Qwen3CoderToolParser(ToolParser): + """Parser for Qwen3 Coder model's tool call format. + + Handles the extraction of tool calls from Qwen3 Coder's output format, which uses purely XML tags for function names + and parameters, e.g., arg_value + + """ + + def __init__(self, tokenizer: object): + super().__init__(tokenizer) + self.tool_start_token = '' + self.tool_end_token = '' + self.func_prefix = '(.*?)', re.DOTALL) + + def _split(self, parser_state: ParserState, parsing_content: str) -> Tuple[str, str, bool]: + """Split content into tuple: (text_content, tool_content, has_tool_end)""" + try: + start_idx = parsing_content.index(self.tool_start_token) + parser_state.position += start_idx + except ValueError: + parser_state.position += len(parsing_content) + return parsing_content, '', False + + try: + end_idx = parsing_content.index(self.tool_end_token) + except ValueError: + return parsing_content[:start_idx], parsing_content[start_idx:], False + + rem = end_idx - start_idx + parser_state.position += rem + len(self.tool_end_token) + return parsing_content[:start_idx], parsing_content[start_idx:end_idx + len(self.tool_end_token)], True + + def _extract_params(self, content: str) -> Tuple[Optional[str], Dict[str, Any], bool]: + """Parse XML tool content into components.""" + content = content.replace(self.tool_start_token, '').replace(self.tool_end_token, '').strip() + + func_name = None + func_start = content.find(self.func_prefix) + if func_start != -1: + name_start = func_start + len(self.func_prefix) + terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] + if terminators: + func_name = content[name_start:min(terminators)].strip() + + args_dict = {} + search_idx = 0 + while True: + param_start = content.find(self.param_prefix, search_idx) + if param_start == -1: + break + + name_start = param_start + len(self.param_prefix) + terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] + if not terminators: + break + + name_end = min(terminators) + param_name = content[name_start:name_end].strip() + + val_start = name_end + 1 + val_end = content.find(self.param_end_token, val_start) + if val_end == -1: + break + + param_val_str = content[val_start:val_end].strip() + + if param_val_str.lower() == 'null': + val = None + elif param_val_str.lower() == 'true': + val = True + elif param_val_str.lower() == 'false': + val = False + else: + try: + val = json.loads(param_val_str) + except json.JSONDecodeError: + val = param_val_str + args_dict[param_name] = val + search_idx = val_end + len(self.param_end_token) + + is_func_closed = self.func_end_token in content + return func_name, args_dict, is_func_closed + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + parser_state = getattr(request, '_tool_parser_state', None) + if parser_state is None: + parser_state = ParserState() + setattr(request, '_tool_parser_state', parser_state) + + split_result = self._split(parser_state, current_text[parser_state.position:]) + text_content, tool_content, has_tool_end = split_result + + delta = DeltaMessage() + if text_content: + delta.content = text_content + + if tool_content: + if not parser_state.id: + parser_state.id = f'chatcmpl-tool-{shortuuid.random()}' + parser_state.current_index += 1 + parser_state.has_emitted_name = False + parser_state.has_emitted_json_start = False + parser_state.json_closed = False + parser_state.emitted_params = set() + + func_name, args_dict, is_func_closed = self._extract_params(tool_content) + + fcall_delta = DeltaFunctionCall() + has_updates = False + + if func_name and not getattr(parser_state, 'has_emitted_name', False): + fcall_delta.name = func_name + parser_state.has_emitted_name = True + has_updates = True + + json_fragments = [] + if not getattr(parser_state, 'has_emitted_json_start', False): + if args_dict or is_func_closed: + json_fragments.append('{') + parser_state.has_emitted_json_start = True + + for k, v in args_dict.items(): + if k not in parser_state.emitted_params: + prefix = ', ' if len(parser_state.emitted_params) > 0 else '' + serialized = json.dumps(v, ensure_ascii=False) + json_fragments.append(f'{prefix}"{k}": {serialized}') + parser_state.emitted_params.add(k) + + if is_func_closed and not getattr(parser_state, 'json_closed', False): + if getattr(parser_state, 'has_emitted_json_start', False): + json_fragments.append('}') + parser_state.json_closed = True + + joined_fragments = ''.join(json_fragments) + if joined_fragments: + fcall_delta.arguments = joined_fragments + has_updates = True + + if has_updates: + parsed_delta = DeltaToolCall( + id=parser_state.id, + index=parser_state.current_index, + function=fcall_delta, + ) + delta.tool_calls = [parsed_delta] + + if has_tool_end: + parser_state.reset_tool_call() + # Prepare for the next tool call + if hasattr(parser_state, 'has_emitted_name'): + delattr(parser_state, 'has_emitted_name') + delattr(parser_state, 'has_emitted_json_start') + delattr(parser_state, 'json_closed') + delattr(parser_state, 'emitted_params') + + return delta + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + buf = [] + scan_pos = 0 + tool_calls = [] + + for idx, match in enumerate(self.tool_call_pat.finditer(text)): + buf.append(text[scan_pos:match.start()]) + scan_pos = match.end() + + tool_content = match.group(1) + func_name, args_dict, _ = self._extract_params(tool_content) + + if func_name: + tool_calls.append( + ToolCall(function=FunctionCall( + name=func_name, arguments=json.dumps(args_dict, ensure_ascii=False) if args_dict else '{}'))) + + if scan_pos < len(text): + buf.append(text[scan_pos:]) + + text = ''.join(buf) + + return ExtractedToolCallInformation( + content=text, + tool_calls=tool_calls, + tools_called=bool(tool_calls), + ) diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index 714c166ea3..c243554d7d 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -91,6 +91,18 @@ class ModelConfig: kv_lora_rank: int = 0 qk_rope_dim: int = 0 v_head_dim: int = 0 + # Qwen 3.5 + layer_types: List[str] = field(default_factory=list) + linear_key_head_dim: int = 0 + linear_value_head_dim: int = 0 + linear_conv_kernel_dim: int = 0 + linear_num_key_heads: int = 0 + linear_num_value_heads: int = 0 + attn_output_gate: bool = False + # Per-layer expert weight type override: layer indices whose + # MoE experts are unquantized (fp16) despite expert_weight_type=int4. + # Populated from modules_to_not_convert patterns like 'model.layers.0.'. + unquantized_expert_layers: List[int] = field(default_factory=list) # tuning tune_layer_num: int = 1 diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 1eda377ad2..9d697ef567 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -124,6 +124,17 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s modules_to_not_convert = getattr(quant_config, 'modules_to_not_convert', None) or [] if any('self_attn' in m for m in modules_to_not_convert): weight_type = dtype + if any('shared_expert' in m for m in modules_to_not_convert): + ffn_weight_type = dtype + # Detect per-layer exclusions like 'model.layers.0.' which mean + # ALL weights in that layer (including MoE experts) are fp16. + import re as _re + unquantized_expert_layers = [] + for m in modules_to_not_convert: + _m = _re.match(r'model\.layers\.(\d+)\.?$', m) + if _m: + unquantized_expert_layers.append(int(_m.group(1))) + config.model_config.unquantized_expert_layers = unquantized_expert_layers config.model_config.model_arch = model_arch config.model_config.data_type = dtype @@ -230,9 +241,12 @@ def get_tm_model(model_path, tm_cfg.model_config.chat_template = chat_template_name tm_cfg.model_config.model_name = model_name - tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size - tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size - tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size + if engine_config.attn_tp_size is not None: + tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size + if engine_config.attn_cp_size is not None: + tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size + if engine_config.mlp_tp_size is not None: + tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model, cfg=tm_cfg, diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 2f91657b98..330fbacc9e 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -22,6 +22,32 @@ def permute_v2(x: torch.Tensor, size_per_head: int = 128): return x.view(-1, head_num, 2, size_per_head // 2).transpose(2, 3).reshape(x.shape) +def permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int): + """Permute only the first rotary_dim elements of each head. + + Used when partial_rotary_factor < 1.0: only the rotary portion needs interleaving for TurboMind's RoPE kernel + layout. + """ + assert x.size(-1) > 1 + assert rotary_dim % 2 == 0, f'rotary_dim must be even, got {rotary_dim}' + assert rotary_dim <= size_per_head, f'rotary_dim ({rotary_dim}) must be <= size_per_head ({size_per_head})' + output_dims = x.size(-1) + assert output_dims % size_per_head == 0, (f'output_dims ({output_dims}) must be divisible by ' + f'size_per_head ({size_per_head})') + head_num = output_dims // size_per_head + orig_shape = x.shape + if x.dim() == 1: + x = x.unsqueeze(0) + x = x.view(x.size(0), head_num, size_per_head) + rotary = x[:, :, :rotary_dim] + passthrough = x[:, :, rotary_dim:] + # Interleave rotary part: [2, rotary_dim//2] -> [rotary_dim//2, 2] + rotary = rotary.view(x.size(0), head_num, 2, rotary_dim // 2).transpose(2, 3).contiguous() + rotary = rotary.view(x.size(0), head_num, rotary_dim) + x = torch.cat([rotary, passthrough], dim=-1) + return x.reshape(orig_shape) + + def merge_qkv_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int): """ Contract: x.size(-1) is output dims @@ -39,6 +65,24 @@ def reshape(x): return qkv +def merge_qkvg_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, tp: int): + """Merge Q, K, V, and Gate with gate appended after V. + + Layout per tp-shard: [Q | K | V | Gate]. + """ + + def reshape(x): + return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1) + + qkvg = torch.cat(tuple(map(reshape, (q, k, v, gate))), dim=-1) + + qkvg = qkvg.view(-1, qkvg.size(-1) * tp) + if q.dim() == 1: + qkvg.squeeze_() + + return qkvg + + def transpose(x): return x.t() if x is not None else x @@ -193,17 +237,56 @@ def __init__(self, model: BaseOutputModel): self.qk_norm = model.model_config.qk_norm self.attn_sink = model.model_config.attn_sink self.group_size = max(1, model.model_config.group_size) + self.attn_output_gate = model.model_config.attn_output_gate + rope_param = model.attention_config.rope_param + self.rope_dim = rope_param.dim if rope_param else self.head_dim + self.head_num = model.model_config.head_num + + def _split_q_gate(self, q): + """Split interleaved Q+gate tensor into separate Q and gate. + + HF layout: [Q_head0, Gate_head0, Q_head1, Gate_head1, ...] + Returns: (q_real, gate) each with shape [..., num_heads * head_dim] + """ + output_dims = q.size(-1) + head_num = output_dims // (self.head_dim * 2) + orig_shape = list(q.shape) + if q.dim() == 1: + q = q.unsqueeze(0) + q = q.view(q.size(0), head_num, 2, self.head_dim) + q_real = q[:, :, 0, :].contiguous() + gate = q[:, :, 1, :].contiguous() + new_last_dim = head_num * self.head_dim + q_real = q_real.reshape(-1, new_last_dim) + gate = gate.reshape(-1, new_last_dim) + if len(orig_shape) == 1: + q_real = q_real.squeeze(0) + gate = gate.squeeze(0) + return q_real, gate def _reorder_and_merge(self, qkvo, gs: int): q, k, v, o = qkvo + gate = None + # When attn_output_gate, Q is interleaved [Q0, G0, Q1, G1, ...] + # Split into separate Q and gate before permuting + if self.attn_output_gate and q is not None: + q, gate = self._split_q_gate(q) # reorder output dim for tm's rotary embedding layout if self.model.permute_qk: if gs == 1: - q = permute_v2(q, self.head_dim) - k = permute_v2(k, self.head_dim) + if self.rope_dim < self.head_dim: + q = permute_v2_partial(q, self.head_dim, self.rope_dim) + k = permute_v2_partial(k, self.head_dim, self.rope_dim) + else: + q = permute_v2(q, self.head_dim) + k = permute_v2(k, self.head_dim) else: assert gs % self.head_dim == 0 - qkv = merge_qkv_v2(q, k, v, self.tp) + # Merge QKV with gate appended at end if present + if gate is not None: + qkv = merge_qkvg_v2(q, k, v, gate, self.tp) + else: + qkv = merge_qkv_v2(q, k, v, self.tp) # zero bias for `wo` when `w_qkv` has bias but `wo` doesn't if o is None and q.dim() == 1: o = torch.zeros_like(q) @@ -265,11 +348,16 @@ def apply(self, i: int, r: BaseReader): e(self._export, partial(r.attn, i), i) if self.qk_norm: q, k = r.qk_norm(i) - if self.model.permute_qk: - q = permute_v2(q, self.head_dim) - k = permute_v2(k, self.head_dim) - self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1]) - self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1]) + if q is not None and k is not None: + if self.model.permute_qk: + if self.rope_dim < self.head_dim: + q = permute_v2_partial(q, self.head_dim, self.rope_dim) + k = permute_v2_partial(k, self.head_dim, self.rope_dim) + else: + q = permute_v2(q, self.head_dim) + k = permute_v2(k, self.head_dim) + self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1]) + self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1]) if self.attn_sink: sinks = r.attn_sinks(i) self.model.save_split(sinks, self._attn.format(i, 'sinks', '')[:-1], split_dim=-1, split_num=self.tp) @@ -393,6 +481,98 @@ def apply(self, i: int, r: BaseReader): self.model.save_split(k, self._layernorm.format(i, 'kv')) +class LinearAttn(Module): + _linear_attn = 'layers.{0}.linear_attn.{1}.{2}' + + def __init__(self, model: BaseOutputModel): + self.model = model + self.tp = model.attn_tp_size + cfg = model.model_config + self.key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim + self.value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim + + def _tp_interleave_qkv(self, tensor, dim): + """Split a concatenated [Q, K, V] tensor into components, reshape each + for TP interleaving, and re-concatenate. + + in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim). + A naive split doesn't respect component boundaries when key_dim and + value_dim differ. This method splits Q/K/V, reshapes each to + ``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens + so that a subsequent ``save_split(split_dim=dim)`` gives each rank the + correct portion. + """ + if dim < 0: + dim = tensor.dim() + dim + q, k, v = torch.split(tensor, [self.key_dim, self.key_dim, self.value_dim], dim=dim) + + def reshape(x): + # Move TP axis to a new dimension right after ``dim`` + shape = list(x.shape) + d = shape[dim] + new_shape = shape[:dim] + [self.tp, d // self.tp] + shape[dim + 1:] + return x.view(new_shape) + + parts = torch.cat([reshape(q), reshape(k), reshape(v)], dim=dim + 1) + # Collapse tp and per-shard dims back + shape = list(parts.shape) + final_shape = shape[:dim] + [shape[dim] * shape[dim + 1]] + shape[dim + 2:] + return parts.reshape(final_shape) + + def apply(self, i: int, r: BaseReader): + layer_types = getattr(self.model.model_config, 'layer_types', []) + if i >= len(layer_types) or layer_types[i] != 'linear_attention': + return + + for kind in ['weight', 'bias']: + weights = r.linear_attn(i, kind) + if not weights: + continue + + names = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias'] + for name, tensor in zip(names, weights): + if tensor is None: + continue + if name == 'conv1d': + # conv1d shape: (conv_dim, 1, d_conv) where + # conv_dim = key_dim*2 + value_dim. Interleave Q/K/V + # portions along dim 0 before splitting for TP. + tensor = self._tp_interleave_qkv(tensor, dim=0) + self.model.save_split(tensor, + self._linear_attn.format(i, name, kind), + split_dim=0, + split_num=self.tp) + elif name in ['A_log', 'dt_bias']: + # Split per-head params across TP ranks (use -1 to + # avoid the 1-D copy shortcut in save_split). + self.model.save_split(tensor, + self._linear_attn.format(i, name, kind), + split_dim=-1, + split_num=self.tp) + elif name == 'out_proj': + self.model.save_split(transpose(tensor), + self._linear_attn.format(i, name, kind), + split_dim=0, + split_num=self.tp) + elif name == 'in_proj_qkv': + # in_proj_qkv: (conv_dim, hidden) where conv_dim = + # key_dim*2 + value_dim. After transpose the QKV + # components are along dim -1. Interleave for TP so + # each shard gets the correct Q/K/V slice. + t = transpose(tensor) + t = self._tp_interleave_qkv(t, dim=-1) + self.model.save_split(t, self._linear_attn.format(i, name, kind), split_dim=-1, split_num=self.tp) + else: + self.model.save_split(transpose(tensor), + self._linear_attn.format(i, name, kind), + split_dim=-1, + split_num=self.tp) + + norm = r.linear_norm(i, 'weight') + if norm is not None: + self.model.export_weight(norm, f'layers.{i}.linear_attn.norm.weight') + + class Misc(Module): """ requires: @@ -437,6 +617,8 @@ def __init__(self, model: BaseOutputModel): modules.append(MLA) else: modules.append(Attn) + if getattr(model.model_config, 'layer_types', []): + modules.append(LinearAttn) if model.model_config.inter_size: modules.append(Ffn) if model.model_config.expert_num: diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index 307a9cc5c8..339b084f9a 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -137,7 +137,10 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict): model_config = model_config.text_config elif hasattr(model_config, 'llm_config'): model_config = model_config.llm_config - self.model_config = model_config.to_dict() + if hasattr(model_config, 'to_dict'): + self.model_config = model_config.to_dict() + else: + self.model_config = model_config self.fp8_quant = kwargs.get('fp8_quant', False) def readers(self): @@ -155,7 +158,7 @@ def model_info(self): norm_eps = model_arg['rms_norm_eps'] attn_head_num = model_arg['num_attention_heads'] vocab_size = model_arg['vocab_size'] - inter_size = model_arg['intermediate_size'] + inter_size = model_arg.get('intermediate_size', 0) if 'num_key_value_heads' in model_arg: kv_head_num = model_arg['num_key_value_heads'] else: diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py index 6f457a177b..0c3d815c06 100644 --- a/lmdeploy/turbomind/deploy/source_model/qwen.py +++ b/lmdeploy/turbomind/deploy/source_model/qwen.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import json import os.path as osp +import re import torch from ..config import RopeParam +from ..loader import create_loader from .base import INPUT_MODELS from .llama import LlamaModel, LlamaReader @@ -216,3 +218,248 @@ def model_info(self): inter_size=0, # no shared expert norm_topk_prob=cfg.get('norm_topk_prob', False)) return info + + +class Qwen3_5ReaderMixin: + """Mixin providing linear attention weight reading for Qwen3.5 models. + + Qwen3.5 uses a zero-centered RMSNorm: ``output = norm(x) * (1 + weight)`` + where weight is initialized to zeros. TurboMind's RMSNorm kernel computes + ``norm(x) * weight`` (standard LLaMA style), so we add 1 to every + RMSNorm weight during export. The GDN-internal norm + (``Qwen3_5MoeRMSNormGated``) uses standard weight and is NOT affected. + """ + + attn_layer_pattern = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.' + + _LINEAR_ATTN_KEYS = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if any(k.startswith('model.language_model.') for k in self.params.keys()): + self.attn_layer_prefix = 'model.language_model.layers' + self.tok_embeddings_key = 'model.language_model.embed_tokens.weight' + self.norm_weight_key = 'model.language_model.norm.weight' + + # ---- zero-centered RMSNorm: add 1 to weights during export ---- + + def attn_norm(self, i: int): + w = super().attn_norm(i) + if w is not None: + w = w.float() + 1.0 + return w + + def ffn_norm(self, i: int): + w = super().ffn_norm(i) + if w is not None: + w = w.float() + 1.0 + return w + + def norm_weight(self): + w = super().norm_weight() + if w is not None: + w = w.float() + 1.0 + return w + + def qk_norm(self, i: int): + result = super().qk_norm(i) + return tuple(w.float() + 1.0 if w is not None else w for w in result) + + # ---- handle mixed QKV(fp16) + O(AWQ) attention layers ------- + + def _attn(self, i: int, kind: str): + """Override to handle mixed QKV(fp16) + O(AWQ) attention layers. + + Some AWQ-quantized Qwen3.5 models keep QKV in fp16 while quantizing only the O projection. TurboMind requires + uniform weight types per layer, so we dequantize O to fp16 at export time. + """ + prefix = f'{self.attn_layer_prefix}.{i}.self_attn' + q_is_fp16 = f'{prefix}.q_proj.weight' in self.params + o_is_awq = f'{prefix}.o_proj.qweight' in self.params + + if not (q_is_fp16 and o_is_awq): + # Not a mixed-format layer, use standard behaviour. + return super()._attn(i, kind) + + # Mixed format detected: QKV are fp16 but O is AWQ. + if kind == 'weight': + # Get fp16 QKV the normal way, then dequantize O. + q, k, v, _ = super()._attn(i, kind) + o = self._awq_dequant(f'{prefix}.o_proj') + o = self.transform(o, kind) + return (q, k, v, o) + + # For any quant kind (qweight/scales/qzeros), return all None + # so that the AWQ handler skips this layer entirely — the O + # weight is already handled via dequantization above. + return (None, None, None, None) + + def _awq_dequant(self, prefix: str): + """Dequantize an AWQ-quantized linear layer to fp16. + + AWQ stores weights in transposed form relative to PyTorch's + convention ([in, out] vs [out, in]), so we transpose here to + match the fp16 ``.weight`` layout that downstream export + expects. + """ + from lmdeploy.pytorch.backends.default.awq_modules import dequantize_gemm + qweight = self.params[f'{prefix}.qweight'] + scales = self.params[f'{prefix}.scales'] + qzeros = self.params[f'{prefix}.qzeros'] + group_size = qweight.shape[0] // scales.shape[0] + w = dequantize_gemm(qweight, qzeros, scales, 4, group_size) + return w.t() # [in, out] → [out, in] (PyTorch convention) + + def linear_attn(self, i: int, kind: str): + if not kind: + return self.filter(r'linear_attn\.', i) + # Always return a fixed-length tuple with None placeholders to + # preserve positional alignment with the name list in module.py. + result = [] + for key in self._LINEAR_ATTN_KEYS: + prefix = f'{self.attn_layer_prefix}.{i}.linear_attn.{key}' + tensor = self.params.get(f'{prefix}.{kind}') + # A_log and dt_bias are bare nn.Parameter (no .weight suffix) + if tensor is None: + tensor = self.params.get(prefix) + # If requesting weight but only AWQ qweight exists, + # dequantize on the fly so LinearAttn gets fp16 tensors. + if tensor is None and kind == 'weight': + if f'{prefix}.qweight' in self.params: + tensor = self._awq_dequant(prefix) + if tensor is not None: + tensor = self.transform(tensor, kind) + result.append(tensor) # keep None to preserve alignment + if all(t is None for t in result): + return tuple() + return tuple(result) + + def linear_norm(self, i: int, kind: str = 'weight'): + tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.linear_attn.norm.{kind}') + if tensor is not None: + return self.transform(tensor, kind) + return None + + +class Qwen3_5Reader(Qwen3_5ReaderMixin, Qwen3Reader): + pass + + +@INPUT_MODELS.register_module(name='qwen3_5') +class Qwen3_5Model(Qwen3Model): + Reader = Qwen3_5Reader + + def model_info(self): + if 'text_config' in self.model_config: + self.model_config = self.model_config['text_config'] + cfg = self.model_config + info = super().model_info() + # MoE parameters (same as Qwen2MoeModel / Qwen3MoeModel) + info['expert_num'] = cfg.get('num_experts', 0) + info['expert_inter_size'] = cfg.get('moe_intermediate_size', 0) + info['experts_per_token'] = cfg.get('num_experts_per_tok', 0) + # For MoE models, inter_size is the shared expert intermediate size; + # for dense models, keep the value from super() (intermediate_size). + shared_expert_size = cfg.get('shared_expert_intermediate_size') + if shared_expert_size is not None: + info['inter_size'] = shared_expert_size + info['moe_shared_gate'] = True + # Qwen3.5 uses sigmoid MoE routing (not softmax) + info['scoring_func'] = 'sigmoid' + info['norm_topk_prob'] = True + # Fix RoPE dim for partial_rotary_factor + rope_params = cfg.get('rope_parameters', {}) + partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0)) + if partial_rotary_factor < 1.0: + info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor) + # Linear attention parameters + info['layer_types'] = cfg.get('layer_types', []) + info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0) + info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0) + info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0) + info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0) + info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0) + # attn_output_gate doubles Q projection for full-attention layers + info['attn_output_gate'] = cfg.get('attn_output_gate', False) + return info + + +class Qwen3_5MoeReader(Qwen3_5ReaderMixin, Qwen3MoeReader): + + def _unpacked_moe_expert(self, e: int, i: int, kind: str): + prefix = f'{self.attn_layer_prefix}.{i}.mlp.experts' + gate_up = self.params.get(f'{prefix}.gate_up_proj.{kind}') + down = self.params.get(f'{prefix}.down_proj.{kind}') + if gate_up is None or down is None: + return None + + # Packed Qwen3.5 MoE checkpoints store all experts in the first + # dimension. Slice one expert before transform so quantized policies + # still see a 2D tensor. + gate_up = self.transform(gate_up[e], kind) + down = self.transform(down[e], kind) + gate, up = gate_up.chunk(2, dim=0) + return (gate, down, up) + + def moe_ffn_expert(self, e=None, i=None, kind=None): + if not kind: + return self.filter(r'experts', i) + unpacked = self._unpacked_moe_expert(e, i, kind) + if unpacked is not None: + return unpacked + + return super().moe_ffn_expert(e, i, kind) + + +@INPUT_MODELS.register_module(name='qwen3_5-moe') +class Qwen3_5MoeModel(Qwen3MoeModel): + Reader = Qwen3_5MoeReader + + @staticmethod + def map_packed_qwen35_experts(name: str): + """Map packed expert names to weight names, i.e., + "mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj.weight" so that + class Weight in parameter.py can classify them.""" + s = re.sub(r'(mlp\.experts\.(?:gate_up|down)_proj)$', r'\1.weight', name) + return s + + def readers(self): + pattern = getattr(self.Reader, 'attn_layer_pattern', self.Reader.attn_layer_patten) + loader = create_loader(self.model_path, pattern, []) + + has_packed_gate_up = any('mlp.experts.gate_up_proj' in k for k in loader.index.keys()) + has_packed_down = any('mlp.experts.down_proj' in k for k in loader.index.keys()) + if has_packed_gate_up and has_packed_down: + loader.mappings = [self.map_packed_qwen35_experts] + + for i, param in loader.items(): + reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant) + yield i, reader + torch.cuda.empty_cache() + + def model_info(self): + if 'text_config' in self.model_config: + self.model_config = self.model_config['text_config'] + cfg = self.model_config + info = super().model_info() + # Shared expert params (missing from Qwen3MoeModel base) + info['inter_size'] = cfg.get('shared_expert_intermediate_size', 0) + info['moe_shared_gate'] = True + # Qwen3.5 uses sigmoid MoE routing (not softmax) + info['scoring_func'] = 'sigmoid' + info['norm_topk_prob'] = True + # Fix RoPE dim for partial_rotary_factor + rope_params = cfg.get('rope_parameters', {}) + partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0)) + if partial_rotary_factor < 1.0: + info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor) + # Linear attention parameters + info['layer_types'] = cfg.get('layer_types', []) + info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0) + info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0) + info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0) + info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0) + info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0) + # attn_output_gate doubles Q projection for full-attention layers + info['attn_output_gate'] = cfg.get('attn_output_gate', False) + return info diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index c257a31c06..732b38c84d 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -32,6 +32,9 @@ # Qwen3 Qwen3ForCausalLM='qwen3', Qwen3MoeForCausalLM='qwen3-moe', + # Qwen 3.5 + Qwen3_5ForConditionalGeneration='qwen3_5', + Qwen3_5MoeForConditionalGeneration='qwen3_5-moe', # mistral MistralForCausalLM='llama', # llava diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 779289627d..5e06ab5ae9 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -8,7 +8,6 @@ from contextlib import contextmanager from dataclasses import dataclass from logging import Logger, LogRecord -from typing import List, Optional, Tuple, Union import torch from transformers import PretrainedConfig @@ -26,7 +25,7 @@ class _ASNI_COLOR: # copy from: https://github.com/termcolor/termcolor @functools.cache -def can_colorize(*, no_color: Optional[bool] = None, force_color: Optional[bool] = None) -> bool: +def can_colorize(*, no_color: bool | None = None, force_color: bool | None = None) -> bool: """Check env vars and for tty/dumb terminal.""" import io if no_color is not None and no_color: @@ -110,8 +109,8 @@ def filter(self, record: LogRecord) -> bool: ' - %(message)s' -def get_logger(name: Optional[str] = None, - log_file: Optional[str] = None, +def get_logger(name: str | None = None, + log_file: str | None = None, log_level: int = logging.INFO, file_mode: str = 'a', log_formatter: str = _FORMAT) -> Logger: @@ -178,7 +177,7 @@ def get_logger(name: Optional[str] = None, return logger -def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str: +def filter_suffix(response: str, suffixes: list[str] | None = None) -> str: """Filter response with suffixes. Args: @@ -197,12 +196,12 @@ def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str: # TODO remove stop_word_offsets stuff and make it clean -def _stop_words(stop_words: List[Union[int, str]], tokenizer: object): +def _stop_words(stop_words: list[int | str], tokenizer: object): """Return list of stop-words to numpy.ndarray.""" import numpy as np if stop_words is None: return None - assert isinstance(stop_words, List) and \ + assert isinstance(stop_words, list) and \ all(isinstance(elem, (str, int)) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' stop_indexes = [] @@ -211,7 +210,7 @@ def _stop_words(stop_words: List[Union[int, str]], tokenizer: object): stop_indexes += tokenizer.indexes_containing_token(stop_word) elif isinstance(stop_word, int): stop_indexes.append(stop_word) - assert isinstance(stop_indexes, List) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words' + assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words' # each id in stop_indexes represents a stop word # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # detailed explanation about fastertransformer's stop_indexes @@ -297,7 +296,7 @@ async def __tmp(): # modified from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150 # noqa def _get_and_verify_max_len( hf_config: PretrainedConfig, - max_model_len: Optional[int], + max_model_len: int | None, ) -> int: """Get and verify the model's maximum length.""" @@ -326,7 +325,11 @@ def _get_and_verify_max_len( ] max_len_key = None for key in possible_keys: - max_len = getattr(hf_config, key, None) + max_len = None + if hasattr(hf_config, key): + max_len = getattr(hf_config, key) + elif key in hf_config: + max_len = hf_config[key] if max_len is not None: max_len_key = key if max_len < derived_max_model_len \ else max_len_key @@ -503,9 +506,9 @@ class FlattenedTensorBucket: def __init__( self, - named_tensors: List[Tuple[str, torch.Tensor]] = None, + named_tensors: list[tuple[str, torch.Tensor]] | None = None, flattened_tensor: torch.Tensor = None, - metadata: List[FlattenedTensorMetadata] = None, + metadata: list[FlattenedTensorMetadata] | None = None, ): """Initialize a tensor bucket from a list of named tensors or from pre- flattened data. @@ -548,11 +551,11 @@ def get_flattened_tensor(self) -> torch.Tensor: """Get the flattened tensor containing multiple tensors.""" return self.flattened_tensor - def get_metadata(self) -> List[FlattenedTensorMetadata]: + def get_metadata(self) -> list[FlattenedTensorMetadata]: """Get all metadatas for all tensors in the bucket.""" return self.metadata - def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]: """Reconstruct original tensors.""" # preallocate the result list reconstructed = [None] * len(self.metadata) diff --git a/src/turbomind/core/module.h b/src/turbomind/core/module.h index aaa520e845..147a3d6593 100644 --- a/src/turbomind/core/module.h +++ b/src/turbomind/core/module.h @@ -1,3 +1,6 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef TURBOMIND_CORE_MODULE_H +#define TURBOMIND_CORE_MODULE_H #include "src/turbomind/core/tensor.h" @@ -34,3 +37,5 @@ class Module { }; } // namespace turbomind::core + +#endif // TURBOMIND_CORE_MODULE_H diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index 1c1f226412..dedb822c10 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -173,6 +173,15 @@ struct RequestCache { float rope_base = 0.f; + // Qwen 3.5 Gated DeltaNet linear attention persistent states. + // These are Tensor objects whose GPU memory is auto-managed by the RequestCache lifecycle. + // Allocated by GatedDeltaNetLayer::Run(kAdd), used/updated in Forward(), freed when + // RequestCache is destroyed (when the request completes or is canceled). + // conv_states: (num_linear_layers, conv_dim, d_conv) — per-channel rolling conv history + // recurrent_states: (num_linear_layers, num_v_heads, key_head_dim, value_head_dim) — SSM state + Tensor conv_states; + Tensor recurrent_states; + Interval output_hidden_states; Interval output_logits; }; diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index 3c4e8d8f6a..d517b5752c 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -1,67 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. +add_subdirectory(kernel) + add_library(attention STATIC attention.cu decoding.cu - reduce.cu kv_cache_utils_v2.cu - utils.cc cp_utils.cu - codegen/attention_sm70_128_f16.cu - codegen/attention_sm75_128_f16.cu - codegen/attention_sm80_128_bf16.cu - codegen/attention_sm80_128_f16.cu - codegen/decoding_sm70_128_f16_f16.cu - codegen/decoding_sm70_128_f16_u4.cu - codegen/decoding_sm70_128_f16_u8.cu - codegen/decoding_sm75_128_f16_f16.cu - codegen/decoding_sm75_128_f16_u4.cu - codegen/decoding_sm75_128_f16_u8.cu - codegen/decoding_sm80_128_bf16_bf16.cu - codegen/decoding_sm80_128_bf16_u4.cu - codegen/decoding_sm80_128_bf16_u8.cu - codegen/decoding_sm80_128_f16_f16.cu - codegen/decoding_sm80_128_f16_u4.cu - codegen/decoding_sm80_128_f16_u8.cu - codegen/attention_sm70_64_f16.cu - codegen/attention_sm75_64_f16.cu - codegen/attention_sm80_64_bf16.cu - codegen/attention_sm80_64_f16.cu - codegen/decoding_sm70_64_f16_f16.cu - codegen/decoding_sm70_64_f16_u4.cu - codegen/decoding_sm70_64_f16_u8.cu - codegen/decoding_sm75_64_f16_f16.cu - codegen/decoding_sm75_64_f16_u4.cu - codegen/decoding_sm75_64_f16_u8.cu - codegen/decoding_sm80_64_bf16_bf16.cu - codegen/decoding_sm80_64_bf16_u4.cu - codegen/decoding_sm80_64_bf16_u8.cu - codegen/decoding_sm80_64_f16_f16.cu - codegen/decoding_sm80_64_f16_u4.cu - codegen/decoding_sm80_64_f16_u8.cu - codegen/attention_sm80_192.cu - codegen/decoding_sm80_192.cu - codegen/attention_sm70_576_f16.cu - codegen/attention_sm75_576_f16.cu - codegen/attention_sm80_576_f16.cu - codegen/attention_sm80_576_bf16.cu - codegen/decoding_sm70_576_f16_f16.cu - codegen/decoding_sm70_576_f16_u4.cu - codegen/decoding_sm70_576_f16_u8.cu - codegen/decoding_sm75_576_f16_f16.cu - codegen/decoding_sm75_576_f16_u4.cu - codegen/decoding_sm75_576_f16_u8.cu - codegen/decoding_sm80_576_f16_f16.cu - codegen/decoding_sm80_576_f16_u4.cu - codegen/decoding_sm80_576_f16_u8.cu - codegen/decoding_sm80_576_bf16_bf16.cu - codegen/decoding_sm80_576_bf16_u4.cu - codegen/decoding_sm80_576_bf16_u8.cu + registry.cu ) set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_compile_options(attention PRIVATE -O3 - $<$:-use_fast_math --expt-relaxed-constexpr>) + $<$:-use_fast_math --expt-relaxed-constexpr -Xptxas=-v --threads 16>) +target_link_libraries(attention PUBLIC $) target_link_libraries(attention PRIVATE nvidia::cutlass::cutlass) if (BUILD_TEST) diff --git a/src/turbomind/kernels/attention/arch.h b/src/turbomind/kernels/attention/arch.h index b883089460..e18679f992 100644 --- a/src/turbomind/kernels/attention/arch.h +++ b/src/turbomind/kernels/attention/arch.h @@ -15,12 +15,31 @@ struct Arch { }; struct Sm70: Arch<700, 750> { + static constexpr int value = 700; }; struct Sm75: Arch<750, 800> { + static constexpr int value = 750; }; struct Sm80: Arch<800> { + static constexpr int value = 800; }; +inline bool is_arch_compatible(int karch, int darch) +{ + switch (karch) { + case 0: + return true; + case 700: + return Sm70::is_compatible(darch); + case 750: + return Sm75::is_compatible(darch); + case 800: + return Sm80::is_compatible(darch); + default: + return false; + } +} + } // namespace turbomind::arch diff --git a/src/turbomind/kernels/attention/attention.cu b/src/turbomind/kernels/attention/attention.cu index 61a9b4b38f..8637af2a54 100644 --- a/src/turbomind/kernels/attention/attention.cu +++ b/src/turbomind/kernels/attention/attention.cu @@ -1,62 +1,28 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "attention.h" -#include "attention_config.h" -#include "src/turbomind/kernels/attention/arch.h" -#include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/kernels/attention/registry.h" #include "src/turbomind/utils/cuda_utils.h" namespace turbomind { -template -void invokeAttention(const typename Kernel::ParamType& params); - template void dispatchAttention(const AttentionParams& params) { using namespace attention; - auto dispatch = [&](const auto dim) { - constexpr int kHeadDim = dim; - if (params.arch >= 80) { - using Config = AttentionConfig; - return invokeAttention(params); - } - if constexpr (!std::is_same_v) { - if (params.arch == 75) { - return invokeAttention::Kernel>( - params); - } - else if (params.arch >= 70) { - return invokeAttention::Kernel>( - params); - } - } - else { - if (params.arch < 80) { - TM_LOG_ERROR( - "CUDA architecture sm%d does not support data type 'bfloat16'. Please specify dtype 'float16'", - params.arch); - } - } - FT_CHECK(0); - }; - if (params.size_per_head == 64) { - return dispatch(std::integral_constant{}); - } - else if (params.size_per_head == 128) { - return dispatch(std::integral_constant{}); - } - else if (params.size_per_head == 576) { - return dispatch(std::integral_constant{}); - } + auto& reg = Registry::instance(); + AttnDesc desc{}; + desc.mode = AttnDesc::kPrefill; + desc.head_dim = params.size_per_head; + desc.data_type = data_type_v; + + auto* kernel = reg.Find(desc); - if (params.size_per_head == 192) { - using Config = AttentionConfig; - return invokeAttention(params); - } + TM_CHECK(kernel) << "No attention kernel found: " + to_string(desc); - FT_CHECK(0); + kernel->Launch(¶ms, reg.sm_count()); } template void dispatchAttention(const AttentionParams& params); diff --git a/src/turbomind/kernels/attention/attention_config.h b/src/turbomind/kernels/attention/attention_config.h deleted file mode 100644 index bf399bb35b..0000000000 --- a/src/turbomind/kernels/attention/attention_config.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#pragma once - -#include "arch.h" -#include "block_iterator.h" -#include "cta_map.h" -#include "impl_16816.h" -#include "impl_1688.h" -#include "impl_884.h" -#include "impl_simt.h" -#include "linear_iterator.h" -#include "mainloop_sm70.h" -#include "mainloop_sm80.h" -#include "src/turbomind/kernels/attention/attention_universal.h" -#include "src/turbomind/kernels/attention/impl.h" - -namespace turbomind::attention { - -enum class CacheType -{ - kLinear, - kBlock, -}; - -template -struct AttentionConfig { - static_assert(sizeof(T) == 0, "config not found"); -}; - -template -using GetCacheIterFactory = std::conditional_t, - GetBlockIterFactory>; - -struct Base_64x64_16x64 { - static constexpr int CTA_Q = 64; - static constexpr int CTA_S = 64; - static constexpr int WARP_Q = 16; - static constexpr int WARP_S = 64; -}; - -template -struct AttentionConfig: Base_64x64_16x64 { - using Attention = Impl; - using CacheIter = LinearIteratorFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, AttentionCtaMap>; -}; - -template -struct AttentionConfig { - static constexpr int CTA_Q = 64; - static constexpr int CTA_S = 128; - static constexpr int WARP_Q = 16; - static constexpr int WARP_S = 128; - - using Attention = Impl; - using CacheIter = LinearIteratorFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, AttentionCtaMap>; -}; - -template -struct AttentionConfig: Base_64x64_16x64 { - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, AttentionCtaMap>; -}; - -template -struct AttentionConfig: Base_64x64_16x64 { - using Attention = Impl; - using CacheIter = GetCacheIterFactory; - using Kernel = AttentionUniversal, CacheIter, AttentionCtaMap>; -}; - -template -struct AttentionConfig { - // MMA_884 config for Volta V100 with HeadDim=576 (GLM-4.7-Flash) - // CTA_Q=64 with WARP_Q=16 gives 4 warps (matching the generic Sm70 config), - // maximizing tensor core utilization and latency hiding. - // CTA_S=32 (reduced from 64 to fit shared memory with HeadDim=576). - // Shared memory: max(Q=64×580×2=72.5KB, K+V+P=32×580×2+32×576×2+64×36×2=76.8KB) - // = 76.8KB < 96KB V100 limit - static constexpr int CTA_Q = 64; - static constexpr int CTA_S = 32; - static constexpr int WARP_Q = 16; - static constexpr int WARP_S = 32; - - using Attention = Impl; - using CacheIter = GetCacheIterFactory; - using Kernel = AttentionUniversal, CacheIter, AttentionCtaMap>; -}; - -template -struct AttentionConfig: Base_64x64_16x64 { - using Attention = Impl; - using CacheIter = GetCacheIterFactory; - using Kernel = AttentionUniversal, CacheIter, AttentionCtaMap>; -}; - -} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/attention_template.h b/src/turbomind/kernels/attention/attention_template.h index a23c3415ef..0a37920c73 100644 --- a/src/turbomind/kernels/attention/attention_template.h +++ b/src/turbomind/kernels/attention/attention_template.h @@ -10,7 +10,7 @@ namespace turbomind { template -void invokeAttention(const typename Kernel::ParamType& params) +void invokeAttention(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas) { static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); @@ -29,21 +29,6 @@ void invokeAttention(const typename Kernel::ParamType& params) static const auto kernel_func = &attention_kernel; - thread_local const int2 caps = [&] { - auto err = cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - if (err) { - std::cout << cudaGetErrorString(err) << "\n"; - std::abort(); - } - int device_id{}; - cudaGetDevice(&device_id); - int sm_count{}; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id); - int max_active_ctas{}; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_ctas, kernel_func, block.x, kSmemSize); - return int2{sm_count, max_active_ctas}; - }(); - const int max_cp_k_len = cdiv(params.max_k_len, (int)params.cp_size); const int tile_count = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S); const int max_split_count = std::min(params.max_split_k, tile_count); @@ -55,7 +40,7 @@ void invokeAttention(const typename Kernel::ParamType& params) dim3 grid = cta_map.get_grid_shape(); const int grid_size = grid.x * grid.y * grid.z; - const int split_cnt = GetSplitCount(max_split_count, grid_size, caps.y, caps.x, 8); + const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 8); // printf("max split cnt: %d, split cnt: %d\n", max_split_count, split_cnt); diff --git a/src/turbomind/kernels/attention/codegen/attention_sm70_128_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm70_128_f16.cu deleted file mode 100644 index 55b6b0298a..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm70_128_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm70_576_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm70_576_f16.cu deleted file mode 100644 index 1b57ceca53..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm70_576_f16.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu deleted file mode 100644 index 2277390c23..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm75_128_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm75_128_f16.cu deleted file mode 100644 index 19e0826c48..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm75_128_f16.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// ! register spill -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm75_576_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm75_576_f16.cu deleted file mode 100644 index cb1918cea8..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm75_576_f16.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu deleted file mode 100644 index cef945015a..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// ! register spill -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_128_bf16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_128_bf16.cu deleted file mode 100644 index 95c9375983..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_128_bf16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_128_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_128_f16.cu deleted file mode 100644 index 62aacae21c..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_128_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu deleted file mode 100644 index ceeafa7a6d..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_576_bf16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_576_bf16.cu deleted file mode 100644 index ddfe281a9a..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_576_bf16.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_576_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_576_f16.cu deleted file mode 100644 index 754357bf9e..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_576_f16.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu deleted file mode 100644 index b5456bd194..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu deleted file mode 100644 index e628fc7442..0000000000 --- a/src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_config.h" -#include "../attention_template.h" - -namespace turbomind { - -using namespace attention; - -template void invokeAttention::Kernel>( - const AttentionParams& params); - -// template void invokeAttention::Kernel>( -// const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_f16.cu deleted file mode 100644 index 5b0745aeaa..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u4.cu deleted file mode 100644 index 135ba35552..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u4.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u8.cu deleted file mode 100644 index 0a3ee89f6c..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_128_f16_u8.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_f16.cu deleted file mode 100644 index 49b0e945a8..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u4.cu deleted file mode 100644 index 28580041d3..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u4.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u8.cu deleted file mode 100644 index abc82166bc..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_576_f16_u8.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu deleted file mode 100644 index 12558aeae6..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_f16.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu deleted file mode 100644 index 25b49f9590..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu deleted file mode 100644 index 824cd5b02e..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_f16.cu deleted file mode 100644 index 3378f7db4a..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_f16.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u4.cu deleted file mode 100644 index f0570a5ea6..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u8.cu deleted file mode 100644 index 4e376fe62a..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_128_f16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_f16.cu deleted file mode 100644 index 727e7baaae..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_f16.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u4.cu deleted file mode 100644 index a067f3996d..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u4.cu +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u8.cu deleted file mode 100644 index d9f57db244..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_576_f16_u8.cu +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu deleted file mode 100644 index 456e6e18d7..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_f16.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu deleted file mode 100644 index 171e59f5f1..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu deleted file mode 100644 index 1d6d40ed3a..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_bf16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_bf16.cu deleted file mode 100644 index 25f36e77b7..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_bf16.cu +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u4.cu deleted file mode 100644 index a0c89d9cf2..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u8.cu deleted file mode 100644 index 39bb09e820..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_bf16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_f16.cu deleted file mode 100644 index 6bc558d953..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_f16.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u4.cu deleted file mode 100644 index 4f2e6952a8..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u8.cu deleted file mode 100644 index 83748f7671..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_128_f16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu deleted file mode 100644 index 214e6748d9..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_bf16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_bf16.cu deleted file mode 100644 index 41fc8587f4..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_bf16.cu +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u4.cu deleted file mode 100644 index 9d7d518030..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u8.cu deleted file mode 100644 index 60ef9cd3a9..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_bf16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_f16.cu deleted file mode 100644 index 09aa841062..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_f16.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u4.cu deleted file mode 100644 index 74fc554740..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u4.cu +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u8.cu deleted file mode 100644 index a7072df1df..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_576_f16_u8.cu +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../attention_params.h" -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu deleted file mode 100644 index b657034c4c..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_bf16.cu +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -template bool -invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu deleted file mode 100644 index a5c0b34b7f..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu deleted file mode 100644 index a7dd3050b1..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_bf16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu deleted file mode 100644 index e73be11e62..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_f16.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -template bool invokeDecoding>(const AttentionParams& params); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu deleted file mode 100644 index c7c560e98d..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu deleted file mode 100644 index 06f6ce5600..0000000000 --- a/src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "../decoding_config.h" -#include "../decoding_template.h" - -namespace turbomind { - -using namespace attention; - -template bool invokeDecoding>(const AttentionParams&); - -template bool invokeDecoding>(const AttentionParams&); - -} // namespace turbomind diff --git a/src/turbomind/kernels/attention/decoding.cu b/src/turbomind/kernels/attention/decoding.cu index a2ace9e758..62356424d1 100644 --- a/src/turbomind/kernels/attention/decoding.cu +++ b/src/turbomind/kernels/attention/decoding.cu @@ -1,146 +1,39 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include -#include - #include "decoding.h" -#include "decoding_config.h" -#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/kernels/attention/registry.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/cuda_utils.h" namespace turbomind { -template -bool invokeDecoding(const typename Kernel::ParamType& params); - -template -using seq = std::integer_sequence; - -template -constexpr auto get_kv_type(std::integral_constant) -{ - if constexpr (is_kv_int8) { - return int8_t{}; - } - else { - return T{}; - } -} - template void dispatchDecoding(const AttentionParams& params) { + using namespace attention; + const bool is_kv_int8 = params.quant_policy & QuantPolicy::kCacheKVInt8; const bool is_kv_int4 = params.quant_policy & QuantPolicy::kCacheKVInt4; const int query_group_sz = params.num_heads / params.num_kv_heads; - using namespace attention; - - /// TODO: we need better Qh dispatching, when #waves < 1, smaller Qh may outperform larger Qh due to better - // concurrency - auto dispatch_h = [&](auto arch, auto kv, const auto dim) -> bool { - using Arch = decltype(arch); - using Tkv = decltype(kv); - constexpr int kHeadDim = dim; - - if (kHeadDim == 576) { - return invokeDecoding>(params); // faster than Qh=16 - } - - if (0) {} - else if (query_group_sz > 8) { - return invokeDecoding>(params); - } - else if (query_group_sz == 8) { - return invokeDecoding>(params); - } - else if (query_group_sz == 7) { - return invokeDecoding>(params); - } - else if (query_group_sz == 6) { - return invokeDecoding>(params); - } - else if (query_group_sz == 5) { - return invokeDecoding>(params); - } - else if (query_group_sz == 4) { - return invokeDecoding>(params); - } - else if (query_group_sz == 3) { - return invokeDecoding>(params); - } - else if (query_group_sz == 2) { - return invokeDecoding>(params); - } - else { - return invokeDecoding>(params); - } - return false; - }; - - auto dispatch_kv = [&](auto arch, const auto dim) -> bool { - FT_CHECK(!(is_kv_int4 && is_kv_int8)); - if (is_kv_int4) { - return dispatch_h(arch, uint4_t{}, dim); - } - else if (is_kv_int8) { - return dispatch_h(arch, uint8_t{}, dim); - } - else { - return dispatch_h(arch, T{}, dim); - } - return false; - }; - - auto dispatch_head_dim = [&](auto arch) { - if (params.size_per_head == 128) { - return dispatch_kv(arch, std::integral_constant{}); - } - else if (params.size_per_head == 64) { - return dispatch_kv(arch, std::integral_constant{}); - } - else if (params.size_per_head == 576) { - return dispatch_kv(arch, std::integral_constant{}); - } - return false; - }; - - auto dispatch = [&]() { - if (params.arch >= 80) { - return dispatch_head_dim(arch::Sm80{}); - } - - if constexpr (!std::is_same_v) { - if (params.arch == 75) { - return dispatch_head_dim(arch::Sm75{}); - } - else if (params.arch >= 70) { - return dispatch_head_dim(arch::Sm70{}); - } - } + FT_CHECK(!(is_kv_int4 && is_kv_int8)); - return false; - }; + int kv_quant = is_kv_int4 ? 4 : (is_kv_int8 ? 8 : 0); - if (params.size_per_head == 192) { + AttnDesc desc{}; + desc.mode = AttnDesc::kDecoding; + desc.head_dim = params.size_per_head; + desc.data_type = data_type_v; + desc.kv_quant = kv_quant; + desc.query_group_sz = query_group_sz; - if (is_kv_int8) { - invokeDecoding>(params); - } - else if (is_kv_int4) { - FT_CHECK_WITH_INFO(!is_kv_int4, "not implemented"); - // invokeDecoding>(params); - } - else { - invokeDecoding>(params); - } - return; - } + auto& reg = Registry::instance(); + auto* kernel = reg.Find(desc); - auto success = dispatch(); + TM_CHECK(kernel) << "No decoding kernel found: " + to_string(desc); - FT_CHECK(success); + kernel->Launch(¶ms, reg.sm_count()); } template void dispatchDecoding(const AttentionParams& params); diff --git a/src/turbomind/kernels/attention/decoding_config.h b/src/turbomind/kernels/attention/decoding_config.h deleted file mode 100644 index 15096ee147..0000000000 --- a/src/turbomind/kernels/attention/decoding_config.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#pragma once - -#include "arch.h" -#include "block_iterator.h" -#include "cta_map.h" -#include "impl_81616.h" -#include "impl_simt.h" -#include "mainloop_sm70.h" -#include "mainloop_sm80.h" -#include "src/turbomind/kernels/attention/attention_universal.h" -#include "src/turbomind/kernels/attention/impl.h" -#include "src/turbomind/kernels/attention/mainloop.h" - -namespace turbomind::attention { - -template -struct DecodingConfig { - static_assert(sizeof(T) == 0, "config not found"); -}; - -template -using Decoding = typename DecodingConfig::Kernel; - -////////////////////////////////////////////////////////////// -template -struct DecodingConfig 2)>> { - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; -}; - -template -struct DecodingConfig 2)>> { - static constexpr int Qh = (Qh_ + 7) / 8 * 8; - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; -}; - -template -struct DecodingConfig> { - static constexpr int Qh = (Qh_ + 7) / 8 * 8; - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; -}; - -template -struct DecodingConfig { - static constexpr int Qh = (Qh_ + 7) / 8 * 8; - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; -}; - -////////////////////////////////////////////////////////////// - -template -struct DecodingConfig { - static constexpr int Qh = (Qh_ + 7) / 8 * 8; - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, CacheIter, DecodingCtaMap>; -}; - -////////////////////////////////////////////////////////////// - -template -struct DecodingConfig { - // CTA_S reduced from 64 to 32 so shared memory fits within V100's 96 KB limit. - static constexpr int kH = Qh % 3 == 0 ? 3 : (Qh % 2 == 0 ? 2 : 1); - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, CacheIter, DecodingCtaMap>; -}; - -template -struct DecodingConfig { - // Qh >= 4 is not beneficial for sm_70 - static constexpr int kH = Qh % 3 == 0 ? 3 : (Qh % 2 == 0 ? 2 : 1); - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, CacheIter, DecodingCtaMap>; -}; - -template -struct DecodingConfig { - static constexpr int Qh = 1; - static constexpr int HeadDim = 192; - - using Attention = Impl; - using CacheIter = GetBlockIterFactory; - using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; -}; - -} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/decoding_template.h b/src/turbomind/kernels/attention/decoding_template.h index ba1f8374d8..2088d7ac98 100644 --- a/src/turbomind/kernels/attention/decoding_template.h +++ b/src/turbomind/kernels/attention/decoding_template.h @@ -10,7 +10,7 @@ namespace turbomind { template -bool invokeDecoding(const typename Kernel::ParamType& params) +bool invokeDecoding(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas) { static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage); @@ -34,21 +34,6 @@ bool invokeDecoding(const typename Kernel::ParamType& params) auto kernel_func = &attention_kernel; - thread_local const int2 caps = [&] { - auto err = cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - if (err) { - std::cout << cudaGetErrorString(err) << "\n"; - std::abort(); - } - int device_id{}; - cudaGetDevice(&device_id); - int sm_count{}; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id); - int max_active_ctas{}; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_ctas, kernel_func, block.x, kSmemSize); - return int2{sm_count, max_active_ctas}; - }(); - const int q_group_size = params.num_heads / params.num_kv_heads; const int q_head_per_cta = std::min(q_group_size, Kernel::CTA_H); @@ -61,7 +46,7 @@ bool invokeDecoding(const typename Kernel::ParamType& params) dim3 grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, 1, cta_per_q_group); const int grid_size = grid.x * grid.y * grid.z; - const int split_cnt = GetSplitCount(max_split_count, grid_size, caps.y, caps.x, 4); + const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 4); grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, split_cnt, cta_per_q_group); diff --git a/src/turbomind/kernels/attention/desc.h b/src/turbomind/kernels/attention/desc.h new file mode 100644 index 0000000000..ecda582d3d --- /dev/null +++ b/src/turbomind/kernels/attention/desc.h @@ -0,0 +1,75 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/turbomind/core/data_type.h" +#include +#include +#include + +namespace turbomind::attention { + +struct AttnDesc { + enum Mode + { + kPrefill, + kDecoding + }; + Mode mode; + int head_dim; + DataType data_type; + int kv_quant; // 0=none, 8=int8, 4=int4 + int query_group_sz; // num_heads/num_kv_heads for decoding; 0 for prefill +}; + +inline std::string to_string(const AttnDesc& d) +{ + std::ostringstream ss; + ss << (d.mode == AttnDesc::kPrefill ? "prefill" : "decode"); + ss << "_d" << d.head_dim; + ss << "_" << to_string(d.data_type); + if (d.mode == AttnDesc::kDecoding) { + if (d.kv_quant == 8) + ss << "_kvint8"; + else if (d.kv_quant == 4) + ss << "_kvint4"; + ss << "_gs" << d.query_group_sz; + } + return ss.str(); +} + +struct KernelDesc { + AttnDesc::Mode mode; + int arch; // 700, 750, 800 + int head_dim; + DataType data_type; + int kv_quant; // 0=none, 8=int8, 4=int4 + int qh; // query heads per CTA (1 for prefill) +}; + +struct KernelInfo { + int dynamic_smem_size; + int max_active_ctas; + int num_warps; + std::string name; + cudaFuncAttributes attr; +}; + +inline std::string to_string(const KernelDesc& d) +{ + std::ostringstream ss; + ss << (d.mode == AttnDesc::kPrefill ? "prefill" : "decode"); + ss << "_sm" << d.arch / 10; + ss << "_d" << d.head_dim; + ss << "_" << to_string(d.data_type); + if (d.mode == AttnDesc::kDecoding) { + if (d.kv_quant == 8) + ss << "_kvint8"; + else if (d.kv_quant == 4) + ss << "_kvint4"; + ss << "_qh" << d.qh; + } + return ss.str(); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/impl_16816.h b/src/turbomind/kernels/attention/impl_16816.h index 3cc42c4b90..959bf4a7d3 100644 --- a/src/turbomind/kernels/attention/impl_16816.h +++ b/src/turbomind/kernels/attention/impl_16816.h @@ -6,6 +6,7 @@ #include "src/turbomind/kernels/attention/impl_m16n8.h" #include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/kernels/core/layout.h" +#include "src/turbomind/kernels/core/mma.h" #include "src/turbomind/kernels/core/smem.h" #include "src/turbomind/kernels/core/thread_map.h" @@ -17,6 +18,8 @@ struct Impl; + static constexpr bool MLA = HeadDim == 576; + using Base::OP_M; using Base::OP_N; using Base::K_M; diff --git a/src/turbomind/kernels/attention/impl_1688.h b/src/turbomind/kernels/attention/impl_1688.h index a822c58039..e1452015e5 100644 --- a/src/turbomind/kernels/attention/impl_1688.h +++ b/src/turbomind/kernels/attention/impl_1688.h @@ -17,6 +17,8 @@ struct Impl; + static constexpr bool MLA = HeadDim == 576; + using Base::OP_M; using Base::OP_N; using Base::K_M; diff --git a/src/turbomind/kernels/attention/impl_81616.h b/src/turbomind/kernels/attention/impl_81616.h index f865f1bc3a..debc9a28d4 100644 --- a/src/turbomind/kernels/attention/impl_81616.h +++ b/src/turbomind/kernels/attention/impl_81616.h @@ -29,6 +29,8 @@ struct Impl; + static constexpr bool MLA = HeadDim == 576; + static constexpr int CTA_H = CTA_H_; static constexpr int CTA_Q = CTA_Q_; static constexpr int CTA_S = CTA_S_; @@ -43,7 +45,7 @@ struct Impl[K_N][kWarpCntS][4]; + using SmemM = Array[K_N][kWarpCntH][kWarpCntS][4]; - using SmemO = Array[V_M][V_N][kWarpCntS][WARP_SIZE]; + using SmemO = Array[V_M][V_N][kWarpCntH][kWarpCntS][WARP_SIZE]; static constexpr bool kUseSmemQ = false; static constexpr bool kUseSmemP = false; @@ -146,7 +148,10 @@ struct Impl 1) { + __syncthreads(); + } + else if constexpr (kQuantKV) { // Thread layout of KV & KVp is different within warp boundary __syncwarp(); } } @@ -165,11 +170,23 @@ struct Impl 1) { + return {warp_id % kWarpCntS, warp_id / kWarpCntS}; + } + else { + return {warp_id, 0}; + } + } + template __device__ static void ForeachS(Fragment& S, Func&& func) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const auto warp_ids = get_warp_ids(); + const int lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL for (int m = 0; m < K_M; ++m) { PRAGMA_UNROLL @@ -178,8 +195,8 @@ struct Impl __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func) { - const int lane_id = threadIdx.x % WARP_SIZE; + const auto warp_ids = get_warp_ids(); + const int lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL for (int n = 0; n < K_N; ++n) { // Q PRAGMA_UNROLL for (int q = 0; q < 2; ++q) { - const int hi = lane_id % 4 * 2 + n * OP_N + q * 1; + const int hi = lane_id % 4 * 2 + n * OP_N + q * 1 + warp_ids.y * WARP_H; const int ri = lane_id / 4 * 1; ((Func &&) func)(hi, /*qi*/ 0, ri, frag_M[n][q], frag_L[n][q]); } @@ -207,14 +226,15 @@ struct Impl sQ{smem_Q}; - const int lane_id = threadIdx.x % WARP_SIZE; + const auto warp_ids = get_warp_ids(); + const int lane_id = threadIdx.x % WARP_SIZE; if constexpr (!kQuantKV) { PRAGMA_UNROLL for (int n = 0; n < K_N; ++n) { PRAGMA_UNROLL for (int k = 0; k < K_K; k += 2) { // 16x16 tile - const int hi = n * OP_N + lane_id % 8; + const int hi = n * OP_N + lane_id % 8 + warp_ids.y * WARP_H; const int di = k * OP_K + lane_id / 8 * 8; ldsm_x4((Array&)frag_Q[n][k], cast_smem_ptr_to_uint(&sQ(hi, di))); } @@ -229,7 +249,7 @@ struct Impl&)frag_Q[n][k + x][d * 2], &sQ(hi, di)); } @@ -263,21 +283,21 @@ struct Impl tmp_O; - Load(tmp_O, storage.O[m][n][w][lane_id].data()); + Load(tmp_O, storage.O[m][n][warp_ids.y][w][lane_id].data()); frag_O[m][n] = frag_O[m][n] + tmp_O; } - frag_L[n] = frag_L[n] + storage.L[n][w][lane_id % 4]; + frag_L[n] = frag_L[n] + storage.L[n][warp_ids.y][w][lane_id % 4]; } // PRAGMA_UNROLL // for (int q = 0; q < 2; ++q) { @@ -710,8 +734,8 @@ struct Impl __device__ static void StoreO(FragO& frag_O, const FragL& frag_L, SharedStorage& storage, Func&& func) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const auto warp_ids = get_warp_ids(); + const int lane_id = threadIdx.x % WARP_SIZE; FragL inv_L; PRAGMA_UNROLL @@ -739,13 +763,13 @@ struct Impl; + // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE + using Map = std::conditional_t, + RakedThreadMap>; Array tmp_O[Map::kIterS][Map::kIterC]; - const int2 offset = Map::get_offset(warp_id, lane_id); + + const int warp_id = threadIdx.x / WARP_SIZE; + const int2 offset = Map::get_offset(warp_id, lane_id); + PRAGMA_UNROLL for (int s = 0; s < Map::kIterS; ++s) { PRAGMA_UNROLL diff --git a/src/turbomind/kernels/attention/impl_884.h b/src/turbomind/kernels/attention/impl_884.h index e8d7e7a689..651c31a117 100644 --- a/src/turbomind/kernels/attention/impl_884.h +++ b/src/turbomind/kernels/attention/impl_884.h @@ -9,6 +9,7 @@ #include "src/turbomind/kernels/core/thread_map.h" #include +#include namespace turbomind::attention { @@ -17,6 +18,8 @@ struct Impl; - using ThreadMapKV = RakedThreadMap; + // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE + using ThreadMapQ = std::conditional_t, + RakedThreadMap>; + using ThreadMapKV = std::conditional_t, + RakedThreadMap>; using ThreadMapKVp = void; diff --git a/src/turbomind/kernels/attention/impl_simt.h b/src/turbomind/kernels/attention/impl_simt.h index 444b67e2c8..93497cf359 100644 --- a/src/turbomind/kernels/attention/impl_simt.h +++ b/src/turbomind/kernels/attention/impl_simt.h @@ -7,6 +7,7 @@ #include #include "src/turbomind/kernels/core/array_ops.h" +#include "src/turbomind/kernels/core/common.h" #include "src/turbomind/kernels/core/layout.h" #include "src/turbomind/kernels/core/thread_map.h" @@ -32,6 +33,8 @@ struct Impl; + static constexpr bool MLA = HeadDim == 576; + static constexpr int CTA_H = CTA_H_; static constexpr int CTA_Q = CTA_Q_; static constexpr int CTA_S = CTA_S_; @@ -44,9 +47,7 @@ struct Impl 1) { + __syncthreads(); + } if constexpr (kQuantKV) { // Thread layout of KV & KVp is different within warp boundary __syncwarp(); } @@ -198,13 +202,25 @@ struct Impl 1) { + return {warp_id % kWarpCntS, warp_id / kWarpCntS}; + } + else { + return {warp_id, 0}; + } + } + template __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func) { + const auto warp_ids = get_warp_ids(); PRAGMA_UNROLL for (int m = 0; m < K_M; ++m) { // Q - const int hi = m * OP_H; - const int ri = threadIdx.x; + const int hi = m * OP_H + warp_ids.y * WARP_H; + const int ri = threadIdx.x % (WARP_SIZE * kWarpCntS); ((Func &&) func)(hi, 0, ri, frag_M[m][0], frag_L[m][0]); } } @@ -212,16 +228,15 @@ struct Impl __device__ static void ForeachS(Fragment& S, Func&& func) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int warp_id_s = warp_id % kWarpCntS; - const int warp_id_h = warp_id / kWarpCntS; + const auto warp_ids = get_warp_ids(); + const int lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL for (int m = 0; m < K_M; ++m) { PRAGMA_UNROLL for (int n = 0; n < K_N; ++n) { - const int hi = m * OP_H + warp_id_h * WARP_H; - const int si = lane_id / T_D * S_S_thr + n * S_S + warp_id_s * WARP_S; + const int hi = m * OP_H + warp_ids.y * WARP_H; + const int si = lane_id / T_D * S_S_thr + n * S_S + warp_ids.x * WARP_S; const int ri = lane_id % T_D; ((Func &&) func)(hi, /*qi*/ 0, si, ri, S[m][n][0]); } @@ -230,10 +245,8 @@ struct Impl&)frag_O[m][n][d]); } } @@ -573,7 +583,7 @@ struct Impl tmp_O; - Lds(tmp_O, storage.O[m][n][d / 4][warp_id_h][s + lane_id / 8][lane_id % T_D].data()); + Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s + lane_id / 8][lane_id % T_D].data()); using namespace ops; (Array&)frag_O[m][n][d] = (Array&)frag_O[m][n][d] + tmp_O; } @@ -610,7 +620,7 @@ struct Impl tmp_O; - Lds(tmp_O, storage.O[m][n][d / 4][warp_id_h][s][lane_id % T_D].data()); + Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s][lane_id % T_D].data()); using namespace ops; (Array&)frag_O[m][n][d] = (Array&)frag_O[m][n][d] + tmp_O; } @@ -619,7 +629,7 @@ struct Impl #include diff --git a/src/turbomind/kernels/attention/kernel.h b/src/turbomind/kernels/attention/kernel.h new file mode 100644 index 0000000000..2c25dd58ae --- /dev/null +++ b/src/turbomind/kernels/attention/kernel.h @@ -0,0 +1,47 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/turbomind/kernels/attention/desc.h" + +namespace turbomind::attention { + +class Kernel { +public: + Kernel(): desc_{}, info_{} {} + + virtual ~Kernel() = default; + + virtual bool Launch(const void* params, int sm_count) const = 0; + + const KernelDesc& desc() const noexcept + { + return desc_; + } + + const KernelInfo& info() const noexcept + { + return info_; + } + + int arch() const noexcept + { + return desc_.arch; + } + + int smem_size() const noexcept + { + return info_.attr.sharedSizeBytes + info_.dynamic_smem_size; + } + + const std::string& name() const + { + return info_.name; + } + +protected: + KernelDesc desc_; + KernelInfo info_; +}; + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/CMakeLists.txt b/src/turbomind/kernels/attention/kernel/CMakeLists.txt new file mode 100644 index 0000000000..7f1094b427 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +add_library(attention_kernels STATIC + ../utils.cc + ../reduce.cu + attention_sm70_64.cu + attention_sm70_128.cu + attention_sm70_256.cu + attention_sm70_576.cu + attention_sm75_64.cu + attention_sm75_128.cu + attention_sm75_256.cu + attention_sm75_576.cu + attention_sm80_64.cu + attention_sm80_128.cu + attention_sm80_192.cu + attention_sm80_256.cu + attention_sm80_576.cu + decoding_sm70_64.cu + decoding_sm70_128.cu + decoding_sm70_256.cu + decoding_sm70_576.cu + decoding_sm75_64.cu + decoding_sm75_128.cu + decoding_sm75_256.cu + decoding_sm75_576.cu + decoding_sm80_64.cu + decoding_sm80_128.cu + decoding_sm80_192.cu + decoding_sm80_256.cu + decoding_sm80_576.cu + ) +set_property(TARGET attention_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_compile_options(attention_kernels PRIVATE -O3 + $<$:-use_fast_math --expt-relaxed-constexpr -Xptxas=-v --threads 8>) +target_link_libraries(attention_kernels PRIVATE nvidia::cutlass::cutlass) diff --git a/src/turbomind/kernels/attention/kernel/attention_sm70_128.cu b/src/turbomind/kernels/attention/kernel/attention_sm70_128.cu new file mode 100644 index 0000000000..4a52c1199c --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm70_128.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_884.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm70, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm70_256.cu b/src/turbomind/kernels/attention/kernel/attention_sm70_256.cu new file mode 100644 index 0000000000..62c7f90e80 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm70_256.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_884.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm70, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm70_576.cu b/src/turbomind/kernels/attention/kernel/attention_sm70_576.cu new file mode 100644 index 0000000000..f62ab5d4d2 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm70_576.cu @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_884.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +// HeadDim=576 on Sm70: kCTA_S=32, WARP_S=kCTA_S to fit within V100's 96 KB shared memory limit +constexpr int kHeadDim = 576; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 32; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm70, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm70_64.cu b/src/turbomind/kernels/attention/kernel/attention_sm70_64.cu new file mode 100644 index 0000000000..107d2537a9 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm70_64.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_884.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 64; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm70, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm75_128.cu b/src/turbomind/kernels/attention/kernel/attention_sm75_128.cu new file mode 100644 index 0000000000..c6e90cb9ad --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm75_128.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_1688.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm75, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm75_256.cu b/src/turbomind/kernels/attention/kernel/attention_sm75_256.cu new file mode 100644 index 0000000000..e659cd4193 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm75_256.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_1688.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm75, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm75_576.cu b/src/turbomind/kernels/attention/kernel/attention_sm75_576.cu new file mode 100644 index 0000000000..bc98b7b5ac --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm75_576.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_1688.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 576; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 32; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm75, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm75_64.cu b/src/turbomind/kernels/attention/kernel/attention_sm75_64.cu new file mode 100644 index 0000000000..6f6b5deb9c --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm75_64.cu @@ -0,0 +1,30 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_1688.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 64; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm75, + Mainloop>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { c.add>(); }); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm80_128.cu b/src/turbomind/kernels/attention/kernel/attention_sm80_128.cu new file mode 100644 index 0000000000..684a102e99 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm80_128.cu @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_16816.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); +#if ENABLE_BF16 + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm80_192.cu b/src/turbomind/kernels/attention/kernel/attention_sm80_192.cu new file mode 100644 index 0000000000..aee9e2bbe5 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm80_192.cu @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_16816.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 192; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); +#if ENABLE_BF16 + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm80_256.cu b/src/turbomind/kernels/attention/kernel/attention_sm80_256.cu new file mode 100644 index 0000000000..76bdd1c227 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm80_256.cu @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_16816.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); +#if ENABLE_BF16 + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm80_576.cu b/src/turbomind/kernels/attention/kernel/attention_sm80_576.cu new file mode 100644 index 0000000000..f0ba8e94cd --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm80_576.cu @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_16816.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 576; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 32; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); +#if ENABLE_BF16 + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/attention_sm80_64.cu b/src/turbomind/kernels/attention/kernel/attention_sm80_64.cu new file mode 100644 index 0000000000..6ddbf345ca --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/attention_sm80_64.cu @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_16816.h" +#include "src/turbomind/kernels/attention/linear_iterator.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +// HeadDim=64 special case: kCTA_S=128, WARP_S=kCTA_S +constexpr int kHeadDim = 64; +constexpr int kCTA_Q = 64; +constexpr int kCTA_S = 128; +constexpr int kWARP_Q = 16; +constexpr int kStages = 2; + +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + LinearIteratorFactory, + AttentionCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); +#if ENABLE_BF16 + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu b/src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu new file mode 100644 index 0000000000..eef50f01f4 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu @@ -0,0 +1,44 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 2; + +// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1) +// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9} +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu b/src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu new file mode 100644 index 0000000000..5442ffac1d --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu @@ -0,0 +1,44 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 3; + +// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1) +// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9} +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu b/src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu new file mode 100644 index 0000000000..6f1914abdd --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 576; + +// CTA_H=2, CTA_S=16, WARP_H=1, WARP_S=8, Stages=2 +template +using KT = AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu b/src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu new file mode 100644 index 0000000000..20b0db10ba --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu @@ -0,0 +1,44 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 2; + +// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1) +// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9} +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu b/src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu new file mode 100644 index 0000000000..fd0c9af948 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu @@ -0,0 +1,40 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 2; + +// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16 +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu b/src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu new file mode 100644 index 0000000000..9aa4d0401b --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 3; + +// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16 +// For 256 head dim, we use Qh=1 and Qh=9 (which maps to 16) +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); // Qh=9 maps to 16 + + c.add>(); + c.add>(); // Qh=9 maps to 16 + + c.add>(); + c.add>(); // Qh=9 maps to 16 +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu b/src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu new file mode 100644 index 0000000000..182d5d7b39 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 576; + +// MLA config for all Tkv: CTA_H=16, CTA_S=16, WARP_H=8, WARP_S=16, Stages=2 +template +using KT = AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu b/src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu new file mode 100644 index 0000000000..b1a7424552 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu @@ -0,0 +1,40 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm70.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 2; + +// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16 +template +using KT = + AttentionUniversal>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + + c.add>(); + c.add>(); + + c.add>(); + c.add>(); +}); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu b/src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu new file mode 100644 index 0000000000..a94895e7ad --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu @@ -0,0 +1,59 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 128; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; + +template +using KT = AttentionUniversal; + +// T==Tkv, Qh<=2: SIMT, stages=3 +template +using Decoding_SIMT = KT, Impl>, + GetBlockIterFactory>; + +// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv +// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16 +template +using Decoding_MMA = + KT, Impl>, + GetBlockIterFactory>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + +#if ENABLE_BF16 + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu b/src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu new file mode 100644 index 0000000000..10b67c8bcd --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu @@ -0,0 +1,41 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 192; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; +constexpr int kStages = 3; +constexpr int kQh = 1; + +// HeadDim=192 uses SIMT+kStages for all Tkv (incl. uint8_t), kQh=1 only +template +using KT = AttentionUniversal< + arch::Sm80, + Mainloop, Impl>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + +#if ENABLE_BF16 + c.add>(); + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu b/src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu new file mode 100644 index 0000000000..fd47b611a1 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu @@ -0,0 +1,59 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 256; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; + +template +using KT = AttentionUniversal; + +// T==Tkv, Qh<=2: SIMT, stages=3 +template +using Decoding_SIMT = KT, Impl>, + GetBlockIterFactory>; + +// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv +// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16 +template +using Decoding_MMA = + KT, Impl>, + GetBlockIterFactory>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + +#if ENABLE_BF16 + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu b/src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu new file mode 100644 index 0000000000..1ce1ce08c3 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu @@ -0,0 +1,47 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 576; + +// Non-quant MLA config: CTA_H=16, CTA_S=32, WARP_H=8, WARP_S=16, Stages=2 +template +using Decoding_F = + AttentionUniversal, Impl>, + GetBlockIterFactory, + DecodingCtaMap>; + +// Quant config: CTA_H=8, CTA_S=64, WARP_H=8, WARP_S=16, Stages=5 +template +using Decoding_Q = + AttentionUniversal, Impl>, + GetBlockIterFactory, + DecodingCtaMap>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + +#if ENABLE_BF16 + c.add>(); + c.add>(); + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu b/src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu new file mode 100644 index 0000000000..d3c29dc6ee --- /dev/null +++ b/src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu @@ -0,0 +1,59 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/attention_universal.h" +#include "src/turbomind/kernels/attention/block_iterator.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/impl_81616.h" +#include "src/turbomind/kernels/attention/impl_simt.h" +#include "src/turbomind/kernels/attention/mainloop.h" +#include "src/turbomind/kernels/attention/mainloop_sm80.h" +#include "src/turbomind/kernels/attention/registrar.h" + +namespace turbomind::attention { + +constexpr int kHeadDim = 64; +constexpr int kCTA_S = 64; +constexpr int kWARP_S = 16; + +template +using KT = AttentionUniversal; + +// T==Tkv, Qh<=2: SIMT, stages=3 +template +using Decoding_SIMT = KT, Impl>, + GetBlockIterFactory>; + +// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv +// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16 +template +using Decoding_MMA = + KT, Impl>, + GetBlockIterFactory>; + +namespace { +Registrar reg([](Collector& c) { + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + +#if ENABLE_BF16 + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); + c.add>(); +#endif +}); +} // namespace + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kernel_impl.h b/src/turbomind/kernels/attention/kernel_impl.h new file mode 100644 index 0000000000..e16760efc8 --- /dev/null +++ b/src/turbomind/kernels/attention/kernel_impl.h @@ -0,0 +1,80 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include + +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/kernels/attention/attention_template.h" +#include "src/turbomind/kernels/attention/cta_map.h" +#include "src/turbomind/kernels/attention/decoding_template.h" +#include "src/turbomind/kernels/attention/kernel.h" +#include "src/turbomind/kernels/core/common.h" + +namespace turbomind::attention { + +template +constexpr int kv_quant_from_type() +{ + if constexpr (std::is_same_v) { + return 8; + } + else if constexpr (std::is_same_v) { + return 4; + } + else { + return 0; + } +} + +template +class KernelImpl: public Kernel { + static constexpr bool kIsDecoding = std::is_same_v; + +public: + KernelImpl() + { + desc_.mode = kIsDecoding ? AttnDesc::kDecoding : AttnDesc::kPrefill; + desc_.arch = K::Arch::value; + desc_.head_dim = K::kHeadDim; + desc_.data_type = data_type_v; + + if constexpr (kIsDecoding) { + desc_.kv_quant = kv_quant_from_type(); + desc_.qh = K::CTA_H; + } + else { + desc_.kv_quant = 0; + desc_.qh = 1; + } + + auto func = &attention_kernel; + info_.dynamic_smem_size = sizeof(typename K::SharedStorage); + + cudaFuncGetAttributes(&info_.attr, func); + + if (info_.dynamic_smem_size > (48 << 10)) { + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size); + } + + info_.num_warps = K::kWarpCount; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &info_.max_active_ctas, func, info_.num_warps * WARP_SIZE, info_.dynamic_smem_size); + + info_.name = to_string(desc_); + } + + bool Launch(const void* params, int sm_count) const override + { + const auto& p = *static_cast(params); + if constexpr (kIsDecoding) { + return invokeDecoding(p, sm_count, info_.max_active_ctas); + } + else { + invokeAttention(p, sm_count, info_.max_active_ctas); + return true; + } + } +}; + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 7b91d1e0b3..bc1b09eb91 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -226,20 +226,21 @@ void invokeProcessKV_v2(char** blocks, int quant_policy, cudaStream_t stream) { - constexpr int WARPS = 4; - constexpr int CTA_S = 64; - - int block = WARPS * WARP_SIZE; - dim3 grid((max_q_len + CTA_S - 1) / CTA_S, head_num, batch_size); auto invoke = [&](auto tkv, const auto dim) { using Tkv = decltype(tkv); - constexpr int kHeadDim = dim; - FT_CHECK(head_dim == kHeadDim); - + constexpr int kHeadDim = dim; constexpr bool kShareKV = kHeadDim == 576; + constexpr int WARPS = 4; + constexpr int CTA_S = kShareKV ? 32 : 64; + + int block = WARPS * WARP_SIZE; + dim3 grid(cdiv(max_q_len, CTA_S), head_num, batch_size); + + TM_CHECK_EQ(head_dim, kHeadDim); + block::Layout block_layout{block::Config{head_num, block_seq_len}}; ProcessKV_v2<<>>(blocks, @@ -262,7 +263,8 @@ void invokeProcessKV_v2(char** blocks, }; auto dispatch = [&](auto tkv) { - if (head_dim == 64) { + if (0) {} + else if (head_dim == 64) { return invoke(tkv, std::integral_constant{}); } else if (head_dim == 128) { @@ -271,6 +273,9 @@ void invokeProcessKV_v2(char** blocks, else if (head_dim == 192) { return invoke(tkv, std::integral_constant{}); } + else if (head_dim == 256) { + return invoke(tkv, std::integral_constant{}); + } else if (head_dim == 576) { return invoke(tkv, std::integral_constant{}); } @@ -470,20 +475,21 @@ void invokeFlattenKV_v2(T* k, int quant_policy, cudaStream_t stream) { - constexpr int kWarpCnt = 4; - constexpr int CTA_S = 64; - - constexpr int block = kWarpCnt * WARP_SIZE; - const dim3 grid((max_seq_len + CTA_S - 1) / CTA_S, head_num, batch_size); auto invoke = [&](auto tkv, const auto dim) { using Tkv = decltype(tkv); - constexpr int kHeadDim = dim; - FT_CHECK(head_dim == kHeadDim); - + constexpr int kHeadDim = dim; constexpr bool kShareKV = kHeadDim == 576; + constexpr int kWarpCnt = 4; + constexpr int CTA_S = kShareKV ? 32 : 64; + + constexpr int block = kWarpCnt * WARP_SIZE; + const dim3 grid((max_seq_len + CTA_S - 1) / CTA_S, head_num, batch_size); + + TM_CHECK_EQ(head_dim, kHeadDim); + block::Layout block_layout{block::Config{head_num, block_seq_len}}; flattenKV_v2<<>>(k, @@ -503,7 +509,8 @@ void invokeFlattenKV_v2(T* k, }; auto dispatch = [&](auto tkv) { - if (head_dim == 64) { + if (0) {} + else if (head_dim == 64) { return invoke(tkv, std::integral_constant{}); } else if (head_dim == 128) { @@ -512,6 +519,9 @@ void invokeFlattenKV_v2(T* k, else if (head_dim == 192) { return invoke(tkv, std::integral_constant{}); } + else if (head_dim == 256) { + return invoke(tkv, std::integral_constant{}); + } else if (head_dim == 576) { return invoke(tkv, std::integral_constant{}); } diff --git a/src/turbomind/kernels/attention/mainloop_sm80.h b/src/turbomind/kernels/attention/mainloop_sm80.h index 3b07b717e4..378c2e5595 100644 --- a/src/turbomind/kernels/attention/mainloop_sm80.h +++ b/src/turbomind/kernels/attention/mainloop_sm80.h @@ -61,7 +61,10 @@ struct Mainloop, Impl_> { template __device__ void operator()(Args&&... args) { - Run(Sm80_CpAsync{}, std::integral_constant{}, ((Args &&) args)...); + Run(Sm80_CpAsync{}, + std::integral_constant{}, + std::integral_constant{}, + ((Args &&) args)...); } template @@ -93,6 +96,7 @@ struct Mainloop, Impl_> { template __device__ void Run(Sm80_CpAsync, std::integral_constant, + std::false_type, // is MLA FragQ& frag_Q, CacheIter& cache_iter, FragO& frag_O, @@ -219,6 +223,7 @@ struct Mainloop, Impl_> { template __device__ void Run(Sm80_CpAsync<2>, std::integral_constant, + std::false_type, // is MLA FragQ& frag_Q, CacheIter& cache_iter, FragO& frag_O, @@ -324,6 +329,7 @@ struct Mainloop, Impl_> { template __device__ void Run(Sm80_CpAsync<2>, std::integral_constant, + std::false_type, // is MLA FragQ& frag_Q, CacheIter& cache_iter_, FragO& frag_O, @@ -442,6 +448,102 @@ struct Mainloop, Impl_> { } #endif + // Simplified MLA implementation + template + __device__ void Run(Sm80_CpAsync, + std::integral_constant, + std::true_type, // is MLA + FragQ& frag_Q, + CacheIter& cache_iter, + FragO& frag_O, + FragM& frag_M, + FragL& frag_L, + int offset_Q, + int offset_K, + int max_step, + int tile_iter, + int mask_iter_back, + int mask_iter_front, + int window_size, + float qk_scale, + SharedStorage& storage, + const StoreS& store_S) + { + GmemIterK gmem_KV{}; + + Impl::SetSmemKV(gmem_KV, gmem_KV, storage, false); + + PipeIter pipe_iter; + + PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + gmem_KV.ClearSmem((++pipe_iter).w); + } + + Impl::Sync(); + + gmem_KV.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w); + __pipeline_commit(); + cache_iter.Advance(); + + PRAGMA_UNROLL + for (int stages = 1; stages < Stages - 1; ++stages) { + gmem_KV.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w); + __pipeline_commit(); + cache_iter.Advance(); + } + + typename Impl::StateQK state_QK{storage, frag_Q}; + typename Impl::StatePV state_PV{storage}; + + Wait(); + state_QK.Load(0, (++pipe_iter).r); + + auto loop = [&](auto is_mask) { + __align__(16) FragS frag_S{}; + + gmem_KV.Prefetch(false_c, cache_iter, CTA_S, pipe_iter.w); + __pipeline_commit(); + cache_iter.Advance(); + + Impl::ComputeQK( + state_QK, frag_S, pipe_iter.r, [](int) {}, [] {}); + + if constexpr (is_mask) { + ApplyCasualMask(frag_S, offset_Q, offset_K, window_size); + } + + Impl::Softmax(frag_S, frag_M, frag_L, frag_O, qk_scale); + + Impl::ConvertStoP(frag_S, state_PV.frag_P, storage); + + state_PV.Load(0, pipe_iter.r); + Impl::ComputePV( + state_PV, frag_O, pipe_iter.r, [](int) {}, [] {}); + + Wait(); + state_QK.Load(0, (++pipe_iter).r); + + offset_K -= CTA_S; + }; + + for (int mask_iter = mask_iter_back; tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) { + loop(true_c); + } + + PRAGMA_NO_UNROLL + for (; tile_iter > mask_iter_front; --tile_iter) { + loop(false_c); + } + + for (; tile_iter > 0; --tile_iter) { + loop(true_c); + } + + __pipeline_commit(); + __pipeline_wait_prior(0); + } + __device__ void Wait() { __pipeline_wait_prior(Stages - 2); diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index bc81e3c517..b761c85b11 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -294,12 +294,14 @@ void invokeReduceV3(T* out, INSTANTIATE_invokeReduceV3(64, half); INSTANTIATE_invokeReduceV3(128, half); INSTANTIATE_invokeReduceV3(192, half); +INSTANTIATE_invokeReduceV3(256, half); INSTANTIATE_invokeReduceV3(576, half); #if ENABLE_BF16 INSTANTIATE_invokeReduceV3(64, nv_bfloat16); INSTANTIATE_invokeReduceV3(128, nv_bfloat16); INSTANTIATE_invokeReduceV3(192, nv_bfloat16); +INSTANTIATE_invokeReduceV3(256, nv_bfloat16); INSTANTIATE_invokeReduceV3(576, nv_bfloat16); #endif diff --git a/src/turbomind/kernels/attention/registrar.h b/src/turbomind/kernels/attention/registrar.h new file mode 100644 index 0000000000..61f6dcecf3 --- /dev/null +++ b/src/turbomind/kernels/attention/registrar.h @@ -0,0 +1,46 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "src/turbomind/kernels/attention/kernel_impl.h" + +namespace turbomind::attention { + +class Collector { +public: + template + void add() + { + kernels_.emplace_back(std::make_unique>()); + // std::cout << "add kernel: " << to_string(kernels_.back()->desc()) << std::endl; + } + + std::vector> release() + { + return std::move(kernels_); + } + +private: + std::vector> kernels_; +}; + +using RegisterFn = std::function; + +inline std::vector& gKernelFactories() +{ + static std::vector v; + return v; +} + +struct Registrar { + explicit Registrar(RegisterFn fn) + { + gKernelFactories().push_back(std::move(fn)); + } +}; + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/registry.cu b/src/turbomind/kernels/attention/registry.cu new file mode 100644 index 0000000000..cd6ad40a86 --- /dev/null +++ b/src/turbomind/kernels/attention/registry.cu @@ -0,0 +1,115 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/turbomind/kernels/attention/registry.h" + +#include +#include +#include +#include + +#include "src/turbomind/core/check.h" +#include "src/turbomind/kernels/attention/arch.h" +#include "src/turbomind/kernels/attention/registrar.h" +#include "src/turbomind/kernels/core/math.h" + +namespace turbomind::attention { + +namespace { + +constexpr float kMaxWasteRatio = 1.f; + +} // namespace + +Registry::Registry(std::shared_ptr device_prop): + device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10} +{ + for (auto& register_fn : gKernelFactories()) { + Collector collector; + register_fn(collector); + for (auto& k : collector.release()) { + Add(std::move(k)); + } + } +} + +bool Registry::Add(std::unique_ptr kernel) +{ + bool is_valid = true; + + if (!arch::is_arch_compatible(kernel->arch(), arch_)) { + is_valid = false; + } + + if ((int)device_prop_->sharedMemPerBlockOptin < kernel->smem_size()) { + is_valid = false; + } + + if (is_valid) { + ptrs_.push_back(kernels_.emplace_back(std::move(kernel)).get()); + } + + return is_valid; +} + +const Kernel* Registry::Find(const AttnDesc& desc) const +{ + const int threshold = static_cast(kMaxWasteRatio * desc.query_group_sz); + + const Kernel* best = nullptr; + std::tuple cost{}; + + for (const auto* k : ptrs_) { + const auto& d = k->desc(); + if (d.mode != desc.mode || d.head_dim != desc.head_dim // + || d.data_type != desc.data_type || d.kv_quant != desc.kv_quant) { + continue; + } + if (desc.mode == AttnDesc::kDecoding) { + const int ctas = cdiv(desc.query_group_sz, d.qh); + const int waste = d.qh * ctas - desc.query_group_sz; + + const auto v = std::make_tuple(waste > threshold, ctas, waste); + if (!best || v < cost) { + best = k; + cost = v; + } + } + else { // attention, return on first match + return k; + } + } + return best; +} + +Registry& Registry::instance() +{ + struct DeviceState { + std::unique_ptr registry; + std::once_flag flag; + }; + + static std::vector> states = [] { + int count{}; + TM_CHECK_EQ(cudaGetDeviceCount(&count), cudaSuccess); + std::vector> vec(count); + for (auto& s : vec) { + s = std::make_unique(); + } + return vec; + }(); + + int device_id{}; + TM_CHECK_EQ(cudaGetDevice(&device_id), cudaSuccess); + + auto& state = *states.at(device_id); + + std::call_once(state.flag, [&]() { + auto prop = std::make_shared(); + TM_CHECK_EQ(cudaGetDeviceProperties(prop.get(), device_id), cudaSuccess); + state.registry = std::make_unique(std::move(prop)); + }); + + return *TM_CHECK_NOTNULL(state.registry); +} + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/registry.h b/src/turbomind/kernels/attention/registry.h new file mode 100644 index 0000000000..b739a8f11f --- /dev/null +++ b/src/turbomind/kernels/attention/registry.h @@ -0,0 +1,45 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include + +#include "src/turbomind/kernels/attention/kernel_impl.h" + +namespace turbomind::attention { + +class Registry { +public: + explicit Registry(std::shared_ptr device_prop); + + template + [[maybe_unused]] bool Add() + { + return Add(std::make_unique>()); + } + + const Kernel* Find(const AttnDesc& desc) const; + + [[nodiscard]] const std::vector& kernels() const + { + return ptrs_; + } + + int sm_count() const noexcept + { + return device_prop_->multiProcessorCount; + } + + static Registry& instance(); + +private: + bool Add(std::unique_ptr kernel); + + std::shared_ptr device_prop_; + int arch_; + std::vector> kernels_; + std::vector ptrs_; +}; + +} // namespace turbomind::attention diff --git a/src/turbomind/kernels/core/thread_map.h b/src/turbomind/kernels/core/thread_map.h index 1271aefcc0..0e3cd58c77 100644 --- a/src/turbomind/kernels/core/thread_map.h +++ b/src/turbomind/kernels/core/thread_map.h @@ -52,7 +52,7 @@ struct ThreadMapQ { } }; -template +template struct RakedThreadMap { static constexpr int kDimC = DimC; static constexpr int kDimS = DimS; @@ -63,26 +63,25 @@ struct RakedThreadMap { static constexpr int kWarpThreadC = WarpThreadC; static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC; - static_assert(kWarpThreadC <= WARP_SIZE); + static_assert(WARP_SIZE % kWarpThreadC == 0); static constexpr int kWarpAccessC = kWarpThreadC * kAccessC; static constexpr int kWarpAccessS = kWarpThreadS; - static constexpr int kWarpIterC = (kDimC + kWarpAccessC - 1) / kWarpAccessC; - static constexpr int kWarpIterS = kDimS / kWarpAccessS; + static constexpr int kWarpIterC = cdiv(kDimC, kWarpAccessC); + static constexpr int kWarpIterS = cdiv(kDimS, kWarpAccessS); - static constexpr int kWarpC = 1; - static constexpr int kWarpS = kWarpCount; + static constexpr int kWarpC = WarpC; + static constexpr int kWarpS = kWarpCount / kWarpC; - static constexpr int kIterC = kWarpIterC / kWarpC; - static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1); + static_assert(kWarpCount % kWarpC == 0); + + static constexpr int kIterC = cdiv(kWarpIterC, kWarpC); + static constexpr int kIterS = cdiv(kWarpIterS, kWarpS); // Allow partial tile when there is ONLY 1 iteration static_assert(kDimC % kWarpAccessC == 0 || kIterC == 1); - static_assert(kIterC > 0); - static_assert(kIterS > 0); - static constexpr bool kPartialC = kDimC % kWarpAccessC != 0; static constexpr int kFootprintC = kWarpAccessC * kIterC; diff --git a/src/turbomind/kernels/norm/rms_norm.cu b/src/turbomind/kernels/norm/rms_norm.cu index ed263ddacf..032c417a6b 100644 --- a/src/turbomind/kernels/norm/rms_norm.cu +++ b/src/turbomind/kernels/norm/rms_norm.cu @@ -192,24 +192,32 @@ void invokeQkRMSNorm(void* data, cudaStream_t stream) { - constexpr constant<128> max_dim{}; - TM_CHECK_LE(head_dim, max_dim); - auto invoke = [&](auto t) { using T = decltype(t); - constexpr int vec_size = sizeof(uint4) / sizeof(T); - // Captured constexpr may not be constant to MSVC - constexpr int thr_per_qk = max_dim.value / vec_size; + auto launch = [&](auto max_dim_c) { + constexpr int kMaxDim = std::decay_t::value; + TM_CHECK_LE(head_dim, kMaxDim); - FT_CHECK(head_dim % vec_size == 0); + constexpr int vec_size = sizeof(uint4) / sizeof(T); + constexpr int thr_per_qk = kMaxDim / vec_size; - const int threads = thr_per_qk * n * (int64_t)token_num; - const int block_dim = 512; - const int grid_dim = cdiv(threads, block_dim); + FT_CHECK(head_dim % vec_size == 0); - kernel::RMSNormQK<<>>( - (T*)data, ld, (const T*)weight, head_dim, n, token_num, eps, 1.f / head_dim); + const int threads = thr_per_qk * n * (int64_t)token_num; + const int block_dim = 512; + const int grid_dim = cdiv(threads, block_dim); + + kernel::RMSNormQK<<>>( + (T*)data, ld, (const T*)weight, head_dim, n, token_num, eps, 1.f / head_dim); + }; + + if (head_dim <= 128) { + launch(constant<128>{}); + } + else { + launch(constant<256>{}); + } }; TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke); @@ -227,23 +235,32 @@ void invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st) auto data = x.raw_data(); auto stride = x.stride(0); - constexpr constant<128> max_dim{}; - TM_CHECK_LE(head_dim, max_dim); - auto invoke = [&](auto t) { using T = decltype(t); - constexpr int vec_size = sizeof(uint4) / sizeof(T); - constexpr int thr_per_qk = max_dim.value / vec_size; + auto launch = [&](auto max_dim_c) { + constexpr int kMaxDim = std::decay_t::value; + TM_CHECK_LE(head_dim, kMaxDim); + + constexpr int vec_size = sizeof(uint4) / sizeof(T); + constexpr int thr_per_qk = kMaxDim / vec_size; + + TM_CHECK(head_dim % vec_size == 0); - TM_CHECK(head_dim % vec_size == 0); + const int threads = token_num * head_num * thr_per_qk; + const int block_dim = 512; + const int grid_dim = cdiv(threads, block_dim); - const int threads = token_num * head_num * thr_per_qk; - const int block_dim = 512; - const int grid_dim = cdiv(threads, block_dim); + kernel::RMSNormQK<<>>( + (T*)data, stride, (const T*)w.raw_data(), head_dim, head_num, token_num, eps, 1.f / head_dim); + }; - kernel::RMSNormQK<<>>( - (T*)data, stride, (const T*)w.raw_data(), head_dim, head_num, token_num, eps, 1.f / head_dim); + if (head_dim <= 128) { + launch(constant<128>{}); + } + else { + launch(constant<256>{}); + } }; TM_DISPATCH_PRIMARY_DTYPES(x.dtype(), invoke); diff --git a/src/turbomind/models/CMakeLists.txt b/src/turbomind/models/CMakeLists.txt index 87efb5eec6..c31d4a00fb 100644 --- a/src/turbomind/models/CMakeLists.txt +++ b/src/turbomind/models/CMakeLists.txt @@ -17,7 +17,10 @@ add_library(models STATIC llama/unified_attention_layer.cc llama/llama_kernels.cu llama/llama_utils.cu - llama/mla_utils.cu) + llama/mla_utils.cu + llama/GatedDeltaNetWeight.cc + llama/GatedDeltaNetLayer.cc + llama/gated_delta_net_kernels.cu) set_property(TARGET models PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET models PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(models PUBLIC diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index a740ee18e4..aaa7467f64 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -22,6 +22,9 @@ add_library(Llama STATIC llama_kernels.cu llama_utils.cu mla_utils.cu + GatedDeltaNetWeight.cc + GatedDeltaNetLayer.cc + gated_delta_net_kernels.cu ) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/models/llama/GatedDeltaNetLayer.cc b/src/turbomind/models/llama/GatedDeltaNetLayer.cc new file mode 100644 index 0000000000..894fee6cb3 --- /dev/null +++ b/src/turbomind/models/llama/GatedDeltaNetLayer.cc @@ -0,0 +1,312 @@ +#include "src/turbomind/models/llama/GatedDeltaNetLayer.h" +#include "src/turbomind/core/check.h" +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/models/llama/gated_delta_net_kernels.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind { + +GatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam& model, + const AttentionParam& attn, + const EngineParam& engine, + int tp_size, + const Context& ctx, + int phases): + hidden_units_(model.hidden_units), + num_k_heads_(model.linear_num_key_heads / tp_size), + num_v_heads_(model.linear_num_value_heads / tp_size), + key_head_dim_(model.linear_key_head_dim > 0 ? model.linear_key_head_dim : model.head_dim), + value_head_dim_(model.linear_value_head_dim > 0 ? model.linear_value_head_dim : model.head_dim), + d_conv_(model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4), + key_dim_(num_k_heads_ * key_head_dim_), + value_dim_(num_v_heads_ * value_head_dim_), + conv_dim_(key_dim_ * 2 + value_dim_), + norm_eps_(model.norm_eps), + dtype_(model.data_type), + linear_(*ctx.linear) +{ + layer_types_ = model.layer_types; + num_linear_layers_ = 0; + for (auto t : layer_types_) { + if (t == 1) + ++num_linear_layers_; + } + + TM_LOG_INFO("GatedDeltaNetLayer: num_k=%d num_v=%d k_dim=%d v_dim=%d " + "conv_dim=%d d_conv=%d num_linear_layers=%d", + num_k_heads_, + num_v_heads_, + key_dim_, + value_dim_, + conv_dim_, + d_conv_, + num_linear_layers_); + + for (int i = 0; i < phases; ++i) { + phase_data_.push_back(std::make_shared()); + } +} + +GatedDeltaNetLayer::~GatedDeltaNetLayer() = default; + +void GatedDeltaNetLayer::Run(BatchOp op, int phase, TensorMap& env) +{ + if (op == BatchOp::kAdd) { + Buffer_ rc = env.at("requests").buffer(); + const auto dtype = dtype_; + for (int i = 0; i < rc.size(); ++i) { + auto& c = *rc[i]; + if (num_linear_layers_ > 0) { + c.conv_states = Tensor{{num_linear_layers_, conv_dim_, d_conv_}, dtype, kDEVICE}; + Clear(c.conv_states); + c.recurrent_states = + Tensor{{num_linear_layers_, num_v_heads_, key_head_dim_, value_head_dim_}, dtype, kDEVICE}; + Clear(c.recurrent_states); + } + } + } + else if (op == BatchOp::kSetup) { + Setup(phase, env); + } + else if (op == BatchOp::kPrepare) { + auto& d = *phase_data_.at(phase); + d.q_offsets = env.at("q_offsets").buffer().borrow(); + } +} + +void GatedDeltaNetLayer::Setup(int phase, TensorMap& env) +{ + auto& d = *phase_data_.at(phase); + const auto& batch = *env.at("batch").data()[0]; + + d.batch_size = batch.rc.size(); + d.rc.resize(d.batch_size); + d.input_lens.resize(d.batch_size); + for (int i = 0; i < d.batch_size; ++i) { + d.rc[i] = batch.rc[i].get(); + d.input_lens[i] = batch.rc[i]->input_len; + } +} + +static int linear_layer_index(int layer_id, const std::vector& layer_types) +{ + int idx = 0; + for (int i = 0; i < layer_id && i < (int)layer_types.size(); ++i) { + if (layer_types[i] == 1) + ++idx; + } + return idx; +} + +void GatedDeltaNetLayer::Forward(ForwardParam p) +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + + const int token_num = p.input.shape(0); + if (token_num == 0) + return; + + const auto dtype = p.input.dtype(); + const auto device = p.input.device(); + const auto stream = core::Context::stream().handle(); + const auto& weights = *p.weights; + + auto& pd = *phase_data_.at(p.phase); + + auto dispatch = [&](auto t) { + using T = decltype(t); + + // ================================================================= + // 1. Single fused input projection: reads p.input once from HBM. + // Output columns are ordered: [qkv | z | b | a] + // where the split dims are: conv_dim_, value_dim_, v_heads_tp_, v_heads_tp_ + // ================================================================= + const int v_heads_tp = num_v_heads_; // already TP-sharded + Tensor all_proj = linear_.Forward(p.input, weights.in_proj_all); + sync_check_cuda_error(); + + // Pointer-arithmetic slices — no copies, no allocations. + auto all_data = all_proj.data(); + // Column offsets per token (all_proj is token-major, row-major): + // [0, conv_dim_) -> mixed_qkv + // [conv_dim_, +value_dim_) -> z + // [conv_dim_+value_dim_, +v_heads_tp) -> b (beta logit) + // [conv_dim_+value_dim_+v_heads_tp, +v_heads_tp) -> a (alpha/dt) + const int all_col = conv_dim_ + value_dim_ + v_heads_tp * 2; + // const T* sub-pointers are derived per-request below; stride = all_col. + + // ================================================================= + // 2. Compute beta and g for all tokens + // b_raw and a_raw are sliced from the fused projection output. + // Stride between tokens is all_col elements. + // ================================================================= + const int bg_total = token_num * num_v_heads_; + const int b_offset = conv_dim_ + value_dim_; // column offset to b logits + const int a_offset = b_offset + v_heads_tp; // column offset to a logits + Tensor beta{{token_num, num_v_heads_}, dtype, device}; + Tensor g_tensor{{token_num, num_v_heads_}, dtype, device}; + + // Gather b and a columns into contiguous buffers for the kernel. + // Each has shape (token_num, v_heads_tp) but is strided inside all_proj. + // Use the existing invokeComputeBetaG which needs contiguous b/a input. + // We copy the b/a columns compactly first. + Tensor b_contig{{token_num, v_heads_tp}, dtype, device}; + Tensor a_contig{{token_num, v_heads_tp}, dtype, device}; + check_cuda_error(cudaMemcpy2DAsync(b_contig.data(), + v_heads_tp * sizeof(T), + all_data + b_offset, + all_col * sizeof(T), + v_heads_tp * sizeof(T), + token_num, + cudaMemcpyDeviceToDevice, + stream)); + check_cuda_error(cudaMemcpy2DAsync(a_contig.data(), + v_heads_tp * sizeof(T), + all_data + a_offset, + all_col * sizeof(T), + v_heads_tp * sizeof(T), + token_num, + cudaMemcpyDeviceToDevice, + stream)); + + invokeComputeBetaG(beta.data(), + g_tensor.data(), + b_contig.data(), + a_contig.data(), + weights.A_log.data(), + weights.dt_bias.data(), + bg_total, + num_v_heads_, + stream); + sync_check_cuda_error(); + + // ================================================================= + // 3. Process each request independently + // ================================================================= + Tensor attn_out{{token_num, value_dim_}, dtype, device}; + + int token_offset = 0; + for (int req = 0; req < pd.batch_size; ++req) { + auto& rc = *pd.rc[req]; + const int seq_len = pd.input_lens[req]; + if (seq_len == 0) + continue; + + // Slice per-request portions from the fused projection output. + // qkv lives in columns [0, conv_dim_) of all_proj with row-stride all_col. + // z lives in columns [conv_dim_, conv_dim_+value_dim_). + // b/a are already extracted into contiguous beta/g_tensor above. + T* qkv_row_ptr = all_data + token_offset * all_col; // first token, col 0 + T* z_row_ptr = all_data + token_offset * all_col + conv_dim_; + T* beta_ptr = beta.data() + token_offset * num_v_heads_; + T* g_ptr = g_tensor.data() + token_offset * num_v_heads_; + T* out_ptr = attn_out.data() + token_offset * value_dim_; + + const int state_layer_idx = linear_layer_index(p.layer_id, layer_types_); + + T* conv_state_ptr = nullptr; + T* recurrent_state_ptr = nullptr; + if (rc.conv_states) { + conv_state_ptr = rc.conv_states.data() + state_layer_idx * (conv_dim_ * d_conv_); + } + if (rc.recurrent_states) { + recurrent_state_ptr = + rc.recurrent_states.data() + state_layer_idx * (num_v_heads_ * key_head_dim_ * value_head_dim_); + } + + // ----- 3a. Fused Causal Conv1d + SiLU ----- + // conv_out shape: (seq_len, conv_dim_) — packed [Q|K|V] per token. + // No transpose: the new delta-rule kernels read this row-major layout directly. + Tensor conv_out{{seq_len, conv_dim_}, dtype, device}; + invokeFusedConv1dSiLU(conv_out.data(), + qkv_row_ptr, + weights.conv1d.data(), + (const T*)nullptr, + conv_state_ptr, + 1, + conv_dim_, + seq_len, + d_conv_, + all_col, + stream); + sync_check_cuda_error(); + + // ----- 3b. Gated Delta Rule ----- + // The kernels handle: + // • Strided Q/K/V access directly from the packed conv_out buffer + // (eliminates the three cudaMemcpy2DAsync strided-copy passes). + // • In-kernel L2 normalization of Q and K + // (eliminates two invokeL2Norm kernel launches). + // • GQA: kh = h / (num_v_heads / num_k_heads) computed per-thread + // (eliminates the invokeRepeatInterleave allocation). + const int k_dim_total = key_dim_; // num_k_heads * key_head_dim (per TP) + + if (seq_len == 1) { + // Decode: single-step recurrent update using persistent state. + invokeRecurrentGatedDeltaRule(out_ptr, + conv_out.data(), + beta_ptr, + g_ptr, + recurrent_state_ptr, + 1, // batch_size = 1 per request + num_v_heads_, + num_k_heads_, + key_head_dim_, + value_head_dim_, + k_dim_total, + stream); + } + else { + // Prefill: process all timesteps in a SINGLE kernel launch. + // (Previously: O(seq_len) kernel launches via host-side for-loop.) + invokeGatedDeltaRulePrefill(out_ptr, + conv_out.data(), + beta_ptr, + g_ptr, + recurrent_state_ptr, + seq_len, + num_v_heads_, + num_k_heads_, + key_head_dim_, + value_head_dim_, + k_dim_total, + stream); + } + sync_check_cuda_error(); + + // ----- 3c. RMSNormGated ----- + const int N = seq_len * num_v_heads_; + invokeRMSNormGated(out_ptr, + z_row_ptr, + weights.norm.data(), + norm_eps_, + N, + value_head_dim_, + all_col, + num_v_heads_, + stream); + sync_check_cuda_error(); + + token_offset += seq_len; + } + + // ================================================================= + // 4. Output projection (all tokens at once) + // ================================================================= + (void)linear_.Forward(attn_out, weights.out_proj, p.output); + sync_check_cuda_error(); + }; + + if (dtype == kHalf) { + dispatch(half{}); + } + else if (dtype == kBfloat16) { + dispatch(nv_bfloat16{}); + } + else { + TM_CHECK(0) << "Unsupported dtype for GatedDeltaNetLayer"; + } +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/GatedDeltaNetLayer.h b/src/turbomind/models/llama/GatedDeltaNetLayer.h new file mode 100644 index 0000000000..b74d8a1476 --- /dev/null +++ b/src/turbomind/models/llama/GatedDeltaNetLayer.h @@ -0,0 +1,66 @@ +#pragma once + +#include "src/turbomind/core/tensor.h" +#include "src/turbomind/engine/batch.h" +#include "src/turbomind/models/llama/GatedDeltaNetWeight.h" +#include "src/turbomind/models/llama/LlamaLinear.h" +#include "src/turbomind/models/llama/context.h" +#include "src/turbomind/models/llama/llama_params.h" + +namespace turbomind { + +class GatedDeltaNetLayer { +public: + struct ForwardParam { + int phase; + Tensor input; + Tensor output; + const GatedDeltaNetWeight* weights; + int layer_id; + }; + + GatedDeltaNetLayer(const ModelParam& model, + const AttentionParam& attn, + const EngineParam& engine, + int tp_size, + const Context& ctx, + int phases); + + ~GatedDeltaNetLayer(); + + void Run(BatchOp op, int phase, TensorMap& env); + + void Forward(ForwardParam p); + +private: + void Setup(int phase, TensorMap& env); + + // Model dimensions + int hidden_units_; + int num_k_heads_; + int num_v_heads_; + int key_head_dim_; + int value_head_dim_; + int d_conv_; + int key_dim_; // num_k_heads * key_head_dim + int value_dim_; // num_v_heads * value_head_dim + int conv_dim_; // key_dim * 2 + value_dim + int num_linear_layers_; // count of linear attention layers for state sizing + std::vector layer_types_; // model layer types for index mapping + + float norm_eps_; + DataType dtype_; + + LlamaLinear& linear_; + + // Per-phase batch data (mirrors UnifiedAttentionLayer pattern) + struct PhaseData { + std::vector rc; // borrowed batch RequestCache pointers + std::vector input_lens; // snapshot of input_len per request (captured at Setup time) + int batch_size = 0; + Buffer_ q_offsets; // cumulative token offsets, device buffer + }; + std::vector> phase_data_; +}; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/GatedDeltaNetWeight.cc b/src/turbomind/models/llama/GatedDeltaNetWeight.cc new file mode 100644 index 0000000000..b1e8fed695 --- /dev/null +++ b/src/turbomind/models/llama/GatedDeltaNetWeight.cc @@ -0,0 +1,164 @@ +#include "src/turbomind/models/llama/GatedDeltaNetWeight.h" +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +GatedDeltaNetWeight::GatedDeltaNetWeight(int hidden_dim, + int num_k_heads, + int num_v_heads, + int key_head_dim, + int value_head_dim, + int d_conv, + bool bias, + int tp_size, + int tp_rank, + DataType data_type, + DataType weight_type, + int group_size): + tp_rank_(tp_rank), tp_size_(tp_size) +{ + const int key_dim = num_k_heads * key_head_dim / tp_size; + const int value_dim = num_v_heads * value_head_dim / tp_size; + const int v_heads_tp = num_v_heads / tp_size; + const int conv_dim = key_dim * 2 + value_dim; + + // GatedDeltaNet projections are stored as plain dense weights in the checkpoint + // (dense_wtype = data_type avoids quantization path for these projections). + const DataType dense_wtype = data_type; + const int dense_gsz = 0; + + // Individual projections registered for checkpoint loading + in_proj_qkv.emplace(hidden_dim, conv_dim, data_type, bias, dense_wtype, dense_gsz); + in_proj_z.emplace(hidden_dim, value_dim, data_type, bias, dense_wtype, dense_gsz); + in_proj_b.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz); + in_proj_a.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz); + out_proj.emplace(value_dim, hidden_dim, data_type, bias, dense_wtype, dense_gsz); + + register_module("in_proj_qkv", in_proj_qkv, tp_rank_); + register_module("in_proj_z", in_proj_z, tp_rank_); + register_module("in_proj_b", in_proj_b, tp_rank_); + register_module("in_proj_a", in_proj_a, tp_rank_); + register_module("out_proj", out_proj, tp_rank_); + + // conv1d: depthwise weights, shape (conv_dim, d_conv) + conv1d = Tensor{{conv_dim, d_conv}, data_type, kDEVICE}; + register_parameter("conv1d." + std::to_string(tp_rank_) + ".weight", conv1d); + + // A_log: log-space decay per head, shape (num_v_heads/tp,) + A_log = Tensor{{v_heads_tp}, data_type, kDEVICE}; + register_parameter("A_log." + std::to_string(tp_rank_) + ".weight", A_log); + + // dt_bias: per head, shape (num_v_heads/tp,) + dt_bias = Tensor{{v_heads_tp}, data_type, kDEVICE}; + register_parameter("dt_bias." + std::to_string(tp_rank_) + ".weight", dt_bias); + + // norm: RMSNormGated weight, shape (value_head_dim,) + norm = Tensor{{value_head_dim}, data_type, kDEVICE}; + register_parameter("norm.weight", norm); +} + +// --------------------------------------------------------------------------- +// Row-wise concatenation of 4 weight matrices into a single pre-allocated +// destination tensor. +// +// Each source weight has shape (input_dim, out_dim_i) in row-major storage. +// The destination has shape (input_dim, sum_i out_dim_i) and rows are filled +// by concatenating the corresponding source rows in order. +// +// Implemented with cudaMemcpy2DAsync so that no extra temporary is needed: +// each source "column block" is scattered into the correct column range of +// the destination in one pass per source. +// --------------------------------------------------------------------------- +static void +concat_weights_4(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, Tensor& dst, cudaStream_t st) +{ + // Tensors are (K=input_dim, M=output_dim) in row-major order. + // Each row of `dst` is [a_row | b_row | c_row | d_row]. + const int K = dst.shape(0); + const int M_a = a.shape(1); + const int M_b = b.shape(1); + const int M_c = c.shape(1); + const int M_d = d.shape(1); + const int M_dst = dst.shape(1); // M_a + M_b + M_c + M_d + const int elem_sz = byte_size(dst.dtype(), 1); + + // Pitch of the destination row in bytes + const size_t dst_pitch = (size_t)M_dst * elem_sz; + const size_t src_pitch_a = (size_t)M_a * elem_sz; + const size_t src_pitch_b = (size_t)M_b * elem_sz; + const size_t src_pitch_c = (size_t)M_c * elem_sz; + const size_t src_pitch_d = (size_t)M_d * elem_sz; + + char* dst_ptr = reinterpret_cast(dst.raw_data()); + + // Columns [0, M_a) + check_cuda_error( + cudaMemcpy2DAsync(dst_ptr, dst_pitch, a.raw_data(), src_pitch_a, src_pitch_a, K, cudaMemcpyDefault, st)); + + // Columns [M_a, M_a+M_b) + check_cuda_error(cudaMemcpy2DAsync( + dst_ptr + src_pitch_a, dst_pitch, b.raw_data(), src_pitch_b, src_pitch_b, K, cudaMemcpyDefault, st)); + + // Columns [M_a+M_b, M_a+M_b+M_c) + check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b, + dst_pitch, + c.raw_data(), + src_pitch_c, + src_pitch_c, + K, + cudaMemcpyDefault, + st)); + + // Columns [M_a+M_b+M_c, M_dst) + check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b + src_pitch_c, + dst_pitch, + d.raw_data(), + src_pitch_d, + src_pitch_d, + K, + cudaMemcpyDefault, + st)); + sync_check_cuda_error(); +} + +void GatedDeltaNetWeight::prepare() +{ + auto stream = core::Context::stream().handle(); + + // Preprocess individual weights (converts blockscale FP8, etc.) + in_proj_qkv.preprocess(); + in_proj_z.preprocess(); + in_proj_b.preprocess(); + in_proj_a.preprocess(); + out_proj.preprocess(); + out_proj.prepare(); + + // Build the fused input projection weight: + // shape (hidden_dim, conv_dim + value_dim + 2*v_heads_tp) + // = [in_proj_qkv | in_proj_z | in_proj_b | in_proj_a] (column-wise) + const int out_all = in_proj_qkv.output_dim // + + in_proj_z.output_dim // + + in_proj_b.output_dim // + + in_proj_a.output_dim; + + in_proj_all.emplace(in_proj_qkv.input_dim, + out_all, + in_proj_qkv.data_type, + /*bias=*/false, + in_proj_qkv.weight_type, + in_proj_qkv.group_size); + + concat_weights_4( + in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, in_proj_all.weight, stream); + + // Prepare (convert/repack) the fused weight for GEMM + in_proj_all.prepare(); + + // Release the now-redundant individual weight tensors to free HBM + in_proj_qkv = {}; + in_proj_z = {}; + in_proj_b = {}; + in_proj_a = {}; +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/GatedDeltaNetWeight.h b/src/turbomind/models/llama/GatedDeltaNetWeight.h new file mode 100644 index 0000000000..2bc3aad778 --- /dev/null +++ b/src/turbomind/models/llama/GatedDeltaNetWeight.h @@ -0,0 +1,52 @@ +#pragma once + +#include "src/turbomind/core/core.h" +#include "src/turbomind/core/module.h" +#include "src/turbomind/models/llama/LlamaDenseWeight.h" + +namespace turbomind { + +struct GatedDeltaNetWeight: public core::Module { + + GatedDeltaNetWeight() = default; + + GatedDeltaNetWeight(int hidden_dim, + int num_k_heads, + int num_v_heads, + int key_head_dim, + int value_head_dim, + int d_conv, + bool bias, + int tp_size, + int tp_rank, + DataType data_type, + DataType weight_type, + int group_size); + + void prepare(); + + // Individual projections – populated at load time from the checkpoint. + // After prepare() completes they are released (null-ed) to free HBM. + LlamaDenseWeight in_proj_qkv; // hidden -> key_dim*2 + value_dim + LlamaDenseWeight in_proj_z; // hidden -> value_dim (output gate) + LlamaDenseWeight in_proj_b; // hidden -> num_v_heads (beta, per-head scalar) + LlamaDenseWeight in_proj_a; // hidden -> num_v_heads (alpha/dt, per-head scalar) + + // Fused projection: hidden -> (conv_dim + value_dim + 2*v_heads_tp). + // Built from the four above in prepare(); used for all inference GEMMs. + // Reduces p.input HBM reads from 4× to 1× per forward pass. + LlamaDenseWeight in_proj_all; + + LlamaDenseWeight out_proj; // value_dim -> hidden + + // Non-dense parameters + Tensor conv1d; // depthwise conv weights: (conv_dim, 1, d_conv) flattened + Tensor A_log; // log-space decay: (num_v_heads,) + Tensor dt_bias; // dt bias: (num_v_heads,) + Tensor norm; // RMSNormGated weight: (value_head_dim,) + + int tp_rank_; + int tp_size_; +}; + +} // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 5c82a377f0..f9774671f1 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -68,31 +68,52 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight( mlp_tp_size_(engine.mlp_tp_size), mlp_tp_rank_(engine.mlp_tp_rank) { - // Attention uses weight_type (fp16 in mixed quant scenarios) - self_attn_weights.reset(new LlamaAttentionWeight{hidden_units_, - size_per_head_, - head_num_, - kv_head_num_, - model.mla, - attn_bias_, - model.qk_norm, - attn_tp_size_, - attn_tp_rank_, - data_type_, - weight_type_, - model.group_size, - model.window_size.empty() ? 0 : model.window_size.at(layer_id), - model.attn_sink}); - register_module("attention", *self_attn_weights); - - // Dense layers (no MoE) use weight_type for FFN, since models like - // GLM-4.7-Flash-AWQ exclude the entire dense layer from quantization - // ("model.layers.0." in modules_to_not_convert). MoE layers use - // ffn_weight_type for their shared experts (int4 for mixed AWQ, - // bfloat16 for GptOss mxfp4, same as weight_type otherwise). + bool is_linear_attention = false; + if (layer_id < (int)model.layer_types.size() && model.layer_types[layer_id] == 1) { + is_linear_attention = true; + } + + if (is_linear_attention) { + linear_attn_weights.reset( + new GatedDeltaNetWeight{hidden_units_, + model.linear_num_key_heads, + model.linear_num_value_heads, + model.linear_key_head_dim, + model.linear_value_head_dim, + model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4, + attn_bias_, + attn_tp_size_, + attn_tp_rank_, + data_type_, + weight_type_, + model.group_size}); + register_module("linear_attn", *linear_attn_weights); + } + else { + // Attention uses weight_type (fp16 in mixed quant scenarios) + self_attn_weights.reset(new LlamaAttentionWeight{hidden_units_, + size_per_head_, + head_num_, + kv_head_num_, + model.mla, + attn_bias_, + model.qk_norm, + attn_tp_size_, + attn_tp_rank_, + data_type_, + weight_type_, + model.group_size, + model.window_size.empty() ? 0 : model.window_size.at(layer_id), + model.attn_sink, + model.attn_output_gate}); + register_module("attention", *self_attn_weights); + } + + // FFN uses ffn_weight_type, except for layers fully excluded from + // quantization (e.g. 'model.layers.0.' in modules_to_not_convert) + // where all weights—including FFN—are in data_type (fp16). if (inter_size_) { - const bool is_moe_layer = layer_id < (int)moe_param.expert_num.size() && moe_param.expert_num[layer_id]; - const DataType ffn_wtype = is_moe_layer ? model.ffn_weight_type : weight_type_; + const DataType ffn_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : model.ffn_weight_type; const bool is_cublas_gemm = byte_size(ffn_wtype, 8) == 16; ffn_weights.reset(new LlamaFfnWeight{ hidden_units_, @@ -110,13 +131,16 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight( } // MoE routed experts use expert_weight_type (int4 for AWQ, e2m1 for mxfp4) + // unless the layer is in unquantized_expert_layers (e.g. layer 0 excluded + // from quantization via modules_to_not_convert). if (layer_id < moe_param.expert_num.size() && moe_param.expert_num[layer_id]) { + const DataType moe_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : expert_weight_type_; moe_weights.reset(new MoeFfnWeight{layer_id, moe_param, hidden_units_, model.mlp_bias, data_type_, - expert_weight_type_, + moe_wtype, model.group_size, mlp_tp_size_, mlp_tp_rank_, @@ -135,7 +159,13 @@ LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() = default; void LlamaDecoderLayerWeight::prepare(const cudaDeviceProp& prop, cudaStream_t st) { - self_attn_weights->prepare(); + if (self_attn_weights) { + self_attn_weights->prepare(); + } + + if (linear_attn_weights) { + linear_attn_weights->prepare(); + } if (ffn_weights) { ffn_weights->prepare(false); diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index a158575f4a..6ac387ab12 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -22,6 +22,7 @@ #include "src/turbomind/core/core.h" +#include "src/turbomind/models/llama/GatedDeltaNetWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/llama_params.h" @@ -47,6 +48,7 @@ struct LlamaDecoderLayerWeight: core::Module { Tensor ffn_norm; std::unique_ptr self_attn_weights; + std::unique_ptr linear_attn_weights; std::unique_ptr ffn_weights; std::unique_ptr moe_weights; diff --git a/src/turbomind/models/llama/LlamaDenseWeight.cc b/src/turbomind/models/llama/LlamaDenseWeight.cc index 2aa7ff9da6..1764d3622a 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.cc +++ b/src/turbomind/models/llama/LlamaDenseWeight.cc @@ -273,13 +273,21 @@ LlamaAttentionWeight::LlamaAttentionWeight(int hidden_dim, DataType weight_type, int group_size, int window_size, - bool sink) + bool sink, + bool attn_output_gate) { this->window_size = window_size; + // attn_output_gate doubles Q dimension (extra gate projection fused into Q) + const int q_factor = attn_output_gate ? 2 : 1; + if (mla.kv_lora_rank == 0) { - qkv.emplace( - hidden_dim, (head_num + 2 * kv_head_num) * head_dim / tp_size, data_type, bias, weight_type, group_size); + qkv.emplace(hidden_dim, + (head_num * q_factor + 2 * kv_head_num) * head_dim / tp_size, + data_type, + bias, + weight_type, + group_size); register_module("w_qkv", qkv, tp_rank); if (qk_norm) { q_a_layernorm = Tensor{{head_dim}, data_type, kDEVICE}; diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index a8288559ef..7aa8673586 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -99,7 +99,8 @@ struct LlamaAttentionWeight: public core::Module { DataType weight_type, int group_size, int window_size, - bool sink); + bool sink, + bool attn_output_gate = false); void prepare(); diff --git a/src/turbomind/models/llama/gated_delta_net_kernels.cu b/src/turbomind/models/llama/gated_delta_net_kernels.cu new file mode 100644 index 0000000000..2d92b75020 --- /dev/null +++ b/src/turbomind/models/llama/gated_delta_net_kernels.cu @@ -0,0 +1,741 @@ +#include "src/turbomind/models/llama/gated_delta_net_kernels.h" +#include "src/turbomind/utils/cuda_utils.h" +#include +#include + +namespace turbomind { + +// ============================================================================= +// Shared utility: block-level L2 norm inverse using warp + shared reduction +// Returns rsqrt(sum_sq + eps) visible to all threads in the block. +// smem must be at least ceil(blockDim.x/32) floats. +// SM70 specific optimization: bypass smem entirely if blockDim.x <= 32 +// ============================================================================= +__device__ __forceinline__ float block_l2_inv_norm(float partial_sq, float* smem, float eps = 1e-6f) +{ + // Warp reduce + for (int mask = 16; mask > 0; mask >>= 1) + partial_sq += __shfl_xor_sync(0xffffffff, partial_sq, mask); + + // Fast path: if the block is a single warp, we don't need shared memory + if (blockDim.x <= 32) { + return rsqrtf(partial_sq + eps); + } + + const int lane_id = threadIdx.x & 31; + const int warp_id = threadIdx.x >> 5; + if (lane_id == 0) + smem[warp_id] = partial_sq; + __syncthreads(); + + // First warp reduces across warps + const int num_warps = (blockDim.x + 31) >> 5; + if (warp_id == 0) { + float val = (lane_id < num_warps) ? smem[lane_id] : 0.f; + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask); + if (lane_id == 0) + smem[0] = rsqrtf(val + eps); + } + __syncthreads(); + return smem[0]; +} + +// Helper to accumulate squares of a 16-bit type scalar or vector2 using float32 arithmetic +template +__device__ __forceinline__ float sq_acc(T val) +{ + return (float)val * (float)val; +} + +#if defined(__CUDA_ARCH__) +__device__ __forceinline__ float sq_acc(half2 val) +{ + float2 fval = __half22float2(val); + return fval.x * fval.x + fval.y * fval.y; +} +__device__ __forceinline__ float sq_acc(nv_bfloat162 val) +{ + float2 fval = __bfloat1622float2(val); + return fval.x * fval.x + fval.y * fval.y; +} + +__device__ __forceinline__ float2 to_float2(half2 v) +{ + return __half22float2(v); +} +__device__ __forceinline__ float2 to_float2(nv_bfloat162 v) +{ + return __bfloat1622float2(v); +} +__device__ __forceinline__ half2 to_vec2(float2 v, half) +{ + return __float22half2_rn(v); +} +__device__ __forceinline__ nv_bfloat162 to_vec2(float2 v, nv_bfloat16) +{ + return __float22bfloat162_rn(v); +} +#endif + +// ============================================================================= +// Causal Conv1d — Decode (seq_len == 1) +// ============================================================================= +template +__global__ void causal_conv1d_decode_kernel( + T* out, const T* in, const T* weight, const T* bias, T* conv_states, int conv_dim, int d_conv) +{ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total = gridDim.y * conv_dim; + if (idx >= total) + return; + + const int b = idx / conv_dim; + const int c = idx % conv_dim; + + T* state = conv_states + (b * conv_dim + c) * d_conv; + +#pragma unroll + for (int d = 0; d < d_conv - 1; ++d) + state[d] = state[d + 1]; + state[d_conv - 1] = in[b * conv_dim + c]; + + const T* w = weight + c * d_conv; + float acc = 0.0f; +#pragma unroll + for (int d = 0; d < d_conv; ++d) + acc += static_cast(state[d]) * static_cast(w[d]); + if (bias) + acc += static_cast(bias[c]); + out[b * conv_dim + c] = static_cast(acc / (1.0f + expf(-acc))); +} + +// ============================================================================= +// Causal Conv1d — Prefill (seq_len > 1) +// ============================================================================= +template +__global__ void causal_conv1d_prefill_kernel( + T* out, const T* in, const T* weight, const T* bias, T* conv_states, int conv_dim, int seq_len, int d_conv) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int total = gridDim.y * conv_dim * seq_len; + if (tid >= total) + return; + + const int b = tid / (conv_dim * seq_len); + const int rem = tid % (conv_dim * seq_len); + const int c = rem / seq_len; + const int t = rem % seq_len; + + const T* w = weight + c * d_conv; + float acc = 0.0f; +#pragma unroll + for (int d = 0; d < d_conv; ++d) { + int src_t = t - (d_conv - 1 - d); + float val = 0.0f; + if (src_t >= 0) + val = static_cast(in[(b * conv_dim + c) * seq_len + src_t]); + acc += val * static_cast(w[d]); + } + if (bias) + acc += static_cast(bias[c]); + out[(b * conv_dim + c) * seq_len + t] = static_cast(acc / (1.0f + expf(-acc))); + + if (conv_states && t >= seq_len - d_conv) { + int state_idx = d_conv - (seq_len - t); + conv_states[(b * conv_dim + c) * d_conv + state_idx] = in[(b * conv_dim + c) * seq_len + t]; + } +} + +template +void invokeCausalConv1d(T* out, + const T* in, + const T* weight, + const T* bias, + T* conv_states, + int batch_size, + int conv_dim, + int seq_len, + int d_conv, + cudaStream_t stream) +{ + if (seq_len == 1) { + const int n = batch_size * conv_dim; + const int threads = 256; + const int blocks = (n + threads - 1) / threads; + dim3 grid(blocks, batch_size); + causal_conv1d_decode_kernel<<>>(out, in, weight, bias, conv_states, conv_dim, d_conv); + } + else { + const int n = batch_size * conv_dim * seq_len; + const int threads = 256; + const int blocks = (n + threads - 1) / threads; + dim3 grid(blocks, batch_size); + causal_conv1d_prefill_kernel<<>>( + out, in, weight, bias, conv_states, conv_dim, seq_len, d_conv); + } +} + +// ============================================================================= +// Recurrent Gated Delta Rule (decode, seq_len == 1) +// +// Optimizations vs. original: +// - Reads Q/K/V directly from the packed qkv_in buffer (stride = conv_dim), +// eliminating three cudaMemcpy2DAsync strided-copy passes. +// - L2-normalizes Q and K in-kernel using a shared-memory block reduction, +// removing the separate invokeL2Norm kernel launches. +// - GQA: maps v_head -> k_head = v_head / ratio inside the kernel, +// removing the invokeRepeatInterleave allocation. +// ============================================================================= +template +__global__ void recurrent_delta_rule_kernel(T* v_out, + const T* qkv_in, + const T* beta_in, + const T* g_in, + T* state, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total) +{ + const int bh = blockIdx.x; + const int b = bh / num_v_heads; + const int h = bh % num_v_heads; + const int ratio = num_v_heads / num_k_heads; + const int kh = h / ratio; + + const int state_size = key_head_dim * value_head_dim; + const int conv_dim = 2 * k_dim_total + num_v_heads * value_head_dim; + + // Pointers into packed qkv for this batch element & head + const T* q_ptr = qkv_in + b * conv_dim + kh * key_head_dim; + const T* k_ptr = qkv_in + b * conv_dim + k_dim_total + kh * key_head_dim; + const T* v_ptr = qkv_in + b * conv_dim + 2 * k_dim_total + h * value_head_dim; + T* s_ptr = state + (b * num_v_heads + h) * state_size; + T* o_ptr = v_out + (b * num_v_heads + h) * value_head_dim; + + const float beta_val = static_cast(beta_in[b * num_v_heads + h]); + const float decay = expf(static_cast(g_in[b * num_v_heads + h])); + + // Shared memory for block reductions (one slot per warp) + __shared__ float smem[32]; + + // --- In-kernel L2-normalize Q (Vectorized) --- + float q_sq = 0.f; + if (key_head_dim % 2 == 0) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + const T2* q_ptr_v2 = reinterpret_cast(q_ptr); + for (int kd = threadIdx.x; kd < key_head_dim / 2; kd += blockDim.x) + q_sq += sq_acc(q_ptr_v2[kd]); + } + else { + for (int kd = threadIdx.x; kd < key_head_dim; kd += blockDim.x) + q_sq += sq_acc(q_ptr[kd]); + } + const float q_inv_norm = block_l2_inv_norm(q_sq, smem); + + // --- In-kernel L2-normalize K (Vectorized) --- + float k_sq = 0.f; + if (key_head_dim % 2 == 0) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + const T2* k_ptr_v2 = reinterpret_cast(k_ptr); + for (int kd = threadIdx.x; kd < key_head_dim / 2; kd += blockDim.x) + k_sq += sq_acc(k_ptr_v2[kd]); + } + else { + for (int kd = threadIdx.x; kd < key_head_dim; kd += blockDim.x) + k_sq += sq_acc(k_ptr[kd]); + } + const float k_inv_norm = block_l2_inv_norm(k_sq, smem); + + // --- Step 1: S *= decay --- + for (int idx = threadIdx.x; idx < state_size; idx += blockDim.x) + s_ptr[idx] = static_cast((float)s_ptr[idx] * decay); + __syncthreads(); + + // --- Step 2: delta rule update (each thread owns a slice of vd) --- + for (int vd = threadIdx.x; vd < value_head_dim; vd += blockDim.x) { + float kv_mem = 0.f; + for (int kd = 0; kd < key_head_dim; ++kd) + kv_mem += (float)s_ptr[kd * value_head_dim + vd] * ((float)k_ptr[kd] * k_inv_norm); + + const float delta = ((float)v_ptr[vd] - kv_mem) * beta_val; + + for (int kd = 0; kd < key_head_dim; ++kd) + s_ptr[kd * value_head_dim + vd] = + static_cast((float)s_ptr[kd * value_head_dim + vd] + (float)k_ptr[kd] * k_inv_norm * delta); + } + __syncthreads(); + + // --- Step 3: output = (S^T @ q) * scale --- + const float scale = rsqrtf((float)key_head_dim); + for (int vd = threadIdx.x; vd < value_head_dim; vd += blockDim.x) { + float o = 0.f; + for (int kd = 0; kd < key_head_dim; ++kd) + o += (float)s_ptr[kd * value_head_dim + vd] * ((float)q_ptr[kd] * q_inv_norm); + o_ptr[vd] = static_cast(o * scale); + } +} + +template +void invokeRecurrentGatedDeltaRule(T* v_out, + const T* qkv_in, + const T* beta, + const T* g, + T* recurrent_state, + int batch_size, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total, + cudaStream_t stream) +{ + const int num_blocks = batch_size * num_v_heads; + if (num_blocks == 0) + return; + const int threads = std::min(256, value_head_dim); + const size_t smem_sz = ((threads + 31) / 32) * sizeof(float); + recurrent_delta_rule_kernel<<>>( + v_out, qkv_in, beta, g, recurrent_state, num_v_heads, num_k_heads, key_head_dim, value_head_dim, k_dim_total); +} + +// ============================================================================= +// Single-Launch Prefill Gated Delta Rule (seq_len > 1) +// +// Optimizations vs. original invokeSerialGatedDeltaRule: +// - The entire sequence is processed inside ONE kernel launch. +// The original code called invokeRecurrentGatedDeltaRule in a host-side +// for-loop, causing O(seq_len) kernel dispatches and CPU-GPU synchronization +// round-trips that starve the GPU for every timestep. +// - L2Norm, GQA handling, and packed qkv access are also fused in-kernel +// (same improvements as the recurrent decode kernel above). +// ============================================================================= +template +__global__ void gated_delta_rule_prefill_kernel(T* v_out, + const T* qkv_in, + const T* beta_in, + const T* g_in, + T* state, + int seq_len, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total) +{ + // One block per v_head. Threads are distributed over value_head_dim. + const int h = blockIdx.x; + const int ratio = num_v_heads / num_k_heads; + const int kh = h / ratio; // GQA: k_head for this v_head + + const int state_size = key_head_dim * value_head_dim; + const int conv_dim = 2 * k_dim_total + num_v_heads * value_head_dim; + const int v_dim = num_v_heads * value_head_dim; + + T* s_ptr = state + h * state_size; + const float scale = rsqrtf((float)key_head_dim); + + __shared__ float smem[32]; + + for (int t = 0; t < seq_len; ++t) { + // Pointers into packed qkv (row-major per token) + const T* q_ptr = qkv_in + t * conv_dim + kh * key_head_dim; + const T* k_ptr = qkv_in + t * conv_dim + k_dim_total + kh * key_head_dim; + const T* v_ptr = qkv_in + t * conv_dim + 2 * k_dim_total + h * value_head_dim; + T* o_ptr = v_out + t * v_dim + h * value_head_dim; + + const float beta_val = (float)beta_in[t * num_v_heads + h]; + const float decay = expf((float)g_in[t * num_v_heads + h]); + + // --- In-kernel L2-normalize Q (Vectorized) --- + float q_sq = 0.f; + if (key_head_dim % 2 == 0) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + const T2* q_ptr_v2 = reinterpret_cast(q_ptr); + for (int kd = threadIdx.x; kd < key_head_dim / 2; kd += blockDim.x) + q_sq += sq_acc(q_ptr_v2[kd]); + } + else { + for (int kd = threadIdx.x; kd < key_head_dim; kd += blockDim.x) + q_sq += sq_acc(q_ptr[kd]); + } + const float q_inv_norm = block_l2_inv_norm(q_sq, smem); + + // --- In-kernel L2-normalize K (Vectorized) --- + float k_sq = 0.f; + if (key_head_dim % 2 == 0) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + const T2* k_ptr_v2 = reinterpret_cast(k_ptr); + for (int kd = threadIdx.x; kd < key_head_dim / 2; kd += blockDim.x) + k_sq += sq_acc(k_ptr_v2[kd]); + } + else { + for (int kd = threadIdx.x; kd < key_head_dim; kd += blockDim.x) + k_sq += sq_acc(k_ptr[kd]); + } + const float k_inv_norm = block_l2_inv_norm(k_sq, smem); + + // --- Step 1: S *= decay --- + for (int idx = threadIdx.x; idx < state_size; idx += blockDim.x) + s_ptr[idx] = static_cast((float)s_ptr[idx] * decay); + __syncthreads(); + + // --- Step 2: delta rule update --- + for (int vd = threadIdx.x; vd < value_head_dim; vd += blockDim.x) { + float kv_mem = 0.f; + for (int kd = 0; kd < key_head_dim; ++kd) + kv_mem += (float)s_ptr[kd * value_head_dim + vd] * ((float)k_ptr[kd] * k_inv_norm); + + const float delta = ((float)v_ptr[vd] - kv_mem) * beta_val; + + for (int kd = 0; kd < key_head_dim; ++kd) + s_ptr[kd * value_head_dim + vd] = + static_cast((float)s_ptr[kd * value_head_dim + vd] + (float)k_ptr[kd] * k_inv_norm * delta); + } + __syncthreads(); + + // --- Step 3: output = (S^T @ q) * scale --- + for (int vd = threadIdx.x; vd < value_head_dim; vd += blockDim.x) { + float o = 0.f; + for (int kd = 0; kd < key_head_dim; ++kd) + o += (float)s_ptr[kd * value_head_dim + vd] * ((float)q_ptr[kd] * q_inv_norm); + o_ptr[vd] = static_cast(o * scale); + } + __syncthreads(); // Ensure state write-back is visible before next step's decay + } + // State is updated in-place in global memory throughout the loop +} + +template +void invokeGatedDeltaRulePrefill(T* v_out, + const T* qkv_in, + const T* beta, + const T* g, + T* recurrent_state, + int seq_len, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total, + cudaStream_t stream) +{ + if (num_v_heads == 0 || seq_len == 0) + return; + // One block per v_head; threads cover value_head_dim + const int threads = std::min(256, value_head_dim); + const size_t smem_sz = ((threads + 31) / 32) * sizeof(float); + gated_delta_rule_prefill_kernel<<>>(v_out, + qkv_in, + beta, + g, + recurrent_state, + seq_len, + num_v_heads, + num_k_heads, + key_head_dim, + value_head_dim, + k_dim_total); +} + +// ============================================================================= +// Compute beta = sigmoid(b) and g = -exp(A_log) * softplus(a + dt_bias) +// ============================================================================= +template +__global__ void compute_beta_g_kernel( + T* beta_out, T* g_out, const T* b_in, const T* a_in, const T* A_log, const T* dt_bias, int total, int num_v_heads) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; +#if defined(__CUDA_ARCH__) + if (total % 2 == 0 && num_v_heads % 2 == 0) { + if (tid < total / 2) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + T2* beta_v2 = reinterpret_cast(beta_out); + T2* g_v2 = reinterpret_cast(g_out); + const T2* b_v2 = reinterpret_cast(b_in); + const T2* a_v2 = reinterpret_cast(a_in); + + float2 b_val = to_float2(b_v2[tid]); + float2 a_val = to_float2(a_v2[tid]); + + int h0 = (tid * 2) % num_v_heads; + int h1 = (tid * 2 + 1) % num_v_heads; + + float Al0 = static_cast(A_log[h0]); + float dt0 = static_cast(dt_bias[h0]); + float Al1 = static_cast(A_log[h1]); + float dt1 = static_cast(dt_bias[h1]); + + float beta0 = 1.0f / (1.0f + expf(-b_val.x)); + float sum0 = a_val.x + dt0; + float sp0 = sum0 > 20.0f ? sum0 : logf(1.0f + expf(sum0)); + float g_val0 = -expf(Al0) * sp0; + + float beta1 = 1.0f / (1.0f + expf(-b_val.y)); + float sum1 = a_val.y + dt1; + float sp1 = sum1 > 20.0f ? sum1 : logf(1.0f + expf(sum1)); + float g_val1 = -expf(Al1) * sp1; + + beta_v2[tid] = to_vec2(make_float2(beta0, beta1), T{}); + g_v2[tid] = to_vec2(make_float2(g_val0, g_val1), T{}); + } + } + else +#endif + { + if (tid >= total) + return; + + const int h = tid % num_v_heads; + float b_val = static_cast(b_in[tid]); + float a_val = static_cast(a_in[tid]); + float A_log_val = static_cast(A_log[h]); + float dt_bias_val = static_cast(dt_bias[h]); + + float beta = 1.0f / (1.0f + expf(-b_val)); + float sum = a_val + dt_bias_val; + float sp = sum > 20.0f ? sum : logf(1.0f + expf(sum)); + float g_val = -expf(A_log_val) * sp; + + beta_out[tid] = static_cast(beta); + g_out[tid] = static_cast(g_val); + } +} + +template +void invokeComputeBetaG(T* beta_out, + T* g_out, + const T* b_in, + const T* a_in, + const T* A_log, + const T* dt_bias, + int total, + int num_v_heads, + cudaStream_t stream) +{ + const int threads = 256; + if (total % 2 == 0 && num_v_heads % 2 == 0) { + const int blocks = (total / 2 + threads - 1) / threads; + compute_beta_g_kernel<<>>( + beta_out, g_out, b_in, a_in, A_log, dt_bias, total, num_v_heads); + } + else { + const int blocks = (total + threads - 1) / threads; + compute_beta_g_kernel<<>>( + beta_out, g_out, b_in, a_in, A_log, dt_bias, total, num_v_heads); + } +} + +// ============================================================================= +// RMSNorm * SiLU-Gate (fused output normalization) +// ============================================================================= +template +__global__ void rms_norm_gated_kernel( + T* hidden, const T* gate, const T* weight, float eps, int N, int head_dim, int gate_stride, int num_heads) +{ + const int row = blockIdx.x; + if (row >= N) + return; + + T* h = hidden + row * head_dim; + const int token_idx = row / num_heads; + const int head_idx = row % num_heads; + const T* g = gate + token_idx * gate_stride + head_idx * head_dim; + + __shared__ float smem[32]; + float sum_sq = 0.0f; + for (int d = threadIdx.x; d < head_dim; d += blockDim.x) { + float val = static_cast(h[d]); + sum_sq += val * val; + } + for (int mask = 16; mask > 0; mask >>= 1) + sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask); + if ((threadIdx.x & 31) == 0) + smem[threadIdx.x >> 5] = sum_sq; + __syncthreads(); + if (threadIdx.x >> 5 == 0) { + sum_sq = (threadIdx.x < (blockDim.x + 31) / 32) ? smem[threadIdx.x] : 0.0f; + for (int mask = 16; mask > 0; mask >>= 1) + sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask); + if (threadIdx.x == 0) + smem[0] = sum_sq; + } + __syncthreads(); + sum_sq = smem[0]; + + float inv_rms = rsqrtf(sum_sq / (float)head_dim + eps); + for (int d = threadIdx.x; d < head_dim; d += blockDim.x) { + float h_val = static_cast(h[d]) * inv_rms * static_cast(weight[d]); + float g_val = static_cast(g[d]); + float silu_g = g_val / (1.0f + expf(-g_val)); + h[d] = static_cast(h_val * silu_g); + } +} + +template +void invokeRMSNormGated(T* hidden, + const T* gate, + const T* weight, + float eps, + int N, + int head_dim, + int gate_stride, + int num_heads, + cudaStream_t stream) +{ + const int threads = std::min(256, head_dim); + rms_norm_gated_kernel<<>>(hidden, gate, weight, eps, N, head_dim, gate_stride, num_heads); +} + +// ============================================================================= +// Fused Conv1d + SiLU for row-major layout +// ============================================================================= +template +__global__ void fused_conv1d_decode_kernel( + T* out, const T* in, const T* weight, const T* bias, T* state, int conv_dim, int d_conv, int in_stride) +{ + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c >= conv_dim) + return; + + T* s = state + c * d_conv; +#pragma unroll + for (int d = 0; d < d_conv - 1; ++d) + s[d] = s[d + 1]; + s[d_conv - 1] = in[c]; + + const T* w = weight + c * d_conv; + float acc = 0.0f; +#pragma unroll + for (int d = 0; d < d_conv; ++d) + acc += static_cast(s[d]) * static_cast(w[d]); + if (bias) + acc += static_cast(bias[c]); + out[c] = static_cast(acc / (1.0f + expf(-acc))); +} + +template +__global__ void fused_conv1d_prefill_kernel( + T* out, const T* in, const T* weight, const T* bias, T* state, int conv_dim, int seq_len, int d_conv, int in_stride) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int total = seq_len * conv_dim; + if (tid >= total) + return; + + const int t = tid / conv_dim; + const int c = tid % conv_dim; + + const T* w = weight + c * d_conv; + float acc = 0.0f; +#pragma unroll + for (int d = 0; d < d_conv; ++d) { + int src_t = t - (d_conv - 1 - d); + float val = 0.0f; + if (src_t >= 0) + val = static_cast(in[src_t * in_stride + c]); + acc += val * static_cast(w[d]); + } + if (bias) + acc += static_cast(bias[c]); + out[t * conv_dim + c] = static_cast(acc / (1.0f + expf(-acc))); + + if (state && t >= seq_len - d_conv) { + int state_idx = d_conv - (seq_len - t); + state[c * d_conv + state_idx] = in[t * in_stride + c]; + } +} + +template +void invokeFusedConv1dSiLU(T* out, + const T* in, + const T* weight, + const T* bias, + T* conv_states, + int batch_size, + int conv_dim, + int seq_len, + int d_conv, + int in_stride, + cudaStream_t stream) +{ + if (seq_len == 1) { + const int threads = 256; + const int blocks = (conv_dim + threads - 1) / threads; + fused_conv1d_decode_kernel<<>>( + out, in, weight, bias, conv_states, conv_dim, d_conv, in_stride); + } + else { + const int total = seq_len * conv_dim; + const int threads = 256; + const int blocks = (total + threads - 1) / threads; + fused_conv1d_prefill_kernel<<>>( + out, in, weight, bias, conv_states, conv_dim, seq_len, d_conv, in_stride); + } +} + +// ============================================================================= +// Element-wise SiLU +// ============================================================================= +template +__global__ void silu_kernel(T* data, int n) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; +#if defined(__CUDA_ARCH__) + if (n % 2 == 0) { + if (tid < n / 2) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + T2* data_v2 = reinterpret_cast(data); + float2 fval = to_float2(data_v2[tid]); + fval.x = fval.x / (1.0f + expf(-fval.x)); + fval.y = fval.y / (1.0f + expf(-fval.y)); + data_v2[tid] = to_vec2(fval, T{}); + } + } + else +#endif + { + if (tid >= n) + return; + float x = static_cast(data[tid]); + data[tid] = static_cast(x / (1.0f + expf(-x))); + } +} + +template +void invokeSiLU(T* data, int n, cudaStream_t stream) +{ + const int threads = 256; + if (n % 2 == 0) { + const int blocks = (n / 2 + threads - 1) / threads; + silu_kernel<<>>(data, n); + } + else { + const int blocks = (n + threads - 1) / threads; + silu_kernel<<>>(data, n); + } +} + +// ============================================================================= +// Explicit instantiations +// ============================================================================= + +#define INSTANTIATE(T) \ + template void invokeCausalConv1d(T*, const T*, const T*, const T*, T*, int, int, int, int, cudaStream_t); \ + template void invokeFusedConv1dSiLU(T*, const T*, const T*, const T*, T*, int, int, int, int, int, cudaStream_t); \ + template void invokeRecurrentGatedDeltaRule( \ + T*, const T*, const T*, const T*, T*, int, int, int, int, int, int, cudaStream_t); \ + template void invokeGatedDeltaRulePrefill( \ + T*, const T*, const T*, const T*, T*, int, int, int, int, int, int, cudaStream_t); \ + template void invokeComputeBetaG(T*, T*, const T*, const T*, const T*, const T*, int, int, cudaStream_t); \ + template void invokeRMSNormGated(T*, const T*, const T*, float, int, int, int, int, cudaStream_t); \ + template void invokeSiLU(T*, int, cudaStream_t); + +INSTANTIATE(half) +INSTANTIATE(nv_bfloat16) + +#undef INSTANTIATE + +} // namespace turbomind diff --git a/src/turbomind/models/llama/gated_delta_net_kernels.h b/src/turbomind/models/llama/gated_delta_net_kernels.h new file mode 100644 index 0000000000..fd58aeaa7b --- /dev/null +++ b/src/turbomind/models/llama/gated_delta_net_kernels.h @@ -0,0 +1,128 @@ +#pragma once + +#include "src/turbomind/core/tensor.h" +#include +#include +#include + +namespace turbomind { + +// ============================================================================= +// Causal Conv1d (channel-first layout, for backward compat) +// ============================================================================= +template +void invokeCausalConv1d(T* out, + const T* in, + const T* weight, + const T* bias, + T* conv_states, + int batch_size, + int conv_dim, + int seq_len, + int d_conv, + cudaStream_t stream); + +// Fused Conv1d + SiLU for row-major (token_num, conv_dim) layout. +// in: (token_num, conv_dim) — row-major +// out: (token_num, conv_dim) — row-major +// weight: (conv_dim, d_conv) +// conv_states: (conv_dim, d_conv) per-request rolling state, may be NULL +template +void invokeFusedConv1dSiLU(T* out, + const T* in, + const T* weight, + const T* bias, + T* conv_states, + int batch_size, + int conv_dim, + int seq_len, + int d_conv, + int in_stride, + cudaStream_t stream); + +// ============================================================================= +// Gated Delta Rule — Recurrent decode (seq_len == 1) +// +// Reads Q/K/V directly from the packed qkv_in buffer (stride = conv_dim), +// L2-normalizes Q and K in-kernel, and handles GQA natively (kh = h / ratio). +// Eliminates separate invokeL2Norm and invokeRepeatInterleave passes. +// +// qkv_in layout per token (row-major): +// [Q: (num_k_heads, key_head_dim) | K: (num_k_heads, key_head_dim) +// | V: (num_v_heads, value_head_dim)] +// where k_dim_total = num_k_heads * key_head_dim. +// ============================================================================= +template +void invokeRecurrentGatedDeltaRule(T* v_out, + const T* qkv_in, + const T* beta, + const T* g, + T* recurrent_state, + int batch_size, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total, + cudaStream_t stream); + +// ============================================================================= +// Gated Delta Rule — Single-launch Prefill (seq_len > 1) +// +// Processes the ENTIRE sequence inside a single kernel launch. +// The original invokeSerialGatedDeltaRule called the decode kernel in a +// host-side for-loop (one CUDA launch per timestep = O(seq_len) dispatches). +// This kernel eliminates that overhead by looping over seq_len on the GPU. +// Also fuses L2Norm and GQA (same as the recurrent decode variant). +// +// v_out layout: (seq_len, num_v_heads, value_head_dim). +// qkv_in layout: (seq_len, conv_dim) packed as described above. +// state layout: (num_v_heads, key_head_dim, value_head_dim) — updated in-place. +// ============================================================================= +template +void invokeGatedDeltaRulePrefill(T* v_out, + const T* qkv_in, + const T* beta, + const T* g, + T* recurrent_state, + int seq_len, + int num_v_heads, + int num_k_heads, + int key_head_dim, + int value_head_dim, + int k_dim_total, + cudaStream_t stream); + +// ============================================================================= +// Helper kernels +// ============================================================================= + +// Compute beta = sigmoid(b) and g = -exp(A_log) * softplus(a + dt_bias) +template +void invokeComputeBetaG(T* beta_out, + T* g_out, + const T* b_in, + const T* a_in, + const T* A_log, + const T* dt_bias, + int total, + int num_v_heads, + cudaStream_t stream); + +// RMSNorm * SiLU-gate (fused output normalization) +template +void invokeRMSNormGated(T* hidden, + const T* gate, + const T* weight, + float eps, + int N, + int head_dim, + int gate_stride, + int num_heads, + cudaStream_t stream); + +// Element-wise SiLU activation in-place +template +void invokeSiLU(T* data, int n, cudaStream_t stream); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index ff8f1ad6cc..0b70f92fd7 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -558,4 +558,33 @@ void AppendTokenIds( AppendTokenIdsKernel<<>>(token_ids_ptrs, output_ids, positions, batch_size); } +template +__global__ void SigmoidGateMultiplyKernel(T* attn, const T* gate_base, int dim, int gate_stride, int num_tokens) +{ + const int ti = blockIdx.x; + const int di = threadIdx.x + blockIdx.y * blockDim.x; + if (ti >= num_tokens || di >= dim) { + return; + } + float g = (float)gate_base[ti * gate_stride + di]; + float s = 1.0f / (1.0f + __expf(-g)); + float a = (float)attn[ti * dim + di]; + attn[ti * dim + di] = (T)(a * s); +} + +void invokeSigmoidGateMultiply( + void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream) +{ + constexpr int block = 256; + const dim3 grid(num_tokens, cdiv(dim, block)); + + auto invoke = [&](auto t) { + using T = decltype(t); + SigmoidGateMultiplyKernel<<>>( + (T*)attn, (const T*)gate_base, dim, gate_stride, num_tokens); + }; + + TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke); +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 5916e1e51a..feb7f01a15 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -79,4 +79,11 @@ void AppendTokenIds(int** token_ids_ptrs, // int batch_size, cudaStream_t stream); +// Apply sigmoid gating: attn[i] *= sigmoid(gate[i]) +// attn: [num_tokens, dim], contiguous +// gate_base: pointer to first gate element in QKV buffer +// gate_stride: stride between tokens in QKV buffer (elements) +void invokeSigmoidGateMultiply( + void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream); + } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 82504af47e..9bb2e9b46a 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "src/turbomind/core/data_type.h" @@ -58,6 +59,21 @@ struct ModelParam { std::vector window_size; std::vector inter_size; + std::vector layer_types; + + // Qwen3.5 Gated DeltaNet linear attention params + int linear_key_head_dim = 0; + int linear_value_head_dim = 0; + int linear_conv_kernel_dim = 0; + int linear_num_key_heads = 0; + int linear_num_value_heads = 0; + + bool attn_output_gate = false; // Qwen3.5: doubles Q projection in full-attention layers + + // Layer indices whose MoE experts use data_type (fp16) instead of + // expert_weight_type (e.g. int4). Populated from modules_to_not_convert + // patterns like 'model.layers.0.'. + std::set unquantized_expert_layers; }; /// TODO: rename all `gate` in the context of MoE router to `router` diff --git a/src/turbomind/models/llama/moe_ffn_layer.cc b/src/turbomind/models/llama/moe_ffn_layer.cc index f1a16f5a68..45a7d21b4f 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.cc +++ b/src/turbomind/models/llama/moe_ffn_layer.cc @@ -92,8 +92,9 @@ void MoeFfnLayer::Forward(ForwardParam& p) // dump_logits(tokens, layer_id); - if (param_.topk_method == "noaux_tc") { + if (param_.topk_method == "noaux_tc" || param_.scoring_func == "sigmoid") { // invokeMoeGate_NoAuxTC clears accum and masks internally + // Also used for sigmoid scoring (e.g. Qwen3.5) which is not supported by V2 kernel TM_CHECK_EQ(param_.n_group, 1); TM_CHECK_EQ(param_.topk_group, 1); const float* correction_bias = diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 0131145410..b4c342102b 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -39,6 +39,7 @@ #include "src/turbomind/macro.h" +#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_rope.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/mla_utils.h" @@ -331,6 +332,24 @@ void UnifiedAttentionLayer::Forward(ForwardParam p) Tensor attn = [&]() -> Tensor { TM_DISPATCH_PRIMARY_DTYPES_RET(qkv.dtype(), invoke); }(); + // Apply sigmoid gating: attn *= sigmoid(gate) + // Gate is stored at the end of each token's QKV: [Q|K|V|Gate] + if (model_param_.attn_output_gate) { + const int q_count = qkv.shape(0); + const int attn_dim = local_head_num_ * size_per_head_; + const int gate_offset = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + const int qkv_stride = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + const auto stream = core::Context::stream().handle(); + invokeSigmoidGateMultiply(attn.raw_data(), + (const char*)qkv.raw_data() + gate_offset * byte_size(qkv.dtype(), 1), + attn_dim, + qkv_stride, + q_count, + qkv.dtype(), + stream); + sync_check_cuda_error(); + } + TM_DEBUG_TENSOR(attn, Concat("attn", layer_id), 3); // if (d.dbg_size) { @@ -384,8 +403,15 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, params.stride = (local_head_num_ + 1 * local_kv_head_num_) * size_per_head_; } else { - params.v = params.k + local_kv_head_num_ * size_per_head_; - params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + params.v = params.k + local_kv_head_num_ * size_per_head_; + // When attn_output_gate, QKV layout is [Q|K|V|Gate] per token + // stride must account for the extra gate portion at the end + if (model_param_.attn_output_gate) { + params.stride = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + } + else { + params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + } } if (weights.qkv.bias) { diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index bde6026f2a..0a8d7508cd 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -25,6 +25,9 @@ namespace turbomind { void UnifiedDecoder::Run(BatchOp op, int phase, TensorMap& env) { attn_layer_->Run(op, phase, env); + if (linear_attn_layer_) { + linear_attn_layer_->Run(op, phase, env); + } } UnifiedDecoder::UnifiedDecoder(const ModelParam& model, @@ -52,6 +55,10 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, attn_layer_ = std::make_unique(model, attn, engine, attn_tp_size_, ctx, phases, (bool)moe_ffn_layer_); + if (std::find(model.layer_types.begin(), model.layer_types.end(), 1) != model.layer_types.end()) { + linear_attn_layer_ = std::make_unique(model, attn, engine, attn_tp_size_, ctx, phases); + } + if (std::accumulate(model.inter_size.begin(), model.inter_size.end(), 0LL)) { ffn_layer_ = std::make_unique(model, ctx); } @@ -193,15 +200,31 @@ void UnifiedDecoder::Forward(int phase, TensorMap& args, const std::vectorForward( - {phase, local_hidden_states, local_hidden_states, weights.at(layer)->self_attn_weights.get(), layer}); + /// self-attention or linear-attention + if (weights.at(layer)->linear_attn_weights) { + linear_attn_layer_->Forward( + {phase, local_hidden_states, local_hidden_states, weights.at(layer)->linear_attn_weights.get(), layer}); + } + else { + attn_layer_->Forward( + {phase, local_hidden_states, local_hidden_states, weights.at(layer)->self_attn_weights.get(), layer}); + } TM_DEBUG_TENSOR(local_hidden_states, Concat("attn_block", layer), 2); + // For gated delta networks, we may need a different output.bias name or it doesn't have it. + // We will just use `output.bias` from either layer. + Tensor out_bias; + if (weights.at(layer)->linear_attn_weights) { + out_bias = weights.at(layer)->linear_attn_weights->out_proj.bias; + } + else { + out_bias = weights.at(layer)->self_attn_weights->output.bias; + } + AllreduceResidualRMSnorm(global_hidden_states, local_residual, - weights.at(layer)->self_attn_weights->output.bias, + out_bias, weights.at(layer)->ffn_norm, local_token_num, attn_tp_group_, diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 199949f478..05e9ea73a4 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -1,6 +1,7 @@ #pragma once #include "src/turbomind/comm/device_comm.h" +#include "src/turbomind/models/llama/GatedDeltaNetLayer.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/models/llama/context.h" @@ -45,6 +46,7 @@ class UnifiedDecoder { int& is_warm_up_; std::unique_ptr attn_layer_; + std::unique_ptr linear_attn_layer_; std::unique_ptr ffn_layer_; std::unique_ptr moe_ffn_layer_; diff --git a/src/turbomind/turbomind.cc b/src/turbomind/turbomind.cc index 1525a3afae..1d8565196f 100644 --- a/src/turbomind/turbomind.cc +++ b/src/turbomind/turbomind.cc @@ -379,6 +379,35 @@ TurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_fac for (auto it = inter_size.begin(); it != inter_size.end(); ++it) { model_param_.inter_size.push_back(it->as()); } + + if (auto layer_types = model["layer_types"]) { + for (auto it = layer_types.begin(); it != layer_types.end(); ++it) { + auto type_str = it->as(""); + if (type_str == "linear_attention") { + model_param_.layer_types.push_back(1); + } + else if (type_str == "full_attention" || type_str.empty()) { + model_param_.layer_types.push_back(0); + } + else { + TM_LOG_WARNING("[TM] Unknown layer_type '%s', treating as full_attention.", type_str.c_str()); + model_param_.layer_types.push_back(0); + } + } + } + + // Qwen3.5 Gated DeltaNet linear attention parameters + model_param_.linear_key_head_dim = model["linear_key_head_dim"].as(0); + model_param_.linear_value_head_dim = model["linear_value_head_dim"].as(0); + model_param_.linear_conv_kernel_dim = model["linear_conv_kernel_dim"].as(0); + model_param_.linear_num_key_heads = model["linear_num_key_heads"].as(0); + model_param_.linear_num_value_heads = model["linear_num_value_heads"].as(0); + model_param_.attn_output_gate = model["attn_output_gate"].as(false); + if (auto uqel = model["unquantized_expert_layers"]) { + for (auto it = uqel.begin(); it != uqel.end(); ++it) { + model_param_.unquantized_expert_layers.insert(it->as()); + } + } model_param_.attn_sink = model["attn_sink"].as(); model_param_.mlp_bias = model["mlp_bias"].as(); if (model["activation_type"].as("") == "gpt-oss") { diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index d3cfda8341..d635ddcf10 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -124,3 +124,66 @@ def test_filter_minp_sorted(): out = _filter_minp_sorted_(scores, min_p) torch.testing.assert_close(out, gt) + + +def test_filter_ngram(): + from lmdeploy.pytorch.engine.logits_process import _filter_repetition_ngram_ + vocab_size = 100 + + def _get_emtas(n, window_size): + batch_size = generated_ids.size(0) + max_n = int(n.max().item()) + same_n = n.eq(max_n).all().item() + max_window_size = window_size + if same_n: + n = None + return batch_size, max_n, max_window_size, n + + # base test + generated_ids = torch.tensor([ + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], + [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], + ], + dtype=torch.int64) + n = torch.tensor([3, 3, 2], dtype=torch.int64) + threshold = torch.tensor([3, 3, 3], dtype=torch.int64) + + batch_size, max_n, max_window_size, n = _get_emtas(n, 10) + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size) + + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 + assert scores[2].isinf().sum().item() == vocab_size - 1 + assert scores[0, stop_words[0, 0]] == 0 + assert scores[2, stop_words[2, 0]] == 0 + + # test no ngram + generated_ids = torch.tensor([ + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]) + n = torch.tensor([3, 0], dtype=torch.int64) + threshold = torch.tensor([3, 0], dtype=torch.int64) + batch_size, max_n, max_window_size, n = _get_emtas(n, 10) + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size) + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 + + # test ids all 0 + generated_ids = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]) + n = torch.tensor([3], dtype=torch.int64) + threshold = torch.tensor([3], dtype=torch.int64) + batch_size, max_n, max_window_size, n = _get_emtas(n, 10) + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size) + assert scores[0].isinf().sum().item() == vocab_size - 1 diff --git a/tests/test_lmdeploy/test_qwen3coder_parser.py b/tests/test_lmdeploy/test_qwen3coder_parser.py new file mode 100644 index 0000000000..b84735a40c --- /dev/null +++ b/tests/test_lmdeploy/test_qwen3coder_parser.py @@ -0,0 +1,255 @@ +import collections +import json +import time +from typing import Generator, List, Tuple, Union + +import pytest +import shortuuid + +from lmdeploy.serve.openai.api_server import VariableInterface +from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, + ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo) +from lmdeploy.serve.openai.tool_parser.qwen3coder_parser import Qwen3CoderToolParser + +TestExpects = collections.namedtuple('TestExpects', 'func_name kwargs') + + +class DummyTokenizer: + + def decode(self, token_ids: List[int]) -> str: + return ' '.join(map(str, token_ids)) + + def encode(self, text: str) -> List[int]: + return [ord(c) for c in text] + + +DELTA_TEXT_SEQUENCE = [ + '好的,我现在帮你调用工具。\n', + '', + '\n', + '\n', + '', + '北京\n', + 'celsius\n', + '\n', + '', +] + +DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [ + '\n\n', + '', + '\n\n', + '上海\n', + '\n', + '', +] + +EXPECTED_CONTENT = '好的,我现在帮你调用工具。' + + +def _chat_completion_v1( + request: ChatCompletionRequest, + text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]: + request_id = f'chat-{shortuuid.random()}' + created_time = int(time.time()) + model_name = request.model + if request.stream: + + def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]: + previous_text = '' + current_text = '' + finish_reason = 'stop' + has_parser = (VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None) + for text in text_sequence: + logprobs, usage = None, None + delta_message = DeltaMessage(role='assistant', content=text) + if has_parser: + current_text = current_text + text + has_tool = VariableInterface.tool_parser is not None + if request.tool_choice != 'none' and has_tool: + tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_message.content, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request) + if tool_delta is not None: + delta_message.tool_calls = tool_delta.tool_calls + delta_message.content = tool_delta.content or '' + if VariableInterface.reasoning_parser is not None: + parser = VariableInterface.reasoning_parser + reasoning_delta = parser.extract_reasoning_content_streaming(previous_text=previous_text, + current_text=current_text, + delta_text=delta_message.content, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[]) + if reasoning_delta is not None: + delta_message.reasoning_content = (reasoning_delta.reasoning_content) + delta_message.content = reasoning_delta.content or '' + if has_parser: + previous_text = current_text + + choice_data = ChatCompletionResponseStreamChoice(index=0, + delta=delta_message, + finish_reason=finish_reason, + logprobs=logprobs) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + usage=usage, + ) + yield response + + return completion_stream_generator() + + text = ''.join(text_sequence) + tool_calls = None + reasoning_content = None + finish_reason = 'stop' + has_tool = VariableInterface.tool_parser is not None + if request.tool_choice != 'none' and has_tool: + tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) + text, tool_calls = tool_call_info.content, tool_call_info.tool_calls + if isinstance(tool_calls, List) and len(tool_calls): + if finish_reason == 'stop': + finish_reason = 'tool_calls' + + if VariableInterface.reasoning_parser is not None: + parser = VariableInterface.reasoning_parser + reasoning_content, text = parser.extract_reasoning_content(text, request) + + choices = [] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content), + finish_reason=finish_reason, + ) + choices.append(choice_data) + + return ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=UsageInfo(), + ) + + +def _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]: + content = '' + reasoning_content = '' + tool_calls = {} + + for stream_resp in _chat_completion_v1(request, text_sequence): + delta_message: DeltaMessage = stream_resp.choices[0].delta + if delta_message.content: + content += delta_message.content + if delta_message.reasoning_content: + reasoning_content += delta_message.reasoning_content + if delta_message.tool_calls: + for c in delta_message.tool_calls: + existing_call = tool_calls.get(c.id, None) + if not existing_call: + tool_calls[c.id] = c + continue + # merge with existing + if c.function.name: + existing_call.function.name = c.function.name + if c.function.arguments: + existing_call.function.arguments = (existing_call.function.arguments or '') + existing_call.function.arguments += c.function.arguments + return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index)) + + +@pytest.mark.parametrize(('text_sequence', 'expects'), [ + (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', { + 'location': '北京', + 'unit': 'celsius' + })]), + (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [ + TestExpects('get_weather', { + 'location': '北京', + 'unit': 'celsius' + }), + TestExpects('get_weather', {'location': '上海'}) + ]), +]) +def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]): + tokenizer = DummyTokenizer() + VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) + VariableInterface.reasoning_parser = None + request = ChatCompletionRequest(model='qwen3coder', messages=[], stream=True) + content, reasoning_content, tool_calls = _stream_parse(request, text_sequence) + assert len(tool_calls) == len(expects) + for parsed_call, expected_call in zip(tool_calls, expects): + assert parsed_call.function.name == expected_call.func_name + args = json.loads(parsed_call.function.arguments) + assert args == expected_call.kwargs + assert content.strip() == EXPECTED_CONTENT + + +@pytest.mark.parametrize(('text_sequence', 'expects'), [ + (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', { + 'location': '北京', + 'unit': 'celsius' + })]), + (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [ + TestExpects('get_weather', { + 'location': '北京', + 'unit': 'celsius' + }), + TestExpects('get_weather', {'location': '上海'}) + ]), +]) +def test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]): + tokenizer = DummyTokenizer() + VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) + VariableInterface.reasoning_parser = None + resp: ChatCompletionResponse = _chat_completion_v1( + ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence) + + assert len(resp.choices) == 1 + first_message = resp.choices[0].message + assert first_message.content.strip() == EXPECTED_CONTENT + assert first_message.reasoning_content is None + assert len(first_message.tool_calls) == len(expects) + for parsed_call, expected_call in zip(first_message.tool_calls, expects): + assert parsed_call.function.name == expected_call.func_name + args = json.loads(parsed_call.function.arguments) + assert args == expected_call.kwargs + + +def test_no_think_nonstream(): + text_sequence = [ + '你好', + '呀', + '!', + '✨', + '', + ' 很', + '高兴', + '见到', + '你', + '!', + ] + tokenizer = DummyTokenizer() + VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) + VariableInterface.reasoning_parser = None + resp: ChatCompletionResponse = _chat_completion_v1( + ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence) + + assert len(resp.choices) == 1 + first_message = resp.choices[0].message + assert first_message.content == '你好呀!✨ 很高兴见到你!' + assert first_message.reasoning_content is None