diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ddadb947762..ac4c3e5f8eb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -209,7 +209,7 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode - - pytest -v -s v1/kv_transfer + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh similarity index 96% rename from tests/v1/kv_connector/run_accuracy_test.sh rename to tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 4959e29c3cb..17eac262968 100755 --- a/tests/v1/kv_connector/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -138,7 +138,7 @@ run_tests_for_model() { done # Build the command for the proxy server with all the hosts and ports - PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/toy_proxy_server.py --port 8192" + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" # Add all prefill hosts and ports PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" @@ -157,7 +157,7 @@ run_tests_for_model() { # Run lm eval for this model echo "Running tests for $model_name" - TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/test_accuracy.py + TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py # Clean up before running next model cleanup_instances diff --git a/tests/v1/kv_connector/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py similarity index 100% rename from tests/v1/kv_connector/test_accuracy.py rename to tests/v1/kv_connector/nixl_integration/test_accuracy.py diff --git a/tests/v1/kv_connector/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py similarity index 100% rename from tests/v1/kv_connector/toy_proxy_server.py rename to tests/v1/kv_connector/nixl_integration/toy_proxy_server.py diff --git a/tests/v1/kv_connector/__init__.py b/tests/v1/kv_connector/unit/__init__.py similarity index 100% rename from tests/v1/kv_connector/__init__.py rename to tests/v1/kv_connector/unit/__init__.py diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py similarity index 100% rename from tests/v1/kv_connector/test_nixl_connector.py rename to tests/v1/kv_connector/unit/test_nixl_connector.py diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py similarity index 100% rename from tests/v1/kv_connector/test_remote_decode_lifecycle.py rename to tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py similarity index 89% rename from tests/v1/kv_connector/test_remote_prefill_lifecycle.py rename to tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index b9deeda18e9..e6d254443e9 100644 --- a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -272,3 +272,37 @@ def test_no_spurious_prefix_caching(): for block in remote_blocks: assert block.ref_cnt == 1 assert block._block_hash is None + + +def test_short_prompt_lifecycle(): + """Test lifecycle of a Remote Decode request with short prompt.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Not enough tokens for full block. + NUM_TOKENS = vllm_config.cache_config.block_size // 2 + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + # Since tokens < block_size, there will be no kv xfer. + # So this should be cleaned up immediately. + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + # We need one more call to schedule() to clear data for persistent batch. + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/unit/utils.py similarity index 96% rename from tests/v1/kv_connector/utils.py rename to tests/v1/kv_connector/unit/utils.py index c5527bc0ee5..a681b1ad5f2 100644 --- a/tests/v1/kv_connector/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -6,7 +6,7 @@ from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import KVTransferParams, SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.scheduler_disagg import DisaggregatedScheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput @@ -16,7 +16,7 @@ EOS_TOKEN_ID = 50256 -def assert_scheduler_empty(scheduler: Scheduler): +def assert_scheduler_empty(scheduler: DisaggregatedScheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. assert len(scheduler.requests) == 0 @@ -88,7 +88,7 @@ def create_vllm_config( def create_scheduler( vllm_config: VllmConfig, num_blocks: int = 10000, -) -> Scheduler: +) -> DisaggregatedScheduler: """Initialize Scheduler For Testing.""" block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( @@ -101,7 +101,7 @@ def create_scheduler( ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks - return Scheduler( + return DisaggregatedScheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_transfer/test_multi_connector.py deleted file mode 100644 index ed26ba0f0d3..00000000000 --- a/tests/v1/kv_transfer/test_multi_connector.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import filecmp -import shutil -import tempfile -from collections import defaultdict -from pathlib import Path - -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) - -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" - -PROMPT_CONTEXT = "Hi " * 100 -PROMPTS = [ - PROMPT_CONTEXT + "Hello, my name is", - PROMPT_CONTEXT + "The capital of France is", -] - -SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) - - -class TestSharedStorageConnector(SharedStorageConnector): - - def __init__(self, config: VllmConfig, role): - self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) - self.call_record: dict[str, int] = defaultdict(int) - # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}_events.log" - # Start with an empty file - with open(self._event_file, "w") as _: - pass - - def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion - return object.__getattribute__(self, name) - if not hasattr(self._connector, name): - return object.__getattribute__(self, name) - attr = getattr(self._connector, name) - - if callable(attr): - - def wrapper(*args, **kwargs): - self.call_record[name] += 1 - # Log the event as a line to the file - try: - with open(self._event_file, "a") as f: - f.write(name + "\n") - except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") - return attr(*args, **kwargs) - - return wrapper - return attr - - -KVConnectorFactory.register_connector("TestSharedStorageConnector", - TestSharedStorageConnector.__module__, - TestSharedStorageConnector.__name__) - - -# Helper function to compare directories recursively -def _compare_directories(dir1: Path, dir2: Path) -> bool: - """Compares two directories recursively for identical content.""" - dcmp = filecmp.dircmp(dir1, dir2) - if dcmp.left_only or dcmp.right_only or dcmp.diff_files: - print(f"Differences found between {dir1} and {dir2}:") - print(f" Left only: {dcmp.left_only}") - print(f" Right only: {dcmp.right_only}") - print(f" Different files: {dcmp.diff_files}") - return False - for sub_dir in dcmp.common_dirs: - if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): - return False - return True - - -def test_multi_shared_storage_connector_consistency(): - """ - Tests that MultiConnector with two SharedStorageConnectors saves - identical KV cache data to separate storage locations. - """ - storage_1_path = Path("storage_1/") - storage_2_path = Path("storage_2/") - shutil.rmtree(storage_1_path, ignore_errors=True) - shutil.rmtree(storage_2_path, ignore_errors=True) - storage_1_path.mkdir() - storage_2_path.mkdir() - - # Configure MultiConnector with two SharedStorageConnectors - kv_transfer_config = KVTransferConfig( - kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [{ - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path), - "name": "storage1", - } - }, { - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path), - "name": "storage2", - } - }] - }, - ) - - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - gpu_memory_utilization=0.5, - kv_transfer_config=kv_transfer_config, - ) - # Run generation - this should trigger saving KV cache - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - # --- Verification --- - - # Check that both storage directories were populated - local_subdirs = list(storage_1_path.iterdir()) - external_subdirs = list(storage_2_path.iterdir()) - - assert len( - local_subdirs - ) > 0, f"Local storage path {storage_1_path} is empty after generation." - assert len(external_subdirs) > 0, ( - f"External storage path {storage_2_path} is empty after generation.") - assert len(local_subdirs) == len(external_subdirs), ( - f"Mismatch in number of cache entries: " - f"Local={len(local_subdirs)}, External={len(external_subdirs)}") - - # The subdirectories should correspond to the prompt hashes - # Since prompts are the same, the hash directories should be the same name - local_subdir_names = sorted([d.name for d in local_subdirs]) - external_subdir_names = sorted([d.name for d in external_subdirs]) - assert local_subdir_names == external_subdir_names, ( - "Cache directory names do not match between local and external storage" - ) - - # Compare the contents of each corresponding cache directory - for subdir_name in local_subdir_names: - print(f"Comparing contents of cache directory: {subdir_name}") - assert _compare_directories(storage_1_path / subdir_name, - storage_2_path / subdir_name), \ - (f"Contents differ for cache directory '{subdir_name}' between " - f"{storage_1_path} and {storage_2_path}") - - events = get_connector_events() - # get_num_new_matched_tokens will be called on each connector in turn. - # neither of them have hits so update_state_after_alloc won't be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will return new tokens from the first - # connector so update_state_after_alloc will be called once blocks - # are allocated for the first connector. - # get_num_new_matched_tokens *won't* be called on the second connector - # in this case. - assert events["storage1"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - assert events["storage2"][:2] == [ - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Delete storage1 connector state - shutil.rmtree(storage_1_path) - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will be called for the first connector but it - # won't have a hit so update_state_after_alloc won't be called. - # get_num_new_matched_tokens will also be called on the second connector, - # but it should have a hit so update_state_after_alloc will be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Clean up - shutil.rmtree(storage_1_path) - shutil.rmtree(storage_2_path) - - -def get_connector_events() -> dict[str, list[str]]: - # Read in connector events and reset the files. - import glob - event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") - connector_events = {} - for fname in event_files: - name = fname.split("connector_")[1].split("_events.log")[0] - try: - with open(fname, "r+") as f: - connector_events[name] = [ - line.strip() for line in f if line.strip() - ] - f.truncate(0) - except Exception as e: - print(f"[ERROR] Could not read connector events for {name}: {e}") - - return connector_events diff --git a/vllm/config.py b/vllm/config.py index 5ca70f2f67b..f60547bdd4e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3402,8 +3402,6 @@ class KVTransferConfig(BaseModel): kv_connector: Optional[str] = None # Engine ID for the KV transfers. - # Note(tms): sticking this here so the engine_id is consistent between - # scheduler-side and worker-side of the KVConnector engine_id: str = str(uuid.uuid4()) # The device used by kv connector to buffer the KV cache. diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 9d342115ccf..54cb1871db3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -110,8 +110,3 @@ def create_connector_v1( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "NixlConnector") - -KVConnectorFactory.register_connector( - "MultiConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", - "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a4a735890da..ca9e1915671 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,6 +22,7 @@ import enum from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING import torch @@ -46,6 +47,7 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +@dataclass class KVConnectorMetadata: pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py deleted file mode 100644 index e8857d6e367..00000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import copy -from typing import TYPE_CHECKING - -import torch - -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext - from vllm.v1.request import Request - -logger = init_logger(__name__) - - -class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], - KVConnectorMetadata): - pass - - -class MultiConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self._connectors = [] - ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") - assert ktcs is not None - for ktc in ktcs: - temp_config = copy.copy(vllm_config) - temp_config.kv_transfer_config = KVTransferConfig(**ktc) - self._connectors.append( - KVConnectorFactory.create_connector_v1(temp_config, role)) - - # A mapping from request id to the connector that is assigned to it. - self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} - - # We must override the base class method here because we need to bind - # the metadata to each connector in the order of the connectors in the - # MultiKVConnectorMetadata. - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: - assert isinstance(connector_metadata, MultiKVConnectorMetadata) - for c, cm in zip(self._connectors, connector_metadata): - c.bind_connector_metadata(cm) - - def clear_connector_metadata(self) -> None: - for c in self._connectors: - c.clear_connector_metadata() - - # ============================== - # Worker-side methods - # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - for c in self._connectors: - c.start_load_kv(forward_context, **kwargs) - - def wait_for_layer_load(self, layer_name: str) -> None: - for c in self._connectors: - c.wait_for_layer_load(layer_name) - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - for c in self._connectors: - c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) - - def wait_for_save(self): - for c in self._connectors: - c.wait_for_save() - - # ============================== - # Scheduler-side methods - # ============================== - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> int: - for c in self._connectors: - toks = c.get_num_new_matched_tokens(request, num_computed_tokens) - # The first connector that has new matched tokens will be assigned - # to this request. - if toks > 0: - self._requests_to_connector[request.request_id] = c - return toks - return 0 - - def update_state_after_alloc(self, request: "Request", - block_ids: list[int], - num_external_tokens: int): - # If the request is not assigned to any connector, we do nothing. - if request.request_id not in self._requests_to_connector: - return - # We assume that the request is assigned to only one connector. - c = self._requests_to_connector.pop(request.request_id) - c.update_state_after_alloc(request, block_ids, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - return MultiKVConnectorMetadata( - c.build_connector_meta(scheduler_output) for c in self._connectors) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aefba620e18..bc10b67f0ce 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1430,7 +1430,11 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: - self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + if self.kv_transfer_config: + self.scheduler_cls = ( + "vllm.v1.core.sched.scheduler_disagg.DisaggregatedScheduler") + else: + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" # When no user override, set the default values based on the usage # context. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c24ba0f45f9..f6e33d35e3b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,6 +11,10 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 12c55be0037..9ccb7d15b56 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -178,7 +178,7 @@ def allocate_slots( prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such - as eagle. + as eagle. skip_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer which will complete in a future step. @@ -271,13 +271,9 @@ def allocate_slots( if not self.enable_caching: return new_blocks + # For disaggregated, avoid caching until KVs are recved. if skip_cache_blocks: - # NOTE(rob): this assert is valid because we only call - # skip_cache_blocks=True on the first time of WAITING - # during a P/D setup. assert request.request_id not in self.num_cached_block - # NOTE(rob): this is necessary so we don't double - # cache a block after is has finished recving. self.num_cached_block[request.request_id] = len( new_computed_blocks) return new_blocks @@ -306,8 +302,7 @@ def cache_blocks( # for a running request. num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) - - # Speculated tokens might be rejected in the future, so we do + # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bea766ff464..1cd3551dbf7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,21 +2,15 @@ from __future__ import annotations -import itertools import time from collections import defaultdict, deque from collections.abc import Iterable from typing import Optional, Union -from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -71,14 +65,6 @@ def __init__( self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events) - # Create KVConnector for the Scheduler. Note that each Worker - # will have a corresponding KVConnector with Role=WORKER. - # KV Connector pushes/pull of remote KVs for P/D and offloading. - self.connector = None - if self.vllm_config.kv_transfer_config is not None: - self.connector = KVConnectorFactory.create_connector_v1( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) - self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config) @@ -99,9 +85,6 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() - # Requests in states for tracking KV transfers for P/D disagg - self.finished_recving_kv_req_ids: set[str] = set() - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -314,27 +297,6 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] - # Skip request if the remote KV recv is still waiting - # for the requests to arrive. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - if request.request_id in self.finished_recving_kv_req_ids: - assert self.kv_cache_manager.enable_caching - # Now that the KVs have been recved, we can cache - # them and set num_computed_tokens. - self.kv_cache_manager.cache_blocks( - request, - num_tokens=0, - num_computed_tokens=(len(request.all_token_ids) - - 1)) - self.finished_recving_kv_req_ids.remove( - request.request_id) - request.status = RequestStatus.WAITING - self.kv_cache_manager.free(request) - else: - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - continue - # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -362,50 +324,9 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks( request) - # Get externally-cached tokens if using a KVConnector. - num_external_tokens = ( - 0 if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens += num_external_tokens - - if request.do_remote_prefill and num_external_tokens > 0: - # Allocate slots for the external tokens, but skip - # caching until after the KV transfer is done. - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_external_tokens, - computed_blocks, - skip_cache_blocks=True) - if new_blocks is None: - # Requests cannot be scheduled - break - - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - request.status = RequestStatus.WAITING_FOR_REMOTE_KVS - - # KVConnector: update internal state after allocation. - # This information is used to determine if a load is - # needed for this request. - if self.connector is not None: - self.connector.update_state_after_alloc( - request, - [ - b.block_id for b in itertools.chain( - computed_blocks, new_blocks) - ], - num_external_tokens, - ) - # We should only trigger a KV transfer once per request. - request.do_remote_prefill = False - continue - # Number of tokens to be scheduled. # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed reqs, + # `request.num_prompt_tokens` to consider the resumed requests, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < @@ -430,7 +351,7 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_tokens, + num_new_tokens, computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, ) @@ -438,19 +359,6 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - # KVConnector: update internal state after allocation. - # This information is used to determine if a load is - # needed for this request. - if self.connector is not None: - self.connector.update_state_after_alloc( - request, - [ - b.block_id for b in itertools.chain( - computed_blocks, new_blocks) - ], - num_external_tokens, - ) - self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ @@ -479,7 +387,7 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if not request.do_remote_prefill and encoder_inputs_to_schedule: + if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. @@ -558,14 +466,6 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=grammar_bitmask, ) - # NOTE(Kuntai): this function is designed for multiple purposes: - # 1. Plan the KV cache store - # 2. Wrap up all the KV cache load / save ops into an opaque object - # 3. Clear the internal states of the connector - if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - events = self.kv_cache_manager.take_events() if events: batch = KVEventBatch(ts=time.time(), events=events) @@ -581,8 +481,7 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - if req := self.requests.get(req_id): - req.num_computed_tokens += num_scheduled_token + self.requests[req_id].num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output @@ -846,10 +745,7 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events(), - kv_transfer_params=kv_transfer_params, - )) - + events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -924,23 +820,14 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, - request: Request, - skip_free_blocks: bool = False) -> None: - assert request.is_finished() - self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) - - if not skip_free_blocks: - self._free_blocks(request) - - def _free_blocks(self, request: Request): + def _free_request(self, request: Request) -> None: assert request.is_finished() - assert request.request_id not in self._cached_reqs_data self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) + self.encoder_cache_manager.free(request) + self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] + self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py new file mode 100644 index 00000000000..2eecb98436e --- /dev/null +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -0,0 +1,636 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import itertools +import time +from collections import deque + +from vllm import envs +from vllm.distributed.kv_events import KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.logger import init_logger +from vllm.sampling_params import KVTransferParams +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.utils import check_stop +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class DisaggregatedScheduler(Scheduler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # NOTE(rob): there is no reason to believe these are not + # supported. However, I would like to test them first + # before enabling them with P/D. + if self.use_eagle or self.vllm_config.speculative_config: + raise NotImplementedError( + "Speculative Decoding is not yet supported with " + "KV Disaggregation.") + if self.lora_config: + raise NotImplementedError( + "LoRA is not yet supported with KV Disaggregation.") + + # Create KVConnector for the Scheduler. + if self.vllm_config.kv_transfer_config is None: + raise ValueError("Using Disaggregated Scheduler but found unset " + "kv_transfer_config.") + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + # Requests in states for tracking KV transfers. + self.finished_recving_kv_req_ids: set[str] = set() + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp) + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Encoder-related. + if not request.do_remote_prefill and encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later (e.g. for FSM + # or KVCacheSending). + skipped_waiting_requests: deque[Request] = deque() + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + # Skip request if the remote KV recv is still waiting + # for the requests to arrive. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + if request.request_id in self.finished_recving_kv_req_ids: + assert self.kv_cache_manager.enable_caching + # Now that the KVs have been recved, we can cache + # them and set num_computed_tokens. + self.kv_cache_manager.cache_blocks( + request, + num_tokens=0, + num_computed_tokens=(len(request.all_token_ids) - + 1)) + self.finished_recving_kv_req_ids.remove( + request.request_id) + request.status = RequestStatus.WAITING + self.kv_cache_manager.free(request) + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Get already-cached tokens. + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + + if request.do_remote_prefill and num_external_tokens > 0: + # Allocate slots for the external tokens, but skip + # caching until after the KV transfer is done. + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_external_tokens, + computed_blocks, + skip_cache_blocks=True) + if new_blocks is None: + # Requests cannot be scheduled + break + + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], + num_external_tokens, + ) + # We should only trigger a KV transfer once per request. + request.do_remote_prefill = False + continue + + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed request, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + else: + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_tokens, + computed_blocks, + ) + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], + num_external_tokens, + ) + + self.waiting.popleft() + if request.use_structured_output: + structured_output_request_ids[ + request.request_id] = req_index + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens={}) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + request=req, + num_scheduled_tokens=num_scheduled_tokens[req.request_id], + num_scheduled_spec_tokens=0, + new_block_ids=req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + request=req, + num_scheduled_tokens=num_scheduled_tokens[req.request_id], + num_scheduled_spec_tokens=0, + new_block_ids=req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + if req := self.requests.get(req_id): + req.num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() + return scheduler_output + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] + send_kv_no_op: list[str] = [] + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions.offset + num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and request.use_structured_output: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if request.use_structured_output: + metadata = request.structured_output_request + assert metadata is not None and metadata.grammar is not None + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids: + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. + kv_transfer_params = None + if request.do_remote_decode and not stopped: + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self._free_request(request, skip_free_blocks=True) + stopped = True + + # TODO(rob): do this on a per-Connector basis. + remote_blocks = [ + block.block_id for block in + self.kv_cache_manager.req_to_blocks[request.request_id] + if block._block_hash is not None + ] + # If prompt < block_size, then there will be no KV xfer. + # Free these requests so we don't have a mem leak. + if len(remote_blocks) == 0: + send_kv_no_op.append(request.request_id) + + engine_id = self.vllm_config.kv_transfer_config.engine_id + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_block_ids=remote_blocks, + remote_engine_id=engine_id, + remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, + remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + ) + + # Add EngineCoreOutput for this Request. + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + if not stopped: + new_running.append(request) + + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + for req_id in send_kv_no_op: + logger.debug("No op sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + + # Return the cached request data to the queue so they can + # be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) + + self.running = new_running + engine_core_outputs = EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(), + ) + if self.include_finished_set: + #TODO currently sending duplicates here, improve this + engine_core_outputs.finished_requests = ( + scheduler_output.finished_req_ids | self.finished_req_ids) + + return engine_core_outputs + + def _free_request(self, + request: Request, + skip_free_blocks: bool = False) -> None: + assert request.is_finished() + self.encoder_cache_manager.free(request) + self._cached_reqs_data.pop(request.request_id, None) + self.finished_req_ids.add(request.request_id) + + if not skip_free_blocks: + self._free_blocks(request) + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e772615b786..d13cef4a3ab 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -28,6 +28,8 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler +from vllm.v1.core.sched.scheduler_disagg import ( # noqa: E501 + DisaggregatedScheduler as V1DisaggregatedScheduler) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache @@ -86,7 +88,7 @@ def __init__(self, # This warning can be removed once the V1 Scheduler interface is # finalized and we can maintain support for scheduler classes that # implement it - if Scheduler is not V1Scheduler: + if Scheduler not in [V1Scheduler, V1DisaggregatedScheduler]: logger.warning( "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1d98f15ebde..3f6f1f685e4 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -146,7 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], - kv_transfer_params: KVTransferParams, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -176,7 +176,7 @@ def _new_request_output( request_id: str, outputs: list[CompletionOutput], finished: bool, - kv_transfer_params: KVTransferParams, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -305,22 +305,22 @@ def process_outputs( 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. ****************** NOTE FOR DEVELOPERS ****************** vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e8ce0df5ed8..f4a240bc7b0 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -105,11 +105,13 @@ class ModelRunnerOutput: finished_recving: Optional[set[str]] = None -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - finished_sending=None, - finished_recving=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None, +) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3a8dae04ee0..d0e8d62eba2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1064,10 +1064,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: return output # Prepare the decoder inputs. - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -1141,7 +1140,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: num_tokens=num_input_tokens): maybe_setup_kv_connector() - model_output = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1152,9 +1151,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: finished_sending, finished_recving = maybe_get_finished() if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output + hidden_states, aux_hidden_states = output else: - hidden_states = model_output + hidden_states = output if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states.