From d198e8fc7372134366a2c262c5ea30d7cefb39e2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 1 Dec 2024 14:36:51 +0800 Subject: [PATCH] [Misc] Rename embedding classes to pooling (#10801) Signed-off-by: DarkLight1337 Signed-off-by: Andrew Feldman --- examples/offline_inference_embedding.py | 2 +- tests/entrypoints/llm/test_encode.py | 6 +- tests/models/test_registry.py | 4 +- tests/worker/test_model_input.py | 4 +- vllm/__init__.py | 31 +++++++++-- vllm/config.py | 2 +- vllm/engine/async_llm_engine.py | 24 ++++---- vllm/engine/llm_engine.py | 8 +-- vllm/engine/multiprocessing/client.py | 14 ++--- vllm/engine/protocol.py | 5 +- vllm/entrypoints/llm.py | 30 +++++----- vllm/entrypoints/openai/serving_embedding.py | 12 ++-- vllm/entrypoints/openai/serving_score.py | 10 ++-- vllm/model_executor/models/__init__.py | 11 ++-- vllm/model_executor/models/adapters.py | 6 +- vllm/model_executor/models/interfaces.py | 4 +- vllm/model_executor/models/interfaces_base.py | 15 +++-- vllm/model_executor/models/registry.py | 16 +++--- vllm/outputs.py | 55 +++++++++++++------ vllm/v1/engine/async_llm.py | 4 +- vllm/v1/engine/async_stream.py | 8 +-- ..._runner.py => cpu_pooling_model_runner.py} | 4 +- vllm/worker/cpu_worker.py | 4 +- ...odel_runner.py => pooling_model_runner.py} | 6 +- vllm/worker/worker.py | 4 +- 25 files changed, 166 insertions(+), 123 deletions(-) rename vllm/worker/{cpu_embedding_model_runner.py => cpu_pooling_model_runner.py} (98%) rename vllm/worker/{embedding_model_runner.py => pooling_model_runner.py} (98%) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 7d5ef128bc8e0..ae158eef2ca4c 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -10,7 +10,7 @@ # Create an LLM. model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) -# Generate embedding. The output is a list of EmbeddingRequestOutputs. +# Generate embedding. The output is a list of PoolingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. for output in outputs: diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 4c9f796e5ed71..41163809237e9 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -3,7 +3,7 @@ import pytest -from vllm import LLM, EmbeddingRequestOutput, PoolingParams +from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -43,8 +43,8 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: List[EmbeddingRequestOutput], - o2: List[EmbeddingRequestOutput]): +def assert_outputs_equal(o1: List[PoolingRequestOutput], + o2: List[PoolingRequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 1886b1f9898ad..b5368aab3ecf1 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,7 +3,7 @@ import pytest import torch.cuda -from vllm.model_executor.models import (is_embedding_model, +from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) from vllm.model_executor.models.adapters import as_embedding_model @@ -31,7 +31,7 @@ def test_registry_imports(model_arch): # All vLLM models should be convertible to an embedding model embed_model = as_embedding_model(model_cls) - assert is_embedding_model(embed_model) + assert is_pooling_model(embed_model) if model_arch in _MULTIMODAL_MODELS: assert supports_multimodal(model_cls) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index b36e8bfe73ff3..309854e6babf3 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -8,10 +8,10 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.worker.embedding_model_runner import ( - ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.multi_step_model_runner import StatefulModelInput +from vllm.worker.pooling_model_runner import ( + ModelInputForGPUWithPoolingMetadata) class MockAttentionBackend(AttentionBackend): diff --git a/vllm/__init__.py b/vllm/__init__.py index 8f477ea84756d..a10f6d3128cb6 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -7,8 +7,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (CompletionOutput, EmbeddingOutput, - EmbeddingRequestOutput, RequestOutput) +from vllm.outputs import (CompletionOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -25,8 +25,8 @@ "SamplingParams", "RequestOutput", "CompletionOutput", - "EmbeddingOutput", - "EmbeddingRequestOutput", + "PoolingOutput", + "PoolingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", @@ -34,3 +34,26 @@ "initialize_ray_cluster", "PoolingParams", ] + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/config.py b/vllm/config.py index 51b8cf24803ab..da043afbe1ae7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -359,7 +359,7 @@ def _resolve_task( # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "generate": ModelRegistry.is_text_generation_model(architectures), - "embedding": ModelRegistry.is_embedding_model(architectures), + "embedding": ModelRegistry.is_pooling_model(architectures), } supported_tasks_lst: List[_Task] = [ task for task, is_supported in task_support.items() if is_supported diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 31a15b04314d5..7b1bb7b05708d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -25,7 +25,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -74,7 +74,7 @@ def _log_task_completion(task: asyncio.Task, class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: @@ -83,7 +83,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -103,7 +103,7 @@ def finished(self) -> bool: async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: try: while True: result = await self._queue.get() @@ -154,7 +154,7 @@ def propagate_exception(self, def process_request_output(self, request_output: Union[RequestOutput, - EmbeddingRequestOutput], + PoolingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -265,7 +265,7 @@ def __init__(self, *args, **kwargs): async def step_async( self, virtual_engine: int - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -907,7 +907,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @overload @@ -922,7 +922,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @deprecate_kwargs( @@ -941,7 +941,7 @@ async def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -1070,7 +1070,7 @@ async def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -1088,7 +1088,7 @@ async def encode( Only applicable with priority scheduling. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. Details: @@ -1141,7 +1141,7 @@ async def encode( trace_headers=trace_headers, priority=priority, ): - yield LLMEngine.validate_output(output, EmbeddingRequestOutput) + yield LLMEngine.validate_output(output, PoolingRequestOutput) async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ecc222f692c41..7911dc8d04500 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,7 +40,7 @@ get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, +from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) @dataclass @@ -112,7 +112,7 @@ class SchedulerContext: def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] + PoolingRequestOutput]] = [] self.seq_group_metadata_list: Optional[ List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None @@ -1314,7 +1314,7 @@ def _advance_to_next_step( else: seq.append_token_id(sample.output_token, sample.logprobs) - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index fe21c58c775fe..d26728e8c6e67 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -35,7 +35,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -495,7 +495,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @overload @@ -507,7 +507,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @deprecate_kwargs( @@ -524,7 +524,7 @@ def encode( priority: int = 0, *, inputs: Optional[PromptType] = None # DEPRECATED - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -540,7 +540,7 @@ def encode( trace_headers: OpenTelemetry trace headers. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. """ if inputs is not None: @@ -549,7 +549,7 @@ def encode( and request_id is not None) return cast( - AsyncGenerator[EmbeddingRequestOutput, None], + AsyncGenerator[PoolingRequestOutput, None], self._process_request(prompt, pooling_params, request_id, @@ -567,7 +567,7 @@ async def _process_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - EmbeddingRequestOutput, None]]: + PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e15395d75c91f..4079de7d36793 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,8 +11,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -209,7 +208,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1551a9a998160..a25c401b4ea10 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -26,7 +26,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -679,7 +679,7 @@ def encode( prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (prompt + optional token ids) @@ -691,7 +691,7 @@ def encode( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single (token ids + optional prompt) @@ -704,7 +704,7 @@ def encode( prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (token ids + optional prompt) @@ -717,7 +717,7 @@ def encode( prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single or multi token ids [pos-only] @@ -728,7 +728,7 @@ def encode( prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload @@ -741,7 +741,7 @@ def encode( Sequence[PoolingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @deprecate_kwargs( @@ -759,7 +759,7 @@ def encode( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: """Generates the completions for the input prompts. This class automatically batches the given prompts, considering @@ -778,7 +778,7 @@ def encode( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of ``PoolingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: @@ -821,7 +821,7 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, - EmbeddingRequestOutput) + PoolingRequestOutput) def score( self, @@ -832,7 +832,7 @@ def score( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: """Generates similarity scores for all pairs . The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case @@ -854,7 +854,7 @@ def score( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of ``PoolingRequestOutput`` objects containing the generated scores in the same order as the input prompts. """ task = self.llm_engine.model_config.task @@ -943,7 +943,7 @@ def ensure_str(prompt: SingletonPrompt): outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, - EmbeddingRequestOutput) + PoolingRequestOutput) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -1085,7 +1085,7 @@ def _add_guided_params( def _run_engine( self, *, use_tqdm: bool - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -1098,7 +1098,7 @@ def _run_engine( ) # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + outputs: List[Union[RequestOutput, PoolingRequestOutput]] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 78e2416d9d4da..2cbb252610e39 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -18,14 +18,14 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger -from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput +from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) def _get_embedding( - output: EmbeddingOutput, + output: PoolingOutput, encoding_format: Literal["float", "base64"], ) -> Union[List[float], str]: if encoding_format == "float": @@ -40,7 +40,7 @@ def _get_embedding( def request_output_to_embedding_response( - final_res_batch: List[EmbeddingRequestOutput], request_id: str, + final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] @@ -169,7 +169,7 @@ async def create_embedding( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] try: pooling_params = request.to_pooling_params() @@ -207,7 +207,7 @@ async def create_embedding( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -215,7 +215,7 @@ async def create_embedding( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch) response = request_output_to_embedding_response( diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 7cd8ff08b5608..a1f14449ba9c3 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger -from vllm.outputs import EmbeddingRequestOutput +from vllm.outputs import PoolingRequestOutput from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import make_async, merge_async_iterators, random_uuid @@ -21,7 +21,7 @@ def request_output_to_score_response( - final_res_batch: List[EmbeddingRequestOutput], request_id: str, + final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str) -> ScoreResponse: data: List[ScoreResponseData] = [] score = None @@ -133,7 +133,7 @@ async def create_score( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] input_pairs = make_pairs(request.text_1, request.text_2) @@ -194,7 +194,7 @@ async def create_score( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: @@ -203,7 +203,7 @@ async def create_score( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch) response = request_output_to_score_response( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d66373512b95e..a3ef9adad16d9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,15 +1,14 @@ from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) -from .interfaces_base import (VllmModelForEmbedding, - VllmModelForTextGeneration, is_embedding_model, - is_text_generation_model) +from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, + is_pooling_model, is_text_generation_model) from .registry import ModelRegistry __all__ = [ "ModelRegistry", - "VllmModelForEmbedding", - "is_embedding_model", + "VllmModelForPooling", + "is_pooling_model", "VllmModelForTextGeneration", "is_text_generation_model", "HasInnerState", @@ -20,4 +19,4 @@ "supports_multimodal", "SupportsPP", "supports_pp", -] \ No newline at end of file +] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 360433a07c5b8..9cc43ae9181b9 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from .interfaces_base import VllmModelForEmbedding, is_embedding_model +from .interfaces_base import VllmModelForPooling, is_pooling_model _T = TypeVar("_T", bound=type[nn.Module]) @@ -12,7 +12,7 @@ def as_embedding_model(cls: _T) -> _T: """Subclass an existing vLLM model to support embeddings.""" # Avoid modifying existing embedding models - if is_embedding_model(cls): + if is_pooling_model(cls): return cls # Lazy import @@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T: from .utils import AutoWeightsLoader, WeightsMapper - class ModelForEmbedding(cls, VllmModelForEmbedding): + class ModelForEmbedding(cls, VllmModelForPooling): def __init__( self, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 1545ce332309f..01a381381ccec 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.utils import supports_kw -from .interfaces_base import is_embedding_model +from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -389,4 +389,4 @@ def _supports_cross_encoding( def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: - return is_embedding_model(model) and _supports_cross_encoding(model) + return is_pooling_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 957a5a6e26b5c..de733b6d49a53 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -141,7 +141,7 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): +class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): def pooler( self, @@ -153,23 +153,22 @@ def pooler( @overload -def is_embedding_model( - model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: +def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: ... @overload -def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... -def is_embedding_model( +def is_pooling_model( model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: +) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: if not is_vllm_model(model): return False if isinstance(model, type): - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7d2bfce9ba264..2b7b69e8c3a95 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -24,7 +24,7 @@ from .interfaces import (has_inner_state, is_attention_free, supports_cross_encoding, supports_multimodal, supports_pp) -from .interfaces_base import is_embedding_model, is_text_generation_model +from .interfaces_base import is_pooling_model, is_text_generation_model logger = init_logger(__name__) @@ -211,7 +211,7 @@ class _ModelInfo: architecture: str is_text_generation_model: bool - is_embedding_model: bool + is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool @@ -220,19 +220,19 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": - is_embedding_model_ = is_embedding_model(model) - if not is_embedding_model_: + is_pooling_model_ = is_pooling_model(model) + if not is_pooling_model_: try: as_embedding_model(model) except Exception: pass else: - is_embedding_model_ = True + is_pooling_model_ = True return _ModelInfo( architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), - is_embedding_model=is_embedding_model_, + is_pooling_model=is_pooling_model_, supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), @@ -441,12 +441,12 @@ def is_text_generation_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_text_generation_model - def is_embedding_model( + def is_pooling_model( self, architectures: Union[str, List[str]], ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) - return model_cls.is_embedding_model + return model_cls.is_pooling_model def is_cross_encoder_model( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index 912e485e40b59..ead37164f1113 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -53,8 +53,8 @@ def __repr__(self) -> str: @dataclass -class EmbeddingOutput: - """The output data of one completion output of a request. +class PoolingOutput: + """The output data of one pooling output of a request. Args: embedding: The embedding vector, which is a list of floats. The @@ -63,7 +63,7 @@ class EmbeddingOutput: embedding: List[float] def __repr__(self) -> str: - return (f"EmbeddingOutput(" + return (f"PoolingOutput(" f"embedding={len(self.embedding)})") @@ -327,18 +327,18 @@ def __repr__(self) -> str: f"multi_modal_placeholders={self.multi_modal_placeholders})") -class EmbeddingRequestOutput: +class PoolingRequestOutput: """ - The output data of an embedding request to the LLM. + The output data of a pooling request to the LLM. Args: - request_id (str): A unique identifier for the embedding request. - outputs (EmbeddingOutput): The embedding results for the given input. + request_id (str): A unique identifier for the pooling request. + outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (List[int]): A list of token IDs used in the prompt. - finished (bool): A flag indicating whether the embedding is completed. + finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: "EmbeddingOutput", + def __init__(self, request_id: str, outputs: "PoolingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids @@ -347,11 +347,11 @@ def __init__(self, request_id: str, outputs: "EmbeddingOutput", @classmethod def from_seq_group(cls, - seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + seq_group: 'SequenceGroup') -> "PoolingRequestOutput": if seq_group.embeddings is None: raise ValueError( "Embeddings are missing in seq_group for EmbeddingRequest.") - output = EmbeddingOutput(seq_group.embeddings) + output = PoolingOutput(seq_group.embeddings) prompt_token_ids = seq_group.prompt_token_ids finished = seq_group.is_finished() @@ -359,15 +359,15 @@ def from_seq_group(cls, def __repr__(self): """ - Returns a string representation of an EmbeddingRequestOutput instance. + Returns a string representation of an PoolingRequestOutput instance. The representation includes the request_id and the number of outputs, - providing a quick overview of the embedding request's results. + providing a quick overview of the pooling request's results. Returns: - str: A string representation of the EmbeddingRequestOutput instance. + str: A string representation of the PoolingRequestOutput instance. """ - return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + return (f"PoolingRequestOutput(request_id='{self.request_id}', " f"outputs={repr(self.outputs)}, " f"prompt_token_ids={self.prompt_token_ids}, " f"finished={self.finished})") @@ -426,7 +426,30 @@ def create(seq_group: SequenceGroup, # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: - return EmbeddingRequestOutput.from_seq_group(seq_group) + return PoolingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, use_cache, seq_id_to_seq_group) + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 421ecc8c0d921..1df9bc57a1cb2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,7 +9,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -133,7 +133,7 @@ async def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: """Add new request to the AsyncLLM.""" if self.detokenizer.is_request_active(request_id): diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py index 3e6c759ad5ebd..35449238c3259 100644 --- a/vllm/v1/engine/async_stream.py +++ b/vllm/v1/engine/async_stream.py @@ -1,11 +1,11 @@ import asyncio from typing import Any, AsyncGenerator, Callable, Optional, Type, Union -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" STOP_ITERATION = Exception() # Sentinel @@ -16,7 +16,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -32,7 +32,7 @@ def finish( async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: finished = False try: while True: diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py similarity index 98% rename from vllm/worker/cpu_embedding_model_runner.py rename to vllm/worker/cpu_pooling_model_runner.py index 3954e4c4c8a5b..17b2fd2564a04 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -16,12 +16,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): """ - Used by the CPUEmbeddingModelRunner. + Used by the CPUPoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class CPUEmbeddingModelRunner( +class CPUPoolingModelRunner( CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( ModelInputForCPUWithPoolingMetadata) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index cf04808b73372..4fad1a3f4caeb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -14,9 +14,9 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase +from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -164,7 +164,7 @@ def __init__( else {"return_hidden_states": True} ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner if self.model_config.task == "embedding": - ModelRunnerClass = CPUEmbeddingModelRunner + ModelRunnerClass = CPUPoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunnerBase = ModelRunnerClass( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/pooling_model_runner.py similarity index 98% rename from vllm/worker/embedding_model_runner.py rename to vllm/worker/pooling_model_runner.py index f56805918fd15..1beae1e3884c5 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -21,12 +21,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ - Used by the EmbeddingModelRunner. + Used by the PoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class EmbeddingModelRunner( +class PoolingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) @@ -52,7 +52,7 @@ def execute_model( ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "EmbeddingModelRunner does not support multi-step execution.") + "PoolingModelRunner does not support multi-step execution.") if self.lora_config: assert model_input.lora_requests is not None diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 24e7bc760b0c0..d58cb029618e9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -22,9 +22,9 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -75,7 +75,7 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_config.task == "embedding": - ModelRunnerClass = EmbeddingModelRunner + ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass(