Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] add a flag to track batchsize statistics #11059

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1


def get_default_cache_root():
Expand Down Expand Up @@ -452,6 +453,8 @@ def get_default_config_root():
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
"VLLM_LOG_BATCHSIZE_INTERVAL":
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
}

# end-env-vars-definition
Expand Down
32 changes: 31 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import time
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter: Counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL


@dataclass
Expand All @@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
if hasattr(context, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + context.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
sorted_data = sorted(batchsize_counter.items(),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_input_tokens: int = 0 # Number of tokens including padding.


class FlashAttentionImpl(AttentionImpl):
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def execute_model(
# Eager mode.
num_input_tokens = num_scheduled_tokens

attn_metadata.num_input_tokens = num_input_tokens

# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
Expand Down
Loading