Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: InternVL2 Mismatch in number of image tokens and image embedding size #7160

Closed
GohioAC opened this issue Aug 5, 2024 · 18 comments · Fixed by #7164
Closed

[Bug]: InternVL2 Mismatch in number of image tokens and image embedding size #7160

GohioAC opened this issue Aug 5, 2024 · 18 comments · Fixed by #7164
Labels
bug Something isn't working

Comments

@GohioAC
Copy link

GohioAC commented Aug 5, 2024

Your current environment

Collecting environment information...
/opt/aritra.c/worktree/vllm-main/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
No module named 'vllm.commit_id'
  from vllm.version import __version__ as VLLM_VERSION
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.30.0
Libc version: glibc-2.31

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-30-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        52 bits physical, 57 bits virtual
CPU(s):                               208
On-line CPU(s) list:                  0-207
Thread(s) per core:                   2
Core(s) per socket:                   52
Socket(s):                            2
NUMA node(s):                         2
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                143
Model name:                           Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
Stepping:                             8
CPU MHz:                              2699.998
BogoMIPS:                             5399.99
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            4.9 MiB
L1i cache:                            3.3 MiB
L2 cache:                             208 MiB
L3 cache:                             210 MiB
NUMA node0 CPU(s):                    0-51,104-155
NUMA node1 CPU(s):                    52-103,156-207
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_
freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpc
id rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpop
cntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize arch_capabilities

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] pytorch-lightning==2.3.3
[pip3] pyzmq==26.0.3
[pip3] torch==2.4.0
[pip3] torchmetrics==1.4.0.post0
[pip3] torchvision==0.19.0
[pip3] transformers==4.43.3
[pip3] triton==3.0.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] pytorch-lightning         2.3.3                    pypi_0    pypi
[conda] pyzmq                     26.0.3                   pypi_0    pypi
[conda] torch                     2.4.0                    pypi_0    pypi
[conda] torchmetrics              1.4.0.post0              pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi
[conda] transformers              4.43.3                   pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    52-103,156-207  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    52-103,156-207  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    52-103,156-207  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      52-103,156-207  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

Offline inference for InternVL2 fails frequently due to mismatch in image tokens in the prompt and size of ViT embeddings.

from vllm import LLM, SamplingParams
from PIL import Image


llm = LLM(
    model="OpenGVLab/InternVL2-26B",
    enforce_eager=True,
    tensor_parallel_size=1,
    seed=42,
    max_model_len=8192,
    trust_remote_code=True,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop=["<|im_end|>"])
