Skip to content

Commit

Permalink
Refactor registry to plugin pattern in order to support specifying du…
Browse files Browse the repository at this point in the history
…mmy data
  • Loading branch information
DarkLight1337 committed Apr 23, 2024
1 parent c48a7d4 commit 3232231
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 225 deletions.
2 changes: 1 addition & 1 deletion examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image

from vllm import LLM
from vllm.multimodal import ImageFeatureData, ImagePixelData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.

Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.multimodal import ImageFeatureData, ImagePixelData, MultiModalData
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_DIR = os.path.dirname(__file__)
Expand Down
43 changes: 41 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from PIL import Image
from torch import nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig

from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -17,14 +18,51 @@
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MM_REGISTRY
from vllm.sequence import SamplerOutput
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SamplerOutput, SequenceData

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}


def _get_seq_data(seq_len: int, vlm_config: VisionLanguageConfig):
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size \
+ [0] * (seq_len - vlm_config.image_feature_size)

return SequenceData(token_ids)


def _get_values(vlm_config: VisionLanguageConfig):
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8

return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)


def _get_dummy_data(seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
seq_data = _get_seq_data(seq_len, vlm_config)
values = _get_values(vlm_config)

config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if config_input_type == ImageInputType.PIXEL_VALUES:
values_arr = values.squeeze(dim=0).permute((1, 2, 0)).numpy()
image = Image.fromarray(values_arr, mode="RGB")
fake_mm_data = ImagePixelData(image)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError

return seq_data, fake_mm_data


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

Expand Down Expand Up @@ -75,6 +113,7 @@ class LlavaImageFeatureInputs(TypedDict):

@MM_REGISTRY.register_image_feature_input()
@MM_REGISTRY.register_image_pixel_input()
@MM_REGISTRY.register_dummy_data(_get_dummy_data)
class LlavaForConditionalGeneration(nn.Module):

def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .data import *
from .base import *
from .registry import *
91 changes: 91 additions & 0 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)

from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
import torch
from torch import nn

logger = init_logger(__name__)


class MultiModalData:
"""To add a new data type, add a new file under `multimodal` directory.
In this new file, create a subclass of
:class:`~vllm.multimodal.base.MultiModalData`
and :class:`~vllm.multimodal.base.MultiModalPlugin`.
Finally, update `~vllm.multimodal.registry.MultiModalRegistry`
with new methods to interact with the newly defined registry.
"""
pass


D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])

MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
"""Returns a dictionary which are passed as keyword arguments to
:meth:`torch.nn.Module.forward`.
"""


class MultiModalPlugin(ABC, Generic[D]):

@classmethod
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture

return get_model_architecture(model_config)[0]

def __init__(self) -> None:
self._input_processors: Dict[Type["nn.Module"],
MultiModalInputProcessor[D]] = {}

@abstractmethod
def get_data_type(self) -> Type[D]:
raise NotImplementedError

@abstractmethod
def _default_input_processor(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""Returns a dictionary which are passed as keyword arguments to
:meth:`torch.nn.Module.forward`.
"""
raise NotImplementedError

def register_input_processor(self,
processor: Optional[
MultiModalInputProcessor[D]] = None):

def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors:
logger.warning(
f"Model class {model_cls} already has an input processor "
f"registered to {self}. It is overwritten by the new one.")

self._input_processors[model_cls] = processor \
or self._default_input_processor

return model_cls

return wrapper

def process_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
model_cls = self.get_model_cls(model_config)

processor = self._input_processors.get(model_cls)
if processor is None:
raise KeyError(f"No input processor in {self} is registered for "
f"model class {model_cls.__name__}.")

return processor(data, model_config, vlm_config)
27 changes: 0 additions & 27 deletions vllm/multimodal/data.py

This file was deleted.

85 changes: 85 additions & 0 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Dict, Type

import numpy as np
import torch
from PIL import Image

from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import cached_get_image_processor

from .base import MultiModalData, MultiModalPlugin

logger = init_logger(__name__)


class ImagePixelData(MultiModalData):

def __init__(self, image: Image.Image) -> None:
# So that this class can be created inside the Image context manager
image.load()

self.image = image


class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):

def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData

def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
if vlm_config is None or vlm_config.image_processor is None:
return None

return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)

def _default_input_processor(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
# Temporary patch to make LLaVA-NeXT usable
_, _, h, w = vlm_config.image_input_shape
image = data.image.resize((w, h))

image_processor = self._get_hf_image_processor(model_config,
vlm_config)
if image_processor is None:
image_arr = np.array(image, copy=True)
pixel_values = torch.as_tensor(image_arr) \
.view(1, image.height, image.width, -1) \
.permute((0, 3, 1, 2)) \
.to(model_config.dtype)

return {"pixel_values": pixel_values}

try:
out_dict = image_processor.preprocess(image) \
.convert_to_tensors("pt")
except Exception:
logger.error("Failed to process image (%s)", image)
raise

return {k: v.to(model_config.dtype) for k, v in out_dict.data.items()}


class ImageFeatureData(MultiModalData):

def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features


class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):

def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData

def _default_input_processor(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image_features = data.image_features.to(model_config.dtype)

return {"image_features": image_features}
Loading

0 comments on commit 3232231

Please sign in to comment.