diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c6a274ee5894b..6c9b02972c1d5 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,6 +5,8 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import random + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -23,7 +25,11 @@ def run_llava(question: str, modality: str): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + # TODO: Fix this! + mm_cache_preprocessor=args.mm_cache_preprocessor) stop_token_ids = None return llm, prompt, stop_token_ids @@ -524,14 +530,35 @@ def main(args): else: # Batch inference - inputs = [{ - "prompt": prompt, - "multi_modal_data": { - modality: data - }, - } for _ in range(args.num_prompts)] - + if args.image_repeat_ratio is not None: + assert (args.image_repeat_ratio <= 1.0 + and args.image_repeat_ratio >= 0) + no_yes = [0, 1] + probs = [1.0 - args.image_repeat_ratio, args.image_repeat_ratio] + + inputs = [] + cur_image = data + for i in range(args.num_prompts): + if args.image_repeat_ratio is not None: + res = random.choices(no_yes, probs)[0] + if res == 0: + # No repeat => Modify one pixel + cur_image = cur_image.copy() + new_val = (i // 256 // 256, i // 256, i % 256) + cur_image.putpixel((0, 0), new_val) + + inputs.append({ + "prompt": prompt, + "multi_modal_data": { + modality: cur_image + } + }) + + import time + start_time = time.time() outputs = llm.generate(inputs, sampling_params=sampling_params) + elapsed_time = time.time() - start_time + print("-- generate time = {}".format(elapsed_time)) for o in outputs: generated_text = o.outputs[0].text @@ -561,5 +588,18 @@ def main(args): type=int, default=16, help='Number of frames to extract from the video.') + + parser.add_argument( + '--image-repeat-ratio', + type=float, + default=None, + help='Simulates the hit-ratio for multi-modal preprocessor cache' + ' (if enabled)') + + parser.add_argument( + '--mm-cache-preprocessor', + action='store_true', + help='If True, enable caching of multi-modal preprocessor/mapper.') + args = parser.parse_args() main(args) diff --git a/vllm/config.py b/vllm/config.py index 7fbe04eaaf4f8..34e80e1119142 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -133,6 +133,8 @@ class ModelConfig: HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. + mm_cache_preprocessor: If True, enable caching of multi-modal + preprocessor/mapper. override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that @@ -171,6 +173,7 @@ def __init__( config_format: ConfigFormat = ConfigFormat.AUTO, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_cache_preprocessor: bool = False, override_neuron_config: Optional[Dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None) -> None: self.model = model @@ -237,6 +240,7 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc self.mm_processor_kwargs = mm_processor_kwargs + self.mm_cache_preprocessor = mm_cache_preprocessor # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: @@ -2610,9 +2614,10 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " - f"pooler_config={self.model_config.pooler_config!r}," - f" compilation_config={self.compilation_config!r}") + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}") _current_vllm_config: Optional[VllmConfig] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3db069ec64ee4..0ca2f51e5b0a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -143,6 +143,7 @@ class EngineArgs: tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None + mm_cache_preprocessor: bool = False enable_lora: bool = False enable_lora_bias: bool = False max_loras: int = 1 @@ -590,6 +591,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=json.loads, help=('Overrides for the multimodal input mapping/processing, ' 'e.g., image processor. For example: {"num_crops": 4}.')) + parser.add_argument( + '--mm-cache-preprocessor', + action='store_true', + help='If True, enable caching of multi-modal preprocessor/mapper.') # LoRA related configs parser.add_argument('--enable-lora', @@ -962,6 +967,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, + mm_cache_preprocessor=self.mm_cache_preprocessor, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 3cf0e610ae7af..abeea052c1fa5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -35,7 +35,8 @@ class EngineCoreRequest: # always be tokenized? prompt: Optional[str] prompt_token_ids: List[int] - mm_inputs: Optional[List[MultiModalKwargs]] + mm_inputs: Optional[List[Optional[MultiModalKwargs]]] + mm_hashes: Optional[List[Optional[str]]] mm_placeholders: Optional[MultiModalPlaceholderDict] sampling_params: SamplingParams eos_token_id: Optional[int] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 751eb3b40a68d..924ee203c7806 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,7 +19,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType) -from vllm.v1.engine.mm_input_mapper import MMInputMapper +from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder @@ -55,9 +55,6 @@ def __init__( vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks - # Set up multimodal input mapper (e.g., convert PIL images to tensors). - self.mm_input_mapper = MMInputMapper(vllm_config.model_config) - # Setup scheduler. self.scheduler = Scheduler(vllm_config.scheduler_config, vllm_config.cache_config, @@ -65,6 +62,8 @@ def __init__( self._last_logging_time = time.time() + self.mm_input_mapper_server = MMInputMapperServer() + def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: start = time.time() @@ -88,7 +87,14 @@ def _initialize_kv_caches(self, def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" + + # Add doc + if request.mm_hashes is not None: + request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs, request.mm_hashes) + req = Request.from_engine_core_request(request) + self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 7ad6882b04520..3efd235fa7a15 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,11 +1,18 @@ from typing import Any, Dict, List, Optional +import PIL +from blake3 import blake3 + from vllm.config import ModelConfig from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs, MultiModalRegistry) +from vllm.v1.utils import LRUDictCache + +# Both Client and Server must use the same cache size +MM_CACHE_SIZE = 128 -class MMInputMapper: +class MMInputMapperClient: def __init__( self, @@ -18,23 +25,115 @@ def __init__( model_config) self.mm_registry.init_mm_limits_per_prompt(model_config) + self.mm_cache = LRUDictCache(MM_CACHE_SIZE) + + # Set to None to disable (TODO: Disable!) + self.mm_debug_cache_hit_ratio_steps = 32 + self.mm_cache_hits = 0 + self.mm_cache_misses = 0 + + def cache_hit_ratio(self, steps) -> float: + total_steps = self.mm_cache_hits + self.mm_cache_misses + + if total_steps > 0 and total_steps % steps == 0: + print("[debug] MMInputMapper: cache_hit_ratio = {}".format( + self.mm_cache_hits / total_steps)) + def process_inputs( self, mm_data: MultiModalDataDict, + mm_hashes: Optional[List[str]], mm_processor_kwargs: Optional[Dict[str, Any]], ) -> List[MultiModalKwargs]: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): image_inputs = [image_inputs] + use_hash = mm_hashes is not None + if use_hash: + assert len(image_inputs) == len(mm_hashes) # Sanity + # Process each image input separately so that later we can schedule # them in a fine-grained manner. - mm_inputs: List[MultiModalKwargs] = [] - num_images = len(image_inputs) - for i in range(num_images): - mm_input = self.multi_modal_input_mapper( - {"image": image_inputs[i]}, - mm_processor_kwargs=mm_processor_kwargs, - ) - mm_inputs.append(mm_input) - return mm_inputs + # Utilize caching (if enabled) + ret_hashes = [] if use_hash else None + ret_inputs: List[MultiModalKwargs] = [] + for i in range(len(image_inputs)): + if self.mm_debug_cache_hit_ratio_steps is not None: + self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) + + if use_hash: + mm_hash = mm_hashes[i] + mm_input = self.mm_cache.get(mm_hash) + else: + mm_hash = None + mm_input = None + + if mm_input is None: + self.mm_cache_misses += 1 + mm_input = self.multi_modal_input_mapper( + {"image": [image_inputs[i]]}, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if use_hash: + self.mm_cache.put(mm_hash, mm_input) + else: + self.mm_cache_hits += 1 + mm_input = None # Avoids sending mm_input to Server + + if use_hash: + ret_hashes.append(mm_hash) + ret_inputs.append(mm_input) + + return ret_inputs, ret_hashes + + +class MMInputMapperServer: + + def __init__(self, ): + self.mm_cache = LRUDictCache(MM_CACHE_SIZE) + + def process_inputs( + self, + mm_inputs: List[Optional[MultiModalKwargs]], + mm_hashes: List[Optional[str]], + ) -> List[MultiModalKwargs]: + assert len(mm_inputs) == len(mm_hashes) + + full_mm_inputs = [] + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + if mm_input is None: + mm_input = self.mm_cache.get(mm_hash) + assert mm_input is not None + else: + self.mm_cache.put(mm_hash, mm_input) + + full_mm_inputs.append(mm_input) + + return full_mm_inputs + + +class MMHasher: + + def __init__(self): + pass + + def hash(self, mm_data: MultiModalDataDict) -> List[str]: + image_inputs = mm_data["image"] + if not isinstance(image_inputs, list): + image_inputs = [image_inputs] + + ret = [] + for image in image_inputs: + assert isinstance(image, PIL.Image.Image) + + # Convert image to bytes + bytes = image.tobytes() + + # Hash image bytes + hasher = blake3() + hasher.update(bytes) + ret.append(hasher.hexdigest()) + + return ret diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 120fc64969552..abaa007921aea 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -15,7 +15,7 @@ from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest -from vllm.v1.engine.mm_input_mapper import MMInputMapper +from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient class Processor: @@ -42,7 +42,11 @@ def __init__( model_config) # Multi-modal (huggingface) input mapper - self.mm_input_mapper = MMInputMapper(model_config) + self.mm_input_mapper_client = MMInputMapperClient(model_config) + + # Multi-modal hasher (for images) + self.mm_hasher = MMHasher( + ) if model_config.mm_cache_preprocessor else None # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the @@ -102,15 +106,20 @@ def process_inputs( self.generation_config_fields, eos_token_id) # Preprocess multi-modal data + mm_hashes = None + mm_inputs = None if len(decoder_inputs.multi_modal_data) == 0: - mm_inputs = None + pass elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): + # No hash in this case mm_inputs = [decoder_inputs.multi_modal_data] else: - mm_inputs = self.mm_input_mapper.process_inputs( - decoder_inputs.multi_modal_data, - decoder_inputs.mm_processor_kwargs, - ) + mm_hashes = self.mm_hasher.hash(decoder_inputs.multi_modal_data) \ + if self.mm_hasher is not None else None + + mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( + decoder_inputs.multi_modal_data, mm_hashes, + decoder_inputs.mm_processor_kwargs) # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( @@ -130,6 +139,7 @@ def process_inputs( decoder_inputs.prompt, decoder_inputs.prompt_token_ids, mm_inputs, + mm_hashes, decoder_inputs.multi_modal_placeholders, sampling_params, eos_token_id, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 4b26749712e32..d9aed20ca1886 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from typing import Generic, List, TypeVar, overload T = TypeVar("T") @@ -62,3 +63,23 @@ def __contains__(self, item): def __len__(self): return len(self._x) + + +class LRUDictCache: + + def __init__(self, size: int): + self.cache = OrderedDict() + self.size = size + + def get(self, key): + if key not in self.cache: + return None + + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value): + self.cache[key] = value + self.cache.move_to_end(key) + if len(self.cache) > self.size: + self.cache.popitem(last=False)