Description
While using a quantized mistral small 3.1 model with vLLM, I encountered a fatal runtime error during execution. The stack trace indicates the error originates from the custom CUTLASS kernel used for quantized matrix multiplication:
RuntimeError: Expected a.dim() == 2 && b.dim() == 2 && c.dim() == 2 to be true, but got false.
This occurs in cutlass_scaled_mm during self.qkv_proj(hidden_states). The root issue seems to be caused by one or more tensors not being 2D as expected by the quantization kernel.
System context:
Model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
Quantization: compressed_tensors_w8a8_int8
Library: PyTorch 2.7
Environment: CUDA/NCCL backend, AWS p4d.24xlarge, llm-compressor (main branch), vLLM(0.9.1)
This issue consistently occurs across multiple implementation attempts and configurations using llm-compressor.
Recipe:
if scheme == "W8A8":
return [
SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[
[["re:^(?!.*vision_tower).*q_proj", "re:^(?!.*vision_tower).*k_proj", "re:^(?!.*vision_tower).*v_proj"], "re:^(?!.*vision_tower).*input_layernorm"],
[["re:^(?!.*vision_tower).*gate_proj", "re:^(?!.*vision_tower).*up_proj"], "re:^(?!.*vision_tower).*post_attention_layernorm"],
[["re:^(?!.*vision_tower).*down_proj"], "re:^(?!.*vision_tower).*up_proj"],
],
),
GPTQModifier(
ignore=[
"lm_head",
"re:.vision_tower.",
"re:.multi_modal_projector."
],
sequential_targets=["MistralDecoderLayer"],
dampening_frac=0.01,
targets="Linear",
scheme=scheme,
# offload_hessians=True
),
]
MAX_SEQ_LEN = 8192
I've also tested with Mistral3ForConditionalGeneration
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
device_map="balanced",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
oneshot(
model=model,
dataset=dataset,
recipe=recipe,
max_seq_length=MAX_SEQ_LEN,
data_collator=create_data_collator,
num_calibration_samples=len(dataset),
trust_remote_code_model=True,
)
save_path = f"{model_name}-quantized.{scheme.lower()}"
logger.info("Saving quantized model to: %s", save_path)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path, save_compressed=True)
processor.save_pretrained(save_path)
The original ignore filter was not respected by GPTQModifier; it seems the relevant modules (e.g., vision tower, ...) were not found or filtered as expected.
INFO 06-13 08:42:50 [loader.py:447] Loading weights took 6.95 seconds
INFO 06-13 08:42:51 [gpu_model_runner.py:1273] Model loading took 23.6805 GiB and 7.212285 seconds
INFO 06-13 08:42:51 [gpu_model_runner.py:1542] Encoder cache will be initialized with a budget of 3080 tokens, and profiled with 1 image items of the maximum feature size.
ERROR 06-13 08:42:53 [core.py:390] EngineCore hit an exception: Traceback (most recent call last):
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 378, in run_engine_core
ERROR 06-13 08:42:53 [core.py:390] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 319, in init
ERROR 06-13 08:42:53 [core.py:390] super().init(vllm_config, executor_class, log_stats)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 71, in init
ERROR 06-13 08:42:53 [core.py:390] self._initialize_kv_caches(vllm_config)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 132, in _initialize_kv_caches
ERROR 06-13 08:42:53 [core.py:390] available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 66, in determine_available_memory
ERROR 06-13 08:42:53 [core.py:390] output = self.collective_rpc("determine_available_memory")
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 06-13 08:42:53 [core.py:390] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/utils.py", line 2347, in run_method
ERROR 06-13 08:42:53 [core.py:390] return func(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-13 08:42:53 [core.py:390] return func(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory
ERROR 06-13 08:42:53 [core.py:390] self.model_runner.profile_run()
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1562, in profile_run
ERROR 06-13 08:42:53 [core.py:390] dummy_encoder_outputs = self.model.get_multimodal_embeddings(
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/mistral3.py", line 558, in get_multimodal_embeddings
ERROR 06-13 08:42:53 [core.py:390] vision_embeddings = self._process_image_input(image_input)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/mistral3.py", line 534, in _process_image_input
ERROR 06-13 08:42:53 [core.py:390] image_features = self.vision_tower(image_input["pixel_values"])
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1300, in forward
ERROR 06-13 08:42:53 [core.py:390] out = self.transformer(
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1191, in forward
ERROR 06-13 08:42:53 [core.py:390] x = layer(x, attention_mask, position_embeddings)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1148, in forward
ERROR 06-13 08:42:53 [core.py:390] r, _ = self.attention.forward(self.attention_norm(hidden_states),
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1091, in forward
ERROR 06-13 08:42:53 [core.py:390] qkv_states, _ = self.qkv_proj(hidden_states)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/linear.py", line 474, in forward
ERROR 06-13 08:42:53 [core.py:390] output_parallel = self.quant_method.apply(self, input, bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 581, in apply
ERROR 06-13 08:42:53 [core.py:390] return scheme.apply_weights(layer, x, bias=bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py", line 110, in apply_weights
ERROR 06-13 08:42:53 [core.py:390] return self.kernel.apply_weights(layer, x, bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py", line 131, in apply_weights
ERROR 06-13 08:42:53 [core.py:390] return ops.cutlass_scaled_mm(x_q,
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/_custom_ops.py", line 557, in cutlass_scaled_mm
ERROR 06-13 08:42:53 [core.py:390] torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/_ops.py", line 1123, in call
ERROR 06-13 08:42:53 [core.py:390] return self._op(*args, **(kwargs or {}))
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] RuntimeError: Expected a.dim() == 2 && b.dim() == 2 && c.dim() == 2 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
ERROR 06-13 08:42:53 [core.py:390]
CRITICAL 06-13 08:42:53 [core_client.py:361] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Killed