Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][CPU] support cpu in v1 engine #11063

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,26 +297,33 @@
dtype=torch.int32,
device="cpu",
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
query_start_loc: torch.Tensor
kv_start_loc: torch.Tensor
if input_data.seq_start_loc is not None and input_data.seq_start_loc is not None:

Check failure on line 302 in vllm/attention/backends/torch_sdpa.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/torch_sdpa.py:302:81: E501 Line too long (93 > 80)
query_start_loc = input_data.query_start_loc[input_data.num_prefills + 1:]

Check failure on line 303 in vllm/attention/backends/torch_sdpa.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/torch_sdpa.py:303:81: E501 Line too long (90 > 80)
kv_start_loc = input_data.seq_start_loc
else:
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")

query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.vllm_flash_attn import flash_attn_varlen_func
# from vllm.vllm_flash_attn import flash_attn_varlen_func


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -166,6 +166,7 @@ def forward(
)

# Compute attention and update output up to `num_actual_tokens`.
from vllm.vllm_flash_attn import flash_attn_varlen_func
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@
num_computed_tokens -= 1
num_new_tokens = 1
computed_blocks.pop()
# if current request can't be fully scheduled, skip and don't schedule it

Check failure on line 206 in vllm/v1/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/core/scheduler.py:206:81: E501 Line too long (89 > 80)
if num_new_tokens > token_budget:
break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand All @@ -21,6 +22,7 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.cpu_executor import CPUExecutor

logger = init_logger(__name__)

Expand Down Expand Up @@ -127,6 +129,8 @@ def shutdown(self):

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
if current_platform.is_cpu():
return CPUExecutor
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
Expand All @@ -21,6 +22,7 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.cpu_executor import CPUExecutor

logger = init_logger(__name__)

Expand Down Expand Up @@ -104,6 +106,8 @@ def from_engine_args(

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
if current_platform.is_cpu():
return CPUExecutor
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
Expand Down
Loading
Loading