image = Image.open(
    "images/89874e-pale-yellow-fs-mini-klub-12-18-months-original-imaeph9vtzfhrnav.jpeg"
)
prompt = llm.get_tokenizer().apply_chat_template(
    [
        {"role": "system", "content": "Answer the question."},
        {"role": "user", "content": "<image>\nWhat is shown in the image?"},
    ],
    tokenize=False,
    add_generation_prompt=True,
)
inputs = {"prompt": prompt, "multi_modal_data": {"image": image}}
outputs = llm.generate(inputs, sampling_params=sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
CUDA_VISIBLE_DEVICES=2 python debug.py
/opt/aritra.c/worktree/vllm-main/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
No module named 'vllm.commit_id'
  from vllm.version import __version__ as VLLM_VERSION
INFO 08-05 19:48:10 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='/opt/aritra.c/worktree/llava-finetune-v1/LLaVA/data/checkpoint/future/InternVL2-26B', speculative_config=None, tokenizer='/opt/aritra.c/worktree/llava-finetune-v1/LLaVA/data/checkpoint/future/InternVL2-26B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=42, served_model_name=/opt/aritra.c/worktree/llava-finetune-v1/LLaVA/data/checkpoint/future/InternVL2-26B, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-05 19:48:10 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
WARNING 08-05 19:48:10 logger.py:146] VLLM_TRACE_FUNCTION is enabled. It will record every function executed by Python. This will slow down the code. It is suggested to be used for debugging hang or crashes only.
INFO 08-05 19:48:10 logger.py:150] Trace frame log is saved to /tmp/vllm/vllm-instance-10c45dd2af9949f9b2e55d4a3b04579c/VLLM_TRACE_FUNCTION_for_process_2460720_thread_140321903204160_at_2024-08-05_19:48:10.741247.log
DEBUG 08-05 19:48:15 parallel_state.py:845] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://10.146.32.17:35725 backend=nccl
INFO 08-05 19:48:15 model_runner.py:720] Starting to load model /opt/aritra.c/worktree/llava-finetune-v1/LLaVA/data/checkpoint/future/InternVL2-26B...
Loading safetensors checkpoint shards:   0% Completed | 0/11 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  18% Completed | 2/11 [00:00<00:01,  8.51it/s]
Loading safetensors checkpoint shards:  64% Completed | 7/11 [00:01<00:00,  5.34it/s]
Loading safetensors checkpoint shards: 100% Completed | 11/11 [00:02<00:00,  4.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 11/11 [00:02<00:00,  4.79it/s]

INFO 08-05 19:48:31 model_runner.py:732] Loading model weights took 47.5707 GB
WARNING 08-05 19:48:31 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 08-05 19:48:34 gpu_executor.py:102] # GPU blocks: 6654, # CPU blocks: 1365
Processed prompts:   0%|                                                                                                                                          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/opt/aritra.c/worktree/llava-finetune-v1/LLaVA/scripts/debug.py", line 362, in <module>
[rank0]:     outputs = llm.generate(inputs, sampling_params=sampling_params)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/utils.py", line 895, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/entrypoints/llm.py", line 330, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/entrypoints/llm.py", line 611, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/engine/llm_engine.py", line 919, in step
[rank0]:     output = self.model_executor.execute_model(
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/executor/gpu_executor.py", line 110, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/worker/worker_base.py", line 273, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/opt/aritra.c/.venvs/vllm_infer_latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/worker/model_runner.py", line 1363, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/opt/aritra.c/.venvs/vllm_infer_latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/aritra.c/.venvs/vllm_infer_latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/model_executor/models/internvl.py", line 397, in forward
[rank0]:     inputs_embeds = merge_vision_embeddings(
[rank0]:   File "/opt/aritra.c/worktree/vllm-main/vllm/model_executor/models/utils.py", line 31, in merge_vision_embeddings
[rank0]:     raise ValueError(
[rank0]: ValueError: Attempted to assign 1 x 256 = 256 image tokens to 512 placeholders
Processed prompts:   0%|                                                                                                                                          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
@GohioAC GohioAC added the bug Something isn't working label Aug 5, 2024
@GohioAC
Copy link
Author

GohioAC commented Aug 5, 2024

@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 5, 2024

cc @Isotr0py

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 5, 2024

Seems that there is a trouble when calculating num_patch for small image. I will fix it soon.

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 5, 2024

@GohioAC #7164 should fix this bug. And it works on InternVL2-2B with this image after the fix:

$ python examples/bug_example.py
INFO 08-05 23:27:38 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='/data/LLM-model/InternVL2-2B', speculative_config=None, tokenizer='/data/LLM-model/InternVL2-2B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=42, served_model_name=/data/LLM-model/InternVL2-2B, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-05 23:27:38 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
WARNING 08-05 23:27:38 cpu_executor.py:345] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
INFO 08-05 23:27:38 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:38 selector.py:66] Using Torch SDPA backend.
INFO 08-05 23:27:41 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:41 selector.py:66] Using Torch SDPA backend.
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.94it/s]
 
INFO 08-05 23:27:42 cpu_executor.py:208] # CPU blocks: 2730
WARNING 08-05 23:27:42 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:54<00:00, 54.82s/it, est. speed input: 5.22 toks/s, output: 4.67 toks/s]
Prompt: '<|im_start|>system\nAnswer the question.<|im_end|>\n<|im_start|>user\n\nWhat is shown in the image?<|im_end|>\n<|im_start|>assistant\n', Generated text: "The image shows a bright yellow jacket with a hood. The jacket has a colorful design on the front, including a green bow on the left chest area and a patch on the right side. The patch features a cartoonish design with a smiling face and some text. The jacket also has a pocket on the left side with a cartoon character and some text. The hood is up, and the jacket appears to be made of a soft, possibly fleece material.\nIs there anything else I can help you with?{No, that's all!}"

@github-0-searcher
Copy link

@GohioAC #7164 should fix this bug. And it works on InternVL2-2B with this image after the fix:

