Skip to content

Commit

Permalink
[Model] Add Idefics3 support (vllm-project#9767)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: B-201 <[email protected]>
Co-authored-by: B-201 <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
2 people authored and JC1DA committed Nov 11, 2024
1 parent dae1bf0 commit 9a8555a
Show file tree
Hide file tree
Showing 8 changed files with 723 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ Text Generation
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
-
- ✅︎
* - :code:`Idefics3ForConditionalGeneration`
- Idefics3
- T + I
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
-
-
* - :code:`InternVLChatModel`
- InternVL2
- T + I\ :sup:`E+`
Expand Down
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ def run_glm4v(question: str, modality: str):
return llm, prompt, stop_token_ids


# Idefics3-8B-Llama3
def run_idefics3(question: str, modality: str):
assert modality == "image"
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

llm = LLM(model=model_name,
max_model_len=8192,
max_num_seqs=2,
enforce_eager=True)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
)
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -397,6 +413,7 @@ def run_glm4v(question: str, modality: str):
"mllama": run_mllama,
"molmo": run_molmo,
"glm4v": run_glm4v,
"idefics3": run_idefics3,
}


Expand Down
25 changes: 25 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,30 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
)


def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=16,
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
Expand All @@ -298,6 +322,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
"mllama": load_mllama,
"idefics3": load_idefics3,
}


Expand Down
16 changes: 16 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,22 @@
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[
pytest.mark.skipif(
transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0"
),
large_gpu_mark(min_gb=48),
],
),
### Tensor parallel / multi-gpu broadcast tests
"broadcast-chameleon": VLMTestInfo(
models=["facebook/chameleon-7b"],
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def _placeholder_str(self, modality: ModalityStr,
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "idefics3":
return "<image>"

raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
Expand Down
25 changes: 24 additions & 1 deletion vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
"""PyTorch Idefics2 model."""

from typing import Optional
from typing import Iterable, Optional, Tuple

import torch
from torch import nn
Expand All @@ -29,6 +29,7 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader


class Idefics2VisionEmbeddings(nn.Module):
Expand Down Expand Up @@ -329,3 +330,25 @@ def forward(
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading

0 comments on commit 9a8555a

Please sign in to comment.