From 9ea0ba65bb3cfdef182f98008891031da4663032 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Sun, 4 Jan 2026 11:12:29 +0000 Subject: [PATCH 1/4] [ascend] fix awq and smoothq --- lmdeploy/pytorch/config.py | 11 ++++++++--- lmdeploy/pytorch/nn/linear/__init__.py | 5 ++++- lmdeploy/pytorch/nn/linear/awq.py | 15 ++++++++++----- lmdeploy/pytorch/nn/linear/w8a8.py | 2 +- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 5a32eded25..459c254ab1 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -24,9 +24,14 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): quantization_config = getattr(config.hf_config, 'quantization_config', dict()) quant_method = quantization_config.get('quant_method', None) if quant_method == 'awq': - logger.debug('set torch_dtype to float16 for awq.') - config.hf_config.torch_dtype = 'float16' - config.dtype = torch.float16 + if dtype == 'bfloat16': + logger.debug('set torch_dtype to bfloat16 for awq.') + config.hf_config.torch_dtype = 'bfloat16' + config.dtype = torch.bfloat16 + else: + logger.debug('set torch_dtype to float16 for awq.') + config.hf_config.torch_dtype = 'float16' + config.dtype = torch.float16 return config torch_dtype = getattr(config.hf_config, 'dtype', None) diff --git a/lmdeploy/pytorch/nn/linear/__init__.py b/lmdeploy/pytorch/nn/linear/__init__.py index 7213fc8e56..06ea53591d 100644 --- a/lmdeploy/pytorch/nn/linear/__init__.py +++ b/lmdeploy/pytorch/nn/linear/__init__.py @@ -70,6 +70,7 @@ def build_linear( is_tp=is_tp, all_reduce=all_reduce, layer_type=layer_type, + dtype=dtype, ) if quant_method == 'smooth_quant': return W8A8Linear(in_features, @@ -229,6 +230,7 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, layer_type=layer_type, + dtype=dtype, ) if quant_method == 'smooth_quant': return MergedW8A8Linear(in_features=in_features, @@ -314,7 +316,8 @@ def build_qkv_proj(in_features: int, bias=bias, device=device, is_tp=is_tp, - num_replicate_kv_heads=num_replicate_kv_heads) + num_replicate_kv_heads=num_replicate_kv_heads, + dtype=dtype) if quant_method == 'smooth_quant': return QKVW8A8Linear(in_features=in_features, num_q_heads=num_q_heads, diff --git a/lmdeploy/pytorch/nn/linear/awq.py b/lmdeploy/pytorch/nn/linear/awq.py index 5e24d93db7..32ab953b24 100644 --- a/lmdeploy/pytorch/nn/linear/awq.py +++ b/lmdeploy/pytorch/nn/linear/awq.py @@ -26,8 +26,9 @@ def __init__( is_tp: bool = False, all_reduce: bool = True, layer_type: str = 'attn', + dtype: Optional[torch.dtype] = torch.float16, ): - super().__init__(dtype=torch.float16, + super().__init__(dtype=dtype, device=device, colwise=colwise, is_tp=is_tp, @@ -180,7 +181,8 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - layer_type: str = 'attn'): + layer_type: str = 'attn', + dtype: Optional[torch.dtype] = torch.float16): self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section_s = all_out_features @@ -202,7 +204,8 @@ def __init__(self, device, colwise=True, is_tp=is_tp, - layer_type=layer_type) + layer_type=layer_type, + dtype=dtype) self.setup_loaders() def setup_loaders(self): @@ -282,7 +285,8 @@ def __init__(self, bias: bool = False, device: Optional[torch.device] = None, is_tp: bool = True, - num_replicate_kv_heads: int = 1): + num_replicate_kv_heads: int = 1, + dtype: Optional[torch.dtype] = torch.float16): self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, @@ -307,7 +311,8 @@ def __init__(self, device=device, is_tp=is_tp, out_names=out_names, - layer_type='attn') + layer_type='attn', + dtype=dtype) def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int): """Update all out features.""" diff --git a/lmdeploy/pytorch/nn/linear/w8a8.py b/lmdeploy/pytorch/nn/linear/w8a8.py index c9105e5599..f59ed97a1e 100644 --- a/lmdeploy/pytorch/nn/linear/w8a8.py +++ b/lmdeploy/pytorch/nn/linear/w8a8.py @@ -25,7 +25,7 @@ def __init__(self, all_reduce: bool = True, quant_dtype: Optional[torch.dtype] = torch.int8, layer_type: str = 'attn'): - super().__init__(dtype=torch.float16, + super().__init__(dtype=dtype, device=device, colwise=colwise, is_tp=is_tp, From 7d0d1bd8bfb42d0c37538b726df1a61be0fd22d5 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 16 Jan 2026 08:11:40 +0000 Subject: [PATCH 2/4] fix code --- lmdeploy/pytorch/check_env/model.py | 2 +- lmdeploy/pytorch/config.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py index 0e5884b69b..e3686a90c9 100644 --- a/lmdeploy/pytorch/check_env/model.py +++ b/lmdeploy/pytorch/check_env/model.py @@ -52,7 +52,7 @@ def check_dtype(self, config): from lmdeploy.pytorch.config import ModelConfig from lmdeploy.utils import is_bf16_supported - model_config = ModelConfig.from_hf_config(config, model_path=model_path, dtype=dtype) + model_config = ModelConfig.from_hf_config(config, model_path=model_path, dtype=dtype, device_type=device_type) if model_config.dtype == torch.bfloat16: if not is_bf16_supported(device_type): logger.warning('Device does not support bfloat16.') diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 459c254ab1..96f12e437c 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -10,7 +10,7 @@ from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value -def _update_torch_dtype(config: 'ModelConfig', dtype: str): +def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'cuda'): """Update the torch dtype from the model config. Args: @@ -24,8 +24,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): quantization_config = getattr(config.hf_config, 'quantization_config', dict()) quant_method = quantization_config.get('quant_method', None) if quant_method == 'awq': - if dtype == 'bfloat16': - logger.debug('set torch_dtype to bfloat16 for awq.') + if dtype == 'bfloat16' and device_type == 'ascend': + logger.debug('awq on ascend only support bfloat16, set torch_dtype to bfloat16 for awq.') config.hf_config.torch_dtype = 'bfloat16' config.dtype = torch.bfloat16 else: @@ -372,6 +372,7 @@ def from_hf_config( dist_config: DistConfig = None, is_draft_model: bool = False, spec_method: str = None, + device_type: str = 'cuda', ): """From huggingface config.""" from lmdeploy.pytorch.configurations import AutoModelConfigBuilder @@ -400,7 +401,7 @@ def from_hf_config( assert tp % model_config.num_key_value_heads == 0 # should after setting `hf_config` and `model_arch` attributes - model_config = _update_torch_dtype(model_config, dtype) + model_config = _update_torch_dtype(model_config, dtype, device_type=device_type) # update eos_token_id to list if isinstance(model_config.eos_token_id, int): From 83a0aaef93d83d38f188bec60e0ee536fe319ed9 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 16 Jan 2026 08:24:20 +0000 Subject: [PATCH 3/4] fix code --- lmdeploy/pytorch/config.py | 2 ++ lmdeploy/pytorch/engine/executor/__init__.py | 1 + 2 files changed, 3 insertions(+) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 96f12e437c..a96020e900 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -323,6 +323,7 @@ def from_pretrained( hf_overrides: Dict[str, Any] = None, is_draft_model: bool = False, spec_method: str = None, + device_type: str = 'cuda', ): """Instantiate one of the configuration classes of the library from a pretrained model configuration. @@ -351,6 +352,7 @@ def from_pretrained( dist_config=dist_config, is_draft_model=is_draft_model, spec_method=spec_method, + device_type=device_type, ) if hf_overrides is not None: diff --git a/lmdeploy/pytorch/engine/executor/__init__.py b/lmdeploy/pytorch/engine/executor/__init__.py index ec7b736015..0073d3cea8 100644 --- a/lmdeploy/pytorch/engine/executor/__init__.py +++ b/lmdeploy/pytorch/engine/executor/__init__.py @@ -78,6 +78,7 @@ def build_executor( dist_config=dist_config, is_draft_model=False, spec_method=None if specdecode_config is None else specdecode_config.method, + device_type=device_type, ) if distributed_executor_backend is None: From edf66612d4261efcda799888d73c20b6580ea8d1 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 16 Jan 2026 08:49:14 +0000 Subject: [PATCH 4/4] fix code --- lmdeploy/pytorch/check_env/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py index e3686a90c9..b1a2daaa42 100644 --- a/lmdeploy/pytorch/check_env/model.py +++ b/lmdeploy/pytorch/check_env/model.py @@ -52,7 +52,10 @@ def check_dtype(self, config): from lmdeploy.pytorch.config import ModelConfig from lmdeploy.utils import is_bf16_supported - model_config = ModelConfig.from_hf_config(config, model_path=model_path, dtype=dtype, device_type=device_type) + model_config = ModelConfig.from_hf_config(config, + model_path=model_path, + dtype=dtype, + device_type=device_type) if model_config.dtype == torch.bfloat16: if not is_bf16_supported(device_type): logger.warning('Device does not support bfloat16.')