Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Pixtral model (Mistral) #2381

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cd5312f
so far
yixin-huang1 Dec 6, 2024
4d2307f
adding in support for pixtral
yixin-huang1 Dec 7, 2024
78d7cc2
Merge branch 'main' of https://github.com/sgl-project/sglang
yixin-huang1 Dec 7, 2024
4f139fa
updated llava
yixin-huang1 Dec 7, 2024
8b59db0
formatted
yixin-huang1 Dec 21, 2024
9c262d8
Merge branch 'sgl-project:main' into main
yixin-huang1 Dec 21, 2024
c46125b
pixtral test
yixin-huang1 Dec 21, 2024
4605b00
Merge branch 'main' into yixin's_branch
yixin-huang1 Dec 21, 2024
2c42655
Merge pull request #1 from yixin-huang1/yixin's_branch
yixin-huang1 Dec 21, 2024
dae8cc5
Update http_llava_onevision_test.py
yixin-huang1 Dec 21, 2024
9153821
Merge branch 'yixin's_branch' of https://github.com/yixin-huang1/sgla…
yixin-huang1 Dec 21, 2024
26ecb00
Merge pull request #2 from yixin-huang1/yixin's_branch
yixin-huang1 Dec 21, 2024
50124b5
fixed more issues
yixin-huang1 Dec 26, 2024
0cf22ca
Update run_suite.py
yixin-huang1 Dec 26, 2024
692605c
Update test_vision_openai_server.py
yixin-huang1 Dec 26, 2024
520f7b0
Update test_vision_openai_server.py
yixin-huang1 Dec 27, 2024
3c010b7
Update test_vision_openai_server.py
yixin-huang1 Dec 27, 2024
8f67337
Merge pull request #3 from yixin-huang1/yixin's_branch
yixin-huang1 Dec 27, 2024
1de7e76
Merge branch 'sgl-project:main' into main
yixin-huang1 Dec 27, 2024
2c7bf65
Update http_llava_onevision_test.py
yixin-huang1 Dec 27, 2024
d27a59c
Update test_vision_openai_server.py
yixin-huang1 Dec 27, 2024
063a619
Merge pull request #4 from yixin-huang1/yixin's_branch
yixin-huang1 Dec 27, 2024
9439b9a
fixed import issues
yixin-huang1 Dec 27, 2024
7e6fc29
Merge pull request #5 from yixin-huang1/yixin's_branch
yixin-huang1 Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CLIPVisionModel,
LlavaConfig,
MistralConfig,
PixtralVisionConfig,
Qwen2Config,
SiglipVisionModel,
)
Expand Down Expand Up @@ -554,4 +555,44 @@ def __init__(
)


EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
class LlavaPixtralForCausalLM(LlavaBaseForCausalLM):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()

self.config = config
self.vision_tower = None

if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)

if getattr(self.config, "text_config", None) is None:
self.config.text_config = PixtralVisionConfig(self.config._name_or_path)

self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size

if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"

if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32001

self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = MistralForCausalLM(config, quant_config=quant_config)

if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)


EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM,
LlavaPixtralForCausalLM,
]
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"test_no_chunked_prefill.py",
"test_no_overlap_scheduler.py",
"test_openai_server.py",
"test_pixtral.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
Expand Down
183 changes: 183 additions & 0 deletions test/srt/test_pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
python3 -m unittest test_pixtral
yixin-huang1 marked this conversation as resolved.
Show resolved Hide resolved

"""

import asyncio
import base64
import json
import math
import os
import unittest
from io import BytesIO
from pathlib import Path
from unittest.mock import AsyncMock, patch

import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer, LlavaForConditionalGeneration

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import PortArgs, ServerArgs


class RawPixtralTest(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Define the models to test
self.models = {
"large": "mistralai/Pixtral-Large-Instruct-2411",
"base": "mistralai/Pixtral-12B-2409",
}

# Initialize objects for each model
self.tokenizers = {}
self.models_obj = {}
self.processors = {}
self.devices = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for key, model_name in self.models.items():
# Load tokenizer
self.tokenizers[key] = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)

# Load model
self.models_obj[key] = (
LlavaForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, trust_remote_code=True
)
.eval()
.to(self.devices)
)

# Load processor
self.processors[key] = AutoProcessor.from_pretrained(model_name)

def test_vision_encoder(self):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true",
},
{"type": "text", "text": "Describe this image."},
],
}
]

for key, model_name in self.models.items():
print(f"\n=== Testing {model_name} ===")

# Get the tokenizer, model, and processor for the current model
tokenizer = self.tokenizers[key]
model = self.models_obj[key]
processor = self.processors[key]

# Apply chat template to generate text input
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Download the image
response = requests.get(messages[0]["content"][0]["image"])
main_image = Image.open(BytesIO(response.content))

# Process inputs using the processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)

# Hugging Face model output
with torch.no_grad():
hf_output = model.visual(
inputs["pixel_values"].to(self.devices),
grid_thw=inputs["image_grid_thw"].to(self.devices),
)

# SGLang model setup
model_config = ModelConfig(model_name, model_override_args="{}")
server_args = ServerArgs(model_path=model_name)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=0.8,
gpu_id=0,
tp_rank=0,
tp_size=1,
nccl_port=12435,
server_args=server_args,
)

# SGLang model output
with torch.no_grad():
sglang_output = model_runner.model.visual(
inputs["pixel_values"].to(self.devices),
grid_thw=inputs["image_grid_thw"].to(self.devices),
)

# Comparison metrics
hf = hf_output.float()
sg = sglang_output.float()

# Basic shape and dtype comparison
print("\n=== Basic Properties ===")
print(f"Shapes match: {hf.shape == sg.shape}")
print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")

# Statistical metrics
print("\n=== Statistical Metrics ===")
print(
f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}"
)
print(
f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}"
)
print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
print(
f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
)

# Cosine similarity
cos_sim = F.cosine_similarity(hf, sg)
print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")

# Largest absolute differences
print("\n=== Largest Absolute Differences ===")
diffs = torch.abs(hf - sg)
flat_diffs = diffs.flatten()
top_k = 10
top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)

print(f"\nTop {top_k} largest absolute differences:")
print(
"Index".ljust(30)
+ "Difference".ljust(15)
+ "HF Value".ljust(15)
+ "SGLang Value"
)
print("-" * 75)

for i in range(top_k):
idx = tuple(dim[i] for dim in top_indices)
diff_val = top_values[i].item()
hf_val = hf[idx].item()
sg_val = sg[idx].item()
print(f"{str(idx):<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")

# Assert outputs are close
np.testing.assert_allclose(hf.cpu().numpy(), sg.cpu().numpy(), rtol=1e-3)


if __name__ == "__main__":
unittest.main()
20 changes: 18 additions & 2 deletions test/srt/test_vision_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,24 @@ def setUpClass(cls):
)
cls.base_url += "/v1"

def test_video_chat_completion(self):
pass

class TestPixtralServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "mistralai/Pixtral-12B-2409"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"pixtral",
],
)
cls.base_url += "/v1"


if __name__ == "__main__":
Expand Down
Loading