Skip to content

Commit

Permalink
[torch.compile] add a flag to track batchsize statistics (vllm-projec…
Browse files Browse the repository at this point in the history
…t#11059)

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
  • Loading branch information
youkaichao authored and Akshat-Tripathi committed Dec 12, 2024
1 parent 8fbffb2 commit 3e734b3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1


def get_default_cache_root():
Expand Down Expand Up @@ -452,6 +453,8 @@ def get_default_config_root():
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
"VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
}

# end-env-vars-definition
Expand Down
32 changes: 31 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import time
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter: Counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL


@dataclass
Expand All @@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
if hasattr(context, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + context.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
sorted_data = sorted(batchsize_counter.items(),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_input_tokens: int = 0 # Number of tokens including padding.


class FlashAttentionImpl(AttentionImpl):
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def execute_model(
# Eager mode.
num_input_tokens = num_scheduled_tokens

attn_metadata.num_input_tokens = num_input_tokens

# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
Expand Down

0 comments on commit 3e734b3

Please sign in to comment.