Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Qwen-VL #244

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions python/sglang/srt/models/qwenvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""

from typing import List, Optional

import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.models.llava import (
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from sglang.srt.models.qwen import QWenLMHeadModel
from torch import nn
from vllm.transformers_utils.configs.qwen import QWenConfig
from transformers import CLIPVisionModel
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.transformers_utils.configs.qwen import QWenConfig

def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype

if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos

# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb

class Resampler(nn.Module):
def __init__(self, config):
super().__init__()
self.grid_size = 16
self.num_queries = self.grid_size * self.grid_size # Default is 256 since grid_size ** 2
self.embed_dim = config.visual.output_dim
self.num_heads = config.visual.num_heads

self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, self.grid_size)).float()
).requires_grad_(False)

self.query = nn.Parameter(torch.zeros(self.num_queries, self.embed_dim))
self.kv_proj = nn.Linear(kv_dim, self.embed_dim, bias=False)
self.attn = nn.MultiHeadAttention(self.embed_dim, self.num_heads)
self.ln_q = nn.LayerNorm(self.embed_dim)
self.ln_kv = nn.LayerNorm(self.embed_dim)

def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))

x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)

N = x.shape[1]
q = self.ln_q(self.query)
repeated_query = q.unsqueeze(1).repeat(1, N, 1)
out = self.attn(
repeated_query + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)

class QwenVLForCausalLM(nn.Module):
def __init__(
self,
config: QwenConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.visual.hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.resampler = QwenResampler(self.config)
self.language_model = QWenLMHeadModel(config, linear_method)

def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)

pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset

def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.

selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)

return image_features

def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
pass

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()

self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size

# NOTE(chris): Not relevant?
# self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
# self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
# self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")

# load mm_projector
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)

monkey_path_clip_vision_embed_forward()

EntryClass = QwenVLForCausalLM