diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c8ce9302b4..4cd3e78333 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -25,6 +25,7 @@ CLIPVisionModel, LlavaConfig, MistralConfig, + PixtralVisionConfig, Qwen2Config, SiglipVisionModel, ) @@ -554,4 +555,44 @@ def __init__( ) -EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] +class LlavaPixtralForCausalLM(LlavaBaseForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.vision_tower = None + + if getattr(self.config, "vision_config", None) is None: + self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) + + if getattr(self.config, "text_config", None) is None: + self.config.text_config = PixtralVisionConfig(self.config._name_or_path) + + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + if getattr(self.config, "projector_hidden_act", None) is None: + self.config.projector_hidden_act = "gelu" + + if getattr(self.config, "image_token_index", None) is None: + self.config.image_token_index = 32001 + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = MistralForCausalLM(config, quant_config=quant_config) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + +EntryClass = [ + LlavaLlamaForCausalLM, + LlavaQwenForCausalLM, + LlavaMistralForCausalLM, + LlavaPixtralForCausalLM, +] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a0ca5fabbf..1a766d352f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -24,6 +24,7 @@ "test_no_chunked_prefill.py", "test_no_overlap_scheduler.py", "test_openai_server.py", + "test_pixtral.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", "test_retract_decode.py", diff --git a/test/srt/test_pixtral.py b/test/srt/test_pixtral.py new file mode 100644 index 0000000000..4e4aa95373 --- /dev/null +++ b/test/srt/test_pixtral.py @@ -0,0 +1,183 @@ +""" +python3 -m unittest test_pixtral + +""" + +import asyncio +import base64 +import json +import math +import os +import unittest +from io import BytesIO +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoProcessor, AutoTokenizer, LlavaForConditionalGeneration + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import PortArgs, ServerArgs + + +class RawPixtralTest(unittest.IsolatedAsyncioTestCase): + def setUp(self): + # Define the models to test + self.models = { + "large": "mistralai/Pixtral-Large-Instruct-2411", + "base": "mistralai/Pixtral-12B-2409", + } + + # Initialize objects for each model + self.tokenizers = {} + self.models_obj = {} + self.processors = {} + self.devices = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for key, model_name in self.models.items(): + # Load tokenizer + self.tokenizers[key] = AutoTokenizer.from_pretrained( + model_name, trust_remote_code=True + ) + + # Load model + self.models_obj[key] = ( + LlavaForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + .eval() + .to(self.devices) + ) + + # Load processor + self.processors[key] = AutoProcessor.from_pretrained(model_name) + + def test_vision_encoder(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + for key, model_name in self.models.items(): + print(f"\n=== Testing {model_name} ===") + + # Get the tokenizer, model, and processor for the current model + tokenizer = self.tokenizers[key] + model = self.models_obj[key] + processor = self.processors[key] + + # Apply chat template to generate text input + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Download the image + response = requests.get(messages[0]["content"][0]["image"]) + main_image = Image.open(BytesIO(response.content)) + + # Process inputs using the processor + inputs = processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="pt", + ) + + # Hugging Face model output + with torch.no_grad(): + hf_output = model.visual( + inputs["pixel_values"].to(self.devices), + grid_thw=inputs["image_grid_thw"].to(self.devices), + ) + + # SGLang model setup + model_config = ModelConfig(model_name, model_override_args="{}") + server_args = ServerArgs(model_path=model_name) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=server_args, + ) + + # SGLang model output + with torch.no_grad(): + sglang_output = model_runner.model.visual( + inputs["pixel_values"].to(self.devices), + grid_thw=inputs["image_grid_thw"].to(self.devices), + ) + + # Comparison metrics + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print( + f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}" + ) + print( + f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}" + ) + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + print(f"{str(idx):<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + # Assert outputs are close + np.testing.assert_allclose(hf.cpu().numpy(), sg.cpu().numpy(), rtol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index e19e6b01d5..d1383a9016 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -440,8 +440,24 @@ def setUpClass(cls): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass + +class TestPixtralServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "mistralai/Pixtral-12B-2409" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "pixtral", + ], + ) + cls.base_url += "/v1" if __name__ == "__main__":