$ python examples/bug_example.py
INFO 08-05 23:27:38 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='/data/LLM-model/InternVL2-2B', speculative_config=None, tokenizer='/data/LLM-model/InternVL2-2B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=42, served_model_name=/data/LLM-model/InternVL2-2B, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-05 23:27:38 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
WARNING 08-05 23:27:38 cpu_executor.py:345] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
INFO 08-05 23:27:38 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:38 selector.py:66] Using Torch SDPA backend.
INFO 08-05 23:27:41 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:41 selector.py:66] Using Torch SDPA backend.
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.94it/s]
 
INFO 08-05 23:27:42 cpu_executor.py:208] # CPU blocks: 2730
WARNING 08-05 23:27:42 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:54<00:00, 54.82s/it, est. speed input: 5.22 toks/s, output: 4.67 toks/s]
Prompt: '<|im_start|>system\nAnswer the question.<|im_end|>\n<|im_start|>user\n\nWhat is shown in the image?<|im_end|>\n<|im_start|>assistant\n', Generated text: "The image shows a bright yellow jacket with a hood. The jacket has a colorful design on the front, including a green bow on the left chest area and a patch on the right side. The patch features a cartoonish design with a smiling face and some text. The jacket also has a pocket on the left side with a cartoon character and some text. The hood is up, and the jacket appears to be made of a soft, possibly fleece material.\nIs there anything else I can help you with?{No, that's all!}"

Still getting this error with vllm==0.5.4.
I gather that this pr has not been approved to be merged into release version. So I manually did the correction by myself.

This pr seems to change vllm/model_executor/models/internvl.py only.
I copied and pasted https://github.com/vllm-project/vllm/blob/2676f58c9e3f9d84399822c04657a83a4fae30dd/vllm/model_executor/models/internvl.py to my local path.

This does not works for me. I tried 2B 8B 28B.
My error looks like:
ValueError: Attempted to assign 7 x 256 = 1792 image tokens to 507 placeholders

Thanks for your work!

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 6, 2024

@github-0-searcher Did you set max_model_len? You can try to set a larger max_model_len like 4096.

@github-0-searcher
Copy link

Thanks for your fast reply.
Seems to work well now :-)

@github-0-searcher
Copy link

btw i notice a weird situation. 26B model's output is just ok. But if 8B model or 2B model is used, the output will be generated in a repetitive manner, either repeating some punctuations or a short sentence.

What could be wrong here?

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 6, 2024

@github-0-searcher Can you provide the prompt and image? So that I can figure it out.

@JUNJIE99
Copy link

JUNJIE99 commented Aug 7, 2024

Thanks for your fast reply. Seems to work well now :-)

Hello, have you tried multi-image inference? I want to know how to pass two images into the inputs of llm.generate.

Thanks!

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 7, 2024

@JUNJIE99 The InternVL implementation in vllm hasn't supported multi-image inference yet. But it has been in our roadmap (#4194) and will work for it soon! A PR for this feature is also welcomed!

@JUNJIE99
Copy link

JUNJIE99 commented Aug 7, 2024 via email

@JUNJIE99
Copy link

JUNJIE99 commented Aug 7, 2024

@JUNJIE99 The InternVL implementation in vllm hasn't supported multi-image inference yet. But it has been in our roadmap (#4194) and will work for it soon! A PR for this feature is also welcomed!

Apologies for the interruption once again, but I was wondering if there is a timeline for updates related to multiple images inference?

@Isotr0py
Copy link
Collaborator

Isotr0py commented Aug 7, 2024

@JUNJIE99 As shown in #4194, I think we have better wait #7230 merged before adding multiple images inference for this model.
Once the PR merged, multiple images inference should be updated soon.

@JUNJIE99
Copy link

JUNJIE99 commented Aug 7, 2024

Thank you for your response and for the great work you're doing. I look forward to your updates.

@Howe-Young
Copy link

@GohioAC #7164 should fix this bug. And it works on InternVL2-2B with this image after the fix:

$ python examples/bug_example.py
INFO 08-05 23:27:38 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='/data/LLM-model/InternVL2-2B', speculative_config=None, tokenizer='/data/LLM-model/InternVL2-2B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=42, served_model_name=/data/LLM-model/InternVL2-2B, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-05 23:27:38 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
WARNING 08-05 23:27:38 cpu_executor.py:345] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
INFO 08-05 23:27:38 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:38 selector.py:66] Using Torch SDPA backend.
INFO 08-05 23:27:41 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:41 selector.py:66] Using Torch SDPA backend.
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.94it/s]
 
