Skip to content

Commit

Permalink
[Misc] Rename embedding classes to pooling (vllm-project#10801)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
  • Loading branch information
DarkLight1337 authored and afeldman-nm committed Dec 2, 2024
1 parent db1ca39 commit d198e8f
Show file tree
Hide file tree
Showing 25 changed files with 166 additions and 123 deletions.
2 changes: 1 addition & 1 deletion examples/offline_inference_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
Expand Down
6 changes: 3 additions & 3 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
Expand Down Expand Up @@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory()


def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
def assert_outputs_equal(o1: List[PoolingRequestOutput],
o2: List[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]


Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch.cuda

from vllm.model_executor.models import (is_embedding_model,
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model
Expand Down Expand Up @@ -31,7 +31,7 @@ def test_registry_imports(model_arch):

# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)
assert is_pooling_model(embed_model)

if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)


class MockAttentionBackend(AttentionBackend):
Expand Down
31 changes: 27 additions & 4 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -25,12 +25,35 @@
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"PoolingOutput",
"PoolingRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]


def __getattr__(name: str):
import warnings

if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingOutput

if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingRequestOutput

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _resolve_task(
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
"embedding": ModelRegistry.is_pooling_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
Expand Down
24 changes: 12 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -74,7 +74,7 @@ def _log_task_completion(task: asyncio.Task,


class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""

def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Expand All @@ -83,7 +83,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False

def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
Expand All @@ -103,7 +103,7 @@ def finished(self) -> bool:

async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try:
while True:
result = await self._queue.get()
Expand Down Expand Up @@ -154,7 +154,7 @@ def propagate_exception(self,

def process_request_output(self,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
PoolingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(self, *args, **kwargs):

async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
Expand Down Expand Up @@ -907,7 +907,7 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...

@overload
Expand All @@ -922,7 +922,7 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...

@deprecate_kwargs(
Expand All @@ -941,7 +941,7 @@ async def add_request(
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
Expand Down Expand Up @@ -1070,7 +1070,7 @@ async def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -1088,7 +1088,7 @@ async def encode(
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
Expand Down Expand Up @@ -1141,7 +1141,7 @@ async def encode(
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
yield LLMEngine.validate_output(output, PoolingRequestOutput)

async def abort(self, request_id: str) -> None:
"""Abort a request.
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:


_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)


@dataclass
Expand Down Expand Up @@ -112,7 +112,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def _advance_to_next_step(
else:
seq.append_token_id(sample.output_token, sample.logprobs)

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
Expand Down
14 changes: 7 additions & 7 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand Down Expand Up @@ -495,7 +495,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...

@overload
Expand All @@ -507,7 +507,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...

@deprecate_kwargs(
Expand All @@ -524,7 +524,7 @@ def encode(
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -540,7 +540,7 @@ def encode(
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
if inputs is not None:
Expand All @@ -549,7 +549,7 @@ def encode(
and request_id is not None)

return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
Expand All @@ -567,7 +567,7 @@ async def _process_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""

# If already dead, error out.
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
Expand Down Expand Up @@ -209,7 +208,7 @@ def encode(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...

Expand Down
Loading

0 comments on commit d198e8f

Please sign in to comment.