diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5e3df1c6e8a1d..fa983fc39ce76 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ import multiprocessing import queue -from collections.abc import Buffer from multiprocessing.process import BaseProcess from threading import Thread from typing import List, Tuple, Type @@ -158,8 +157,8 @@ def __init__( self.ctx = zmq.Context() # type: ignore[attr-defined] - self.input_queue: queue.Queue[Buffer] = queue.Queue() - self.output_queue: queue.Queue[Buffer] = queue.Queue() + self.input_queue = queue.Queue() + self.output_queue = queue.Queue() # Get EngineCoreRequests from the LLMEngine. self.input_socket = self.ctx.socket(zmq.constants.PULL) @@ -185,12 +184,15 @@ def __init__( def process_input_socket(self): while True: frames = self.input_socket.recv_multipart(copy=False) - self.input_queue.put_nowait(frames[0].buffer) + request = self.msgpack_decoder.decode(frames[0].buffer) + self.input_queue.put_nowait(request) def process_output_socket(self): while True: - serialized = self.output_queue.get() - self.output_socket.send_multipart((serialized, ), + engine_core_outputs = self.output_queue.get() + outputs = EngineCoreOutputs(outputs=engine_core_outputs) + outputs_serialized = self.msgpack_encoder.encode(outputs) + self.output_socket.send_multipart((outputs_serialized, ), copy=False, flags=zmq.NOBLOCK) @@ -265,8 +267,8 @@ def run_busy_loop(self): while True: # Poll the input socket until there is work to do. if not self.scheduler.has_unfinished_requests(): - buffer = self.input_queue.get() - self._handle_input_buffer(buffer) + request = self.input_queue.get() + self._handle_request(request) # Handle new input from the socket. self._handle_new_input() @@ -277,10 +279,15 @@ def run_busy_loop(self): # Send outputs to the EngineCoreClient. self._send_outputs(outputs) - def _handle_input_buffer(self, buffer): + def _handle_new_input(self): + """Handle new input from the AsyncLLMEngine for async mode.""" + while not self.input_queue.empty(): + request = self.input_queue.get_nowait() + self._handle_request(request) + + def _handle_request(self, request: EngineCoreRequest): try: - engine_core_request = self.msgpack_decoder.decode(buffer) - self.add_request(engine_core_request) + self.add_request(request) # TODO: handle abort via another socket # TODO: handle logits processors via cloudpickle @@ -290,19 +297,9 @@ def _handle_input_buffer(self, buffer): # TODO: handle gracefully raise e - def _handle_new_input(self): - """Handle new input from the AsyncLLMEngine for async mode.""" - while not self.input_queue.empty(): - buffer = self.input_queue.get_nowait() - self._handle_input_buffer(buffer) - def _send_outputs(self, engine_core_outputs: List[EngineCoreOutput]) -> None: """Serialize and send output to the AsyncLLMEngine for async mode.""" - if not engine_core_outputs: - return - - outputs = EngineCoreOutputs(outputs=engine_core_outputs) - outputs_serialized = self.msgpack_encoder.encode(outputs) - self.output_queue.put_nowait(outputs_serialized) + if engine_core_outputs: + self.output_queue.put_nowait(engine_core_outputs)