Skip to content

Conversation

CUHKSZzxy
Copy link
Collaborator

@CUHKSZzxy CUHKSZzxy commented Sep 29, 2025

Usage

  1. quantize
lmdeploy lite blocked_fp8 ${model_path} --work-dir ${quantized_model_path} --quant-dtype fp8
  1. test case

NOTE: We can use either pytorch or turbomind backend for FP8 inference. Here we take pytorch backend as an example.

from lmdeploy import pipeline, PytorchEngineConfig

model_path = "OpenGVLab/InternVL3_5-8B-FP8"

if __name__ == '__main__':
    engine_config = PytorchEngineConfig(tp=1)
    pipe = pipeline(model_path, backend_config=engine_config)
    response = pipe(["Hi, pls intro yourself", "Shanghai is"])
    print(response)

Accuracy

Dataset: OCRBench
Model: InternVL3.5-8B (FP8), InternVL3_5-30B-A3B (FP8)

Backend InternVL3.5-8B InternVL3.5-8B-FP8 InternVL3_5-30B-A3B InternVL3_5-30B-A3B-FP8
TurboMind 84.3 84.1 88.8 88.4
PyTorch 84.3 84.2 88.7 88.1

Tested with VLMEvalKit.

Checklist

  • Align the quantization config with QWen3 / InternS1 FP8
  • Add documents for blocked FP8
  • Verify the FP8 model accuracy
  • Fix quantizations for MOE models
  • Check whether weight_scale_inv modification affects other quant methods / modules

@CUHKSZzxy CUHKSZzxy marked this pull request as ready for review September 30, 2025 04:35
@lvhan028 lvhan028 added the enhancement New feature or request label Oct 7, 2025
Comment on lines +44 to +54
skip_patterns = [
'lm_head',
'embed_tokens',
'mlp.gate', # sparse MOE router gate
'vision_model', # non-HF InternVL, vision part
'mlp1', # non-HF InternVL, projector
'mlp2', # non-HF InternVL-Flash, projector
'vision_tower', # HF InternVL, vision part
'multi_modal_projector', # HF InternVL, projector
]
modules_to_not_convert = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configurations are model-specific. We should adopt a more maintainable approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the vLLM FP8 compressor example, and noticed that the ignored patterns are indeed model-specific. Currently, these patterns are passed as an input argument named ignore in the quantization recipe.

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_fp8

https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w8a8_fp8/qwen2vl_example.py#L20

How about we also expose this as a configurable input argument, allowing users to define their own ignore patterns as needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RunningLeon As discussed with @CUHKSZzxy, we propose a new --skip-pattern config.py option for custom skip patterns, alongside lmdeploy's internal defaults.
what's your opinion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, if only passing skip patterns, a config file is not necessary.

"""
tensor: torch.Tensor
scale: torch.Tensor
weight_scale_inv: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing scale to weight_scale_inv might affect w8a8 quantized model inference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RunningLeon @grimoire any good ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants