Skip to content

cutlass_scaled_mm RuntimeError: Expected a.dim() == 2 && b.dim() == 2 && c.dim() == 2 in quantized linear layer (Mistral 3.1 + vLLM + CUTLASS kernel) #1542

Open
@Alex31840

Description

@Alex31840

While using a quantized mistral small 3.1 model with vLLM, I encountered a fatal runtime error during execution. The stack trace indicates the error originates from the custom CUTLASS kernel used for quantized matrix multiplication:
RuntimeError: Expected a.dim() == 2 && b.dim() == 2 && c.dim() == 2 to be true, but got false.

This occurs in cutlass_scaled_mm during self.qkv_proj(hidden_states). The root issue seems to be caused by one or more tensors not being 2D as expected by the quantization kernel.

System context:
Model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
Quantization: compressed_tensors_w8a8_int8
Library: PyTorch 2.7
Environment: CUDA/NCCL backend, AWS p4d.24xlarge, llm-compressor (main branch), vLLM(0.9.1)

This issue consistently occurs across multiple implementation attempts and configurations using llm-compressor.

Recipe:
if scheme == "W8A8":
return [
SmoothQuantModifier(
smoothing_strength=0.8,
mappings=[
[["re:^(?!.*vision_tower).*q_proj", "re:^(?!.*vision_tower).*k_proj", "re:^(?!.*vision_tower).*v_proj"], "re:^(?!.*vision_tower).*input_layernorm"],
[["re:^(?!.*vision_tower).*gate_proj", "re:^(?!.*vision_tower).*up_proj"], "re:^(?!.*vision_tower).*post_attention_layernorm"],
[["re:^(?!.*vision_tower).*down_proj"], "re:^(?!.*vision_tower).*up_proj"],
],
),
GPTQModifier(
ignore=[
"lm_head",
"re:.vision_tower.",
"re:.multi_modal_projector."
],
sequential_targets=["MistralDecoderLayer"],
dampening_frac=0.01,
targets="Linear",
scheme=scheme,
# offload_hessians=True
),
]

MAX_SEQ_LEN = 8192

I've also tested with Mistral3ForConditionalGeneration
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
device_map="balanced",
torch_dtype="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

oneshot(
model=model,
dataset=dataset,
recipe=recipe,
max_seq_length=MAX_SEQ_LEN,
data_collator=create_data_collator,
num_calibration_samples=len(dataset),
trust_remote_code_model=True,
)

save_path = f"{model_name}-quantized.{scheme.lower()}"
logger.info("Saving quantized model to: %s", save_path)

os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path, save_compressed=True)
processor.save_pretrained(save_path)

The original ignore filter was not respected by GPTQModifier; it seems the relevant modules (e.g., vision tower, ...) were not found or filtered as expected.

INFO 06-13 08:42:50 [loader.py:447] Loading weights took 6.95 seconds
INFO 06-13 08:42:51 [gpu_model_runner.py:1273] Model loading took 23.6805 GiB and 7.212285 seconds
INFO 06-13 08:42:51 [gpu_model_runner.py:1542] Encoder cache will be initialized with a budget of 3080 tokens, and profiled with 1 image items of the maximum feature size.
ERROR 06-13 08:42:53 [core.py:390] EngineCore hit an exception: Traceback (most recent call last):
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 378, in run_engine_core
ERROR 06-13 08:42:53 [core.py:390] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 319, in init
ERROR 06-13 08:42:53 [core.py:390] super().init(vllm_config, executor_class, log_stats)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 71, in init
ERROR 06-13 08:42:53 [core.py:390] self._initialize_kv_caches(vllm_config)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 132, in _initialize_kv_caches
ERROR 06-13 08:42:53 [core.py:390] available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 66, in determine_available_memory
ERROR 06-13 08:42:53 [core.py:390] output = self.collective_rpc("determine_available_memory")
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 06-13 08:42:53 [core.py:390] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/utils.py", line 2347, in run_method
ERROR 06-13 08:42:53 [core.py:390] return func(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-13 08:42:53 [core.py:390] return func(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory
ERROR 06-13 08:42:53 [core.py:390] self.model_runner.profile_run()
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1562, in profile_run
ERROR 06-13 08:42:53 [core.py:390] dummy_encoder_outputs = self.model.get_multimodal_embeddings(
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/mistral3.py", line 558, in get_multimodal_embeddings
ERROR 06-13 08:42:53 [core.py:390] vision_embeddings = self._process_image_input(image_input)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/mistral3.py", line 534, in _process_image_input
ERROR 06-13 08:42:53 [core.py:390] image_features = self.vision_tower(image_input["pixel_values"])
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1300, in forward
ERROR 06-13 08:42:53 [core.py:390] out = self.transformer(
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1191, in forward
ERROR 06-13 08:42:53 [core.py:390] x = layer(x, attention_mask, position_embeddings)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1148, in forward
ERROR 06-13 08:42:53 [core.py:390] r, _ = self.attention.forward(self.attention_norm(hidden_states),
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/models/pixtral.py", line 1091, in forward
ERROR 06-13 08:42:53 [core.py:390] qkv_states, _ = self.qkv_proj(hidden_states)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 06-13 08:42:53 [core.py:390] return self._call_impl(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in call_impl
ERROR 06-13 08:42:53 [core.py:390] return forward_call(*args, **kwargs)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/linear.py", line 474, in forward
ERROR 06-13 08:42:53 [core.py:390] output_parallel = self.quant_method.apply(self, input
, bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 581, in apply
ERROR 06-13 08:42:53 [core.py:390] return scheme.apply_weights(layer, x, bias=bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py", line 110, in apply_weights
ERROR 06-13 08:42:53 [core.py:390] return self.kernel.apply_weights(layer, x, bias)
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py", line 131, in apply_weights
ERROR 06-13 08:42:53 [core.py:390] return ops.cutlass_scaled_mm(x_q,
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/vllm/_custom_ops.py", line 557, in cutlass_scaled_mm
ERROR 06-13 08:42:53 [core.py:390] torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
ERROR 06-13 08:42:53 [core.py:390] File "/opt/pytorch/lib/python3.12/site-packages/torch/_ops.py", line 1123, in call
ERROR 06-13 08:42:53 [core.py:390] return self._op(*args, **(kwargs or {}))
ERROR 06-13 08:42:53 [core.py:390] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-13 08:42:53 [core.py:390] RuntimeError: Expected a.dim() == 2 && b.dim() == 2 && c.dim() == 2 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
ERROR 06-13 08:42:53 [core.py:390]
CRITICAL 06-13 08:42:53 [core_client.py:361] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Killed

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions