Skip to content

Commit

Permalink
Support qwen2 vl (#2689)
Browse files Browse the repository at this point in the history
* feat: add support for qwen2 vl model

* feat: fix token padding, enable warmup and process basic request

* fix: improve get_position_ids, add lift embed_tokens

* fix: remove get_cos_sin_hack dev function

* feat: add simple test chat with meesage and text

* fix: lint test

* fix: adjust positional embeddings for multi dimensional position ids

* fix: update docs and lint unused vars

* fix: include linted file

* fix: add norm after text output

* fix: format model file

* fix: adjust for ruff lints

* fix: remove unused rotate_half

* feat: refactors and calc num features

* fix: prefer position_ids passed from vlm causal lm and reset ids on batch

* fix: adjust get_position_ids if not available and add required args to signatures

* fix: adjust resize case for qwen2_vl warmup

* fix: avoid qwen2 vl specific paths with qwen2
  • Loading branch information
drbh authored Oct 30, 2024
1 parent 46aeb08 commit befd9f6
Show file tree
Hide file tree
Showing 13 changed files with 705 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
- [Opt](https://huggingface.co/facebook/opt-6.7b)
- [T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1730164250,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 58,
"prompt_tokens": 349,
"total_tokens": 407
}
}
42 changes: 42 additions & 0 deletions integration-tests/models/test_flash_qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest


@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client


@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
max_tokens=100,
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
)

assert (
response.choices[0].message.content
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)

assert response == response_snapshot
29 changes: 29 additions & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,39 @@ impl Paligemma {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2VlVisionConfig {
pub(crate) depth: usize,
pub(crate) embed_dim: usize,
pub(crate) mlp_ratio: usize,
pub(crate) num_heads: usize,
pub(crate) in_chans: usize,
pub(crate) hidden_size: usize,
pub(crate) patch_size: usize,
pub(crate) spatial_merge_size: usize,
pub(crate) spatial_patch_size: usize,
pub(crate) temporal_patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2Vl {
pub(crate) vision_config: Qwen2VlVisionConfig,
}

impl Qwen2Vl {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
num_pixels / self.vision_config.patch_size.pow(2)
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
Qwen2Vl(Qwen2Vl),
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,
Expand Down
8 changes: 7 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ fn image_tokens(
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
Expand All @@ -620,7 +624,9 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def static(cls, config, dim, base, device):

if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
Expand Down
20 changes: 20 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
Expand Down Expand Up @@ -275,6 +278,11 @@ class ModelType(enum.Enum):
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
OPT = {
"type": "opt",
"name": "Opt",
Expand Down Expand Up @@ -1193,6 +1201,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def forward(
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = (
config.rope_scaling.get("mrope_section", None)
if config.rope_scaling is not None
else None
)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads

Expand Down Expand Up @@ -122,6 +127,17 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

if self.mrope_section is not None:
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

if prefill_cache_indices is not None:
Expand Down Expand Up @@ -270,9 +286,6 @@ def __init__(self, prefix: str, config, weights):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
Expand All @@ -296,7 +309,7 @@ def __init__(self, prefix: str, config, weights):

def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
Expand All @@ -307,13 +320,16 @@ def forward(
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = inputs_embeds

# Get rotary cos and sin for this forward
# Avoid to index in each layer
# flatten position ids from 2D to 1D
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
position_ids.flatten(), true_max_s, hidden_states.dtype
)
# reshape back to 2D if the position_ids were 2D
if position_ids.size(0) != cos.size(0):
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)

residual = None
for i, layer in enumerate(self.layers):
Expand Down Expand Up @@ -352,6 +368,12 @@ def __init__(self, prefix: str, config, weights):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)

self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)

self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
Expand Down Expand Up @@ -382,8 +404,10 @@ def forward(
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)

inputs_embeds = self.embed_tokens(input_ids)

hidden_states = self.model(
input_ids,
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def forward(
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def forward(
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
Expand Down
Loading

0 comments on commit befd9f6

Please sign in to comment.