Skip to content

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <[email protected]>
  • Loading branch information
yma11 committed Dec 10, 2024
1 parent a3d0fa0 commit 512bd05
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 37 deletions.
24 changes: 16 additions & 8 deletions vllm/v1/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
enable_trace_function_call_for_thread,
make_async, enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)

from vllm.v1.executor.abstract import Executor
Expand Down Expand Up @@ -44,10 +43,6 @@ def __init__(self, vllm_config: VllmConfig) -> None:
#
# Environment variables for CPU executor
#

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

Expand Down Expand Up @@ -191,7 +186,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.driver_method_invoker(self.driver_worker,
"determine_num_available_blocks")

def initialize_cache(self, num_gpu_blocks: int,
def initialize(self, num_gpu_blocks: int,
num_cpu_blocks: int = 0) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
Expand Down Expand Up @@ -246,6 +241,19 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
for result in parallel_worker_tasks:
result.get()

# def initialize(self, num_gpu_blocks: int) -> None:
# pass

def profile(self, is_start=True):
pass

def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
pass


class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):

Expand Down Expand Up @@ -301,7 +309,7 @@ def init_worker(self, *args, **kwargs):
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()
enable_trace_function_call_for_thread(self.vllm_config)

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
Expand Down
33 changes: 4 additions & 29 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class CPUModelRunner(GPUModelRunner):
#
def __init__(self, vllm_config):
super().__init__(vllm_config)
super().__init__(vllm_config, vllm_config.device_config.device)
self.use_cuda_graph = False
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
Expand Down Expand Up @@ -85,8 +85,7 @@ def execute_model(
)

# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
sampled_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
Expand All @@ -97,7 +96,7 @@ def execute_model(
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids_list[i]
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
else:
Expand All @@ -119,7 +118,7 @@ def execute_model(
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
)
Expand All @@ -129,30 +128,6 @@ def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# build input_data
'''
self.use_mrope = use_mrope
self.input_tokens: List[int] = []
self.input_positions: Optional[
List[int]] = [] if not self.use_mrope else None
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
self.decode_block_tables: List[List[int]] = []
self.max_decode_seq_len: int = 0
self.num_prefills: int = 0
self.num_prefill_tokens: int = 0
self.num_decode_tokens: int = 0
self.slot_mapping: List[int] = []
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
self.input_mrope_positions: Optional[List[List[int]]] = [
[] for _ in range(3)
] if self.use_mrope else None
'''
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
Expand Down

0 comments on commit 512bd05

Please sign in to comment.