INFO 08-05 23:27:42 cpu_executor.py:208] # CPU blocks: 2730
WARNING 08-05 23:27:42 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:54<00:00, 54.82s/it, est. speed input: 5.22 toks/s, output: 4.67 toks/s]
Prompt: '<|im_start|>system\nAnswer the question.<|im_end|>\n<|im_start|>user\n\nWhat is shown in the image?<|im_end|>\n<|im_start|>assistant\n', Generated text: "The image shows a bright yellow jacket with a hood. The jacket has a colorful design on the front, including a green bow on the left chest area and a patch on the right side. The patch features a cartoonish design with a smiling face and some text. The jacket also has a pocket on the left side with a cartoon character and some text. The hood is up, and the jacket appears to be made of a soft, possibly fleece material.\nIs there anything else I can help you with?{No, that's all!}"

Still getting this error with vllm==0.5.4. I gather that this pr has not been approved to be merged into release version. So I manually did the correction by myself.

This pr seems to change vllm/model_executor/models/internvl.py only. I copied and pasted https://github.com/vllm-project/vllm/blob/2676f58c9e3f9d84399822c04657a83a4fae30dd/vllm/model_executor/models/internvl.py to my local path.

This does not works for me. I tried 2B 8B 28B. My error looks like: ValueError: Attempted to assign 7 x 256 = 1792 image tokens to 507 placeholders

Thanks for your work!

@GohioAC #7164 should fix this bug. And it works on InternVL2-2B with this image after the fix:

$ python examples/bug_example.py
INFO 08-05 23:27:38 llm_engine.py:174] Initializing an LLM engine (v0.5.3.post1) with config: model='/data/LLM-model/InternVL2-2B', speculative_config=None, tokenizer='/data/LLM-model/InternVL2-2B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=42, served_model_name=/data/LLM-model/InternVL2-2B, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-05 23:27:38 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
WARNING 08-05 23:27:38 cpu_executor.py:345] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
INFO 08-05 23:27:38 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:38 selector.py:66] Using Torch SDPA backend.
INFO 08-05 23:27:41 selector.py:117] Cannot use _Backend.FLASH_ATTN backend on CPU.
INFO 08-05 23:27:41 selector.py:66] Using Torch SDPA backend.
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 14.94it/s]
 
INFO 08-05 23:27:42 cpu_executor.py:208] # CPU blocks: 2730
WARNING 08-05 23:27:42 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:54<00:00, 54.82s/it, est. speed input: 5.22 toks/s, output: 4.67 toks/s]
Prompt: '<|im_start|>system\nAnswer the question.<|im_end|>\n<|im_start|>user\n\nWhat is shown in the image?<|im_end|>\n<|im_start|>assistant\n', Generated text: "The image shows a bright yellow jacket with a hood. The jacket has a colorful design on the front, including a green bow on the left chest area and a patch on the right side. The patch features a cartoonish design with a smiling face and some text. The jacket also has a pocket on the left side with a cartoon character and some text. The hood is up, and the jacket appears to be made of a soft, possibly fleece material.\nIs there anything else I can help you with?{No, that's all!}"

Still getting this error with vllm==0.5.4. I gather that this pr has not been approved to be merged into release version. So I manually did the correction by myself.

This pr seems to change vllm/model_executor/models/internvl.py only. I copied and pasted https://github.com/vllm-project/vllm/blob/2676f58c9e3f9d84399822c04657a83a4fae30dd/vllm/model_executor/models/internvl.py to my local path.

This does not works for me. I tried 2B 8B 28B. My error looks like: ValueError: Attempted to assign 7 x 256 = 1792 image tokens to 507 placeholders

Thanks for your work!

same error, I have already set max_model_len=4096:

 ValueError: Attempted to assign 7 x 256 = 1792 image tokens to 3328 placeholders

@DarkLight1337
Copy link
Member

The fix is currently only available if you build vLLM from source (main branch) since there hasn't been a release since then.

@Howe-Young
Copy link

The fix is currently only available if you build vLLM from source (main branch) since there hasn't been a release since then.

build vLLm from source works! thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants