Skip to content

Commit

Permalink
Various updates
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Nov 6, 2024
1 parent 7d3c114 commit f9a0f75
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 97 deletions.
11 changes: 8 additions & 3 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ async def async_request_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."

stream = True

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
Expand All @@ -238,7 +240,7 @@ async def async_request_openai_completions(
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream": stream,
"ignore_eos": request_func_input.ignore_eos,
}
headers = {
Expand All @@ -263,9 +265,10 @@ async def async_request_openai_completions(

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk == "[DONE]":
stream_is_done = stream and chunk == "[DONE]"
if not stream or stream_is_done:
latency = time.perf_counter() - st
else:
if not stream_is_done:
data = json.loads(chunk)

# NOTE: Some completion API might have a last
Expand Down Expand Up @@ -379,10 +382,12 @@ async def async_request_openai_chat_completions(
else:
output.error = response.reason or ""
output.success = False
print("Error reason", response.reason)
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
traceback.print_exc()

if pbar:
pbar.update(1)
Expand Down
44 changes: 44 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,3 +2038,47 @@ def __post_init__(self):
self.model_config is not None and self.load_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)

def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
(self.model_config.model, self.speculative_config,
self.model_config.tokenizer,
self.model_config.skip_tokenizer_init,
self.model_config.tokenizer_mode,
self.model_config.revision,
self.model_config.override_neuron_config,
self.model_config.rope_scaling,
self.model_config.rope_theta,
self.model_config.tokenizer_revision,
self.model_config.trust_remote_code,
self.model_config.dtype,
self.model_config.max_model_len,
self.load_config.download_dir,
self.load_config.load_format,
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.disable_custom_all_reduce,
self.model_config.quantization,
self.model_config.enforce_eager,
self.cache_config.cache_dtype,
self.model_config.quantization_param_path,
self.device_config.device, self.decoding_config,
self.observability_config, self.model_config.seed,
self.model_config.served_model_name,
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
6 changes: 5 additions & 1 deletion vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def check_stop_strings(
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not stop:
return None
Expand Down
2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def new(
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new "empty" RequestOutput object."""
"""Initialize a new RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DetokenizerRequest:
include_stop_str_in_output: bool


class EngineCoreRequest(msgspec.Struct):
class EngineCoreRequest(msgspec.Struct, omit_defaults=True):

# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
Expand All @@ -42,7 +42,10 @@ class EngineCoreRequest(msgspec.Struct):
lora_request: Optional[LoRARequest]


class EngineCoreOutput(msgspec.Struct, array_like=True):
class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):

request_id: str
new_token_ids: List[int]
Expand All @@ -51,7 +54,10 @@ class EngineCoreOutput(msgspec.Struct, array_like=True):
stop_reason: Union[int, str, None] = None


class EngineCoreOutputs(msgspec.Struct, array_like=True):
class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):

#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
Expand Down
66 changes: 36 additions & 30 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(

# Map (request_id -> Stream)
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []

# Processor (converts Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
Expand All @@ -76,28 +78,26 @@ def __init__(
# TODO: add background loop shielding
# TODO: add AsyncEngineDeadError

self.is_output_handler_running = False
self.output_handler = None

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
vllm_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an AsyncLLMEngine from the EngineArgs."""
"""Creates an AsyncLLM from the EngineArgs."""

# Create the engine configs.
if engine_config is None:
if vllm_config is None:
vllm_config = engine_args.create_engine_config()
else:
vllm_config = engine_config

executor_class = cls._get_executor_cls(vllm_config)

# Create the AsyncLLMEngine.
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
Expand All @@ -112,8 +112,8 @@ def shutdown(self):
"""Shutdown the EngineCore."""
self.engine_core.shutdown()

if hasattr(self, "output_handler"):
self.output_handler.cancel()
if handler := getattr(self, "output_handler", None):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
Expand All @@ -123,15 +123,9 @@ def _add_request_to_streams(self, request_id: str) -> AsyncStream:
if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")

# TODO: handle abort.
# IDEA(Nick): we could batch up aborts rather than sending
# them individually, so that we send at most one batch of
# aborts per step (added to any that we're doing due to
# stop string matches for that step)
def _abort():
pass

stream = AsyncStream(request_id, _abort)
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream
return stream

Expand All @@ -140,20 +134,33 @@ def _send_to_streams(self, request_outputs: List[RequestOutput]):

for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams

self.request_streams[request_id].put(request_output)
stream = self.request_streams.get(request_id)
if stream is not None:
finished = request_output.finished
stream.put(request_output)
if finished:
self._finish_stream(request_id)

if request_output.finished:
self.request_streams[request_id].finish()
self.request_streams.pop(request_id)
def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id)
if stream is not None:
stream.finish()

async def abort_requests(self, request_ids: List[str]) -> None:
async def _abort_requests(self, request_ids: List[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""

if len(request_ids) > 0:
# Include any client cancellations.
client_aborted_reqs = self.client_aborted_requests
if client_aborted_reqs:
self.detokenizer.abort_requests(client_aborted_reqs)
for request_id in client_aborted_reqs:
self._finish_stream(request_id)
request_ids.extend(client_aborted_reqs)
client_aborted_reqs.clear()

if request_ids:
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)

async def add_request(
self,
Expand Down Expand Up @@ -205,10 +212,9 @@ async def generate(
# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if not self.is_output_handler_running:
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self.is_output_handler_running = True

async for output in await self.add_request(
request_id,
Expand Down Expand Up @@ -241,7 +247,7 @@ async def _run_output_handler(self):
self._send_to_streams(request_outputs)

# Abort any requests that finished due to stop strings.
await self.abort_requests(reqs_to_abort)
await self._abort_requests(reqs_to_abort)

except BaseException as e:
logger.error(e)
Expand Down Expand Up @@ -288,7 +294,7 @@ async def do_log_stats(
logger.debug("Called do_log_stats.")

async def check_health(self) -> None:
logger.debug("Called do_log_stats.")
logger.debug("Called check_health.")

async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
Expand Down
60 changes: 12 additions & 48 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,50 +44,8 @@ def __init__(

assert vllm_config.model_config.task != "embedding"

logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)", VLLM_VERSION,
vllm_config.model_config.model, vllm_config.speculative_config,
vllm_config.model_config.tokenizer,
vllm_config.model_config.skip_tokenizer_init,
vllm_config.model_config.tokenizer_mode,
vllm_config.model_config.revision,
vllm_config.model_config.override_neuron_config,
vllm_config.model_config.rope_scaling,
vllm_config.model_config.rope_theta,
vllm_config.model_config.tokenizer_revision,
vllm_config.model_config.trust_remote_code,
vllm_config.model_config.dtype,
vllm_config.model_config.max_model_len,
vllm_config.load_config.download_dir,
vllm_config.load_config.load_format,
vllm_config.parallel_config.tensor_parallel_size,
vllm_config.parallel_config.pipeline_parallel_size,
vllm_config.parallel_config.disable_custom_all_reduce,
vllm_config.model_config.quantization,
vllm_config.model_config.enforce_eager,
vllm_config.cache_config.cache_dtype,
vllm_config.model_config.quantization_param_path,
vllm_config.device_config.device, vllm_config.decoding_config,
vllm_config.observability_config, vllm_config.model_config.seed,
vllm_config.model_config.served_model_name,
vllm_config.scheduler_config.num_scheduler_steps,
vllm_config.cache_config.enable_prefix_caching,
vllm_config.model_config.use_async_output_proc,
vllm_config.model_config.mm_processor_kwargs)
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)

# Setup Model.
self.model_executor = executor_class(vllm_config)
Expand Down Expand Up @@ -129,6 +87,9 @@ def add_request(self, request: EngineCoreRequest):
def abort_requests(self, request_ids: List[str]):
"""Abort requests from the scheduler."""

# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

Expand Down Expand Up @@ -166,7 +127,9 @@ def __init__(
self.should_shutdown = should_shutdown

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL.
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
Expand Down Expand Up @@ -271,7 +234,6 @@ def make_engine_core_process(
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""

engine_core = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
Expand Down Expand Up @@ -352,12 +314,14 @@ def process_output_socket(self, output_path: str):

# Msgpack serialization encoding..
encoder = msgpack.Encoder()
# Reuse send buffer
buffer = bytearray()

with self.make_socket(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs_serialized = encoder.encode(outputs)
socket.send_multipart((outputs_serialized, ),
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ),
copy=False,
flags=zmq.NOBLOCK)
Loading

0 comments on commit f9a0f75

Please sign in to comment.