-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor registry to plugin pattern in order to support specifying du…
…mmy data
- Loading branch information
1 parent
c48a7d4
commit 3232231
Showing
10 changed files
with
326 additions
and
225 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .data import * | ||
from .base import * | ||
from .registry import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, | ||
TypeVar) | ||
|
||
from vllm.config import ModelConfig, VisionLanguageConfig | ||
from vllm.logger import init_logger | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
from torch import nn | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class MultiModalData: | ||
"""To add a new data type, add a new file under `multimodal` directory. | ||
In this new file, create a subclass of | ||
:class:`~vllm.multimodal.base.MultiModalData` | ||
and :class:`~vllm.multimodal.base.MultiModalPlugin`. | ||
Finally, update `~vllm.multimodal.registry.MultiModalRegistry` | ||
with new methods to interact with the newly defined registry. | ||
""" | ||
pass | ||
|
||
|
||
D = TypeVar("D", bound=MultiModalData) | ||
N = TypeVar("N", bound=Type["nn.Module"]) | ||
|
||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], | ||
Dict[str, "torch.Tensor"]] | ||
"""Returns a dictionary which are passed as keyword arguments to | ||
:meth:`torch.nn.Module.forward`. | ||
""" | ||
|
||
|
||
class MultiModalPlugin(ABC, Generic[D]): | ||
|
||
@classmethod | ||
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]: | ||
# Avoid circular import | ||
from vllm.model_executor.model_loader import get_model_architecture | ||
|
||
return get_model_architecture(model_config)[0] | ||
|
||
def __init__(self) -> None: | ||
self._input_processors: Dict[Type["nn.Module"], | ||
MultiModalInputProcessor[D]] = {} | ||
|
||
@abstractmethod | ||
def get_data_type(self) -> Type[D]: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _default_input_processor( | ||
self, data: D, model_config: ModelConfig, | ||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: | ||
"""Returns a dictionary which are passed as keyword arguments to | ||
:meth:`torch.nn.Module.forward`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def register_input_processor(self, | ||
processor: Optional[ | ||
MultiModalInputProcessor[D]] = None): | ||
|
||
def wrapper(model_cls: N) -> N: | ||
if model_cls in self._input_processors: | ||
logger.warning( | ||
f"Model class {model_cls} already has an input processor " | ||
f"registered to {self}. It is overwritten by the new one.") | ||
|
||
self._input_processors[model_cls] = processor \ | ||
or self._default_input_processor | ||
|
||
return model_cls | ||
|
||
return wrapper | ||
|
||
def process_input( | ||
self, data: D, model_config: ModelConfig, | ||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: | ||
model_cls = self.get_model_cls(model_config) | ||
|
||
processor = self._input_processors.get(model_cls) | ||
if processor is None: | ||
raise KeyError(f"No input processor in {self} is registered for " | ||
f"model class {model_cls.__name__}.") | ||
|
||
return processor(data, model_config, vlm_config) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Dict, Type | ||
|
||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
|
||
from vllm.config import ModelConfig, VisionLanguageConfig | ||
from vllm.logger import init_logger | ||
from vllm.transformers_utils.image_processor import cached_get_image_processor | ||
|
||
from .base import MultiModalData, MultiModalPlugin | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class ImagePixelData(MultiModalData): | ||
|
||
def __init__(self, image: Image.Image) -> None: | ||
# So that this class can be created inside the Image context manager | ||
image.load() | ||
|
||
self.image = image | ||
|
||
|
||
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): | ||
|
||
def get_data_type(self) -> Type[ImagePixelData]: | ||
return ImagePixelData | ||
|
||
def _get_hf_image_processor(self, model_config: ModelConfig, | ||
vlm_config: VisionLanguageConfig): | ||
if vlm_config is None or vlm_config.image_processor is None: | ||
return None | ||
|
||
return cached_get_image_processor( | ||
vlm_config.image_processor, | ||
trust_remote_code=model_config.trust_remote_code, | ||
revision=vlm_config.image_processor_revision, | ||
) | ||
|
||
def _default_input_processor( | ||
self, data: ImagePixelData, model_config: ModelConfig, | ||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: | ||
# Temporary patch to make LLaVA-NeXT usable | ||
_, _, h, w = vlm_config.image_input_shape | ||
image = data.image.resize((w, h)) | ||
|
||
image_processor = self._get_hf_image_processor(model_config, | ||
vlm_config) | ||
if image_processor is None: | ||
image_arr = np.array(image, copy=True) | ||
pixel_values = torch.as_tensor(image_arr) \ | ||
.view(1, image.height, image.width, -1) \ | ||
.permute((0, 3, 1, 2)) \ | ||
.to(model_config.dtype) | ||
|
||
return {"pixel_values": pixel_values} | ||
|
||
try: | ||
out_dict = image_processor.preprocess(image) \ | ||
.convert_to_tensors("pt") | ||
except Exception: | ||
logger.error("Failed to process image (%s)", image) | ||
raise | ||
|
||
return {k: v.to(model_config.dtype) for k, v in out_dict.data.items()} | ||
|
||
|
||
class ImageFeatureData(MultiModalData): | ||
|
||
def __init__(self, image_features: torch.Tensor) -> None: | ||
self.image_features = image_features | ||
|
||
|
||
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): | ||
|
||
def get_data_type(self) -> Type[ImageFeatureData]: | ||
return ImageFeatureData | ||
|
||
def _default_input_processor( | ||
self, data: ImageFeatureData, model_config: ModelConfig, | ||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: | ||
image_features = data.image_features.to(model_config.dtype) | ||
|
||
return {"image_features": image_features} |
Oops, something went wrong.