From ae26a38dcff1a966fd0def98a64b2a86d6602f2c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 31 Oct 2024 16:12:42 -0700 Subject: [PATCH] make incremental detokenization OO, handle stop strings --- vllm/engine/output_processor/stop_checker.py | 37 +-- vllm/v1/engine/__init__.py | 6 +- vllm/v1/engine/async_llm_engine.py | 12 +- vllm/v1/engine/detokenizer.py | 263 ++++++++++++------- vllm/v1/engine/llm_engine.py | 5 +- vllm/v1/engine/processor.py | 4 +- 6 files changed, 203 insertions(+), 124 deletions(-) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index a71ad493d9920..e25649b2b0b9f 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -67,9 +67,13 @@ def maybe_stop_sequence( return # Check if any stop strings are matched. - stop_str = self._check_stop_strings(seq, new_char_count, - sampling_params) - if stop_str is not None: + stop = self.check_stop_strings( + seq.output_text, new_char_count, sampling_params.stop, + sampling_params.include_stop_str_in_output) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + seq.output_text = seq.output_text[:truncate_to] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = stop_str return @@ -85,33 +89,36 @@ def maybe_stop_sequence( return @staticmethod - def _check_stop_strings(seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> Optional[str]: + def check_stop_strings( + output_text: str, + new_char_count: int, + stop: List[str], + include_in_output: bool, + ) -> Optional[Tuple[str, int]]: """Check if any stop strings are matched and truncate sequence output text accordingly. Returns the stop string if matched or else None. """ - if not new_char_count or not sampling_params.stop: + if not new_char_count or not stop: return None - for stop_str in sampling_params.stop: + for stop_str in stop: stop_string_len = len(stop_str) # Avoid searching already-searched text. - stop_index = seq.output_text.find( - stop_str, -new_char_count - stop_string_len) + stop_index = output_text.find(stop_str, + -new_char_count - stop_string_len) if stop_index == -1: continue - if sampling_params.include_stop_str_in_output: + if include_in_output: # Truncate to end of stop string. stop_index += stop_string_len - if stop_index >= len(seq.output_text): + if stop_index >= len(output_text): # No truncation required. - return stop_str + return stop_str, -1 # Truncate the output text to either the beginning # or end of the stop string. - seq.output_text = seq.output_text[:stop_index] - return stop_str + return stop_str, stop_index return None diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 94c29faed0811..d929fa7884dd6 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,11 +1,9 @@ -import asyncio from dataclasses import dataclass from typing import List, Optional, Union import msgspec from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams POLLING_TIMEOUT_MS = 5000 @@ -21,8 +19,8 @@ class DetokenizerRequest: spaces_between_special_tokens: bool output_kind: RequestOutputKind - # Queue for streaming outputs to clients. - output_queue: Optional[asyncio.Queue[RequestOutput]] = None + stop: List[str] + include_stop_str_in_output: bool class EngineCoreRequest(msgspec.Struct): diff --git a/vllm/v1/engine/async_llm_engine.py b/vllm/v1/engine/async_llm_engine.py index 6058ce16be9c9..063decdc1368b 100644 --- a/vllm/v1/engine/async_llm_engine.py +++ b/vllm/v1/engine/async_llm_engine.py @@ -131,6 +131,10 @@ async def add_request( raise KeyError(f"Request {request_id} already exists.") # TODO: handle abort. + # IDEA(Nick): we could batch up aborts rather than sending + # them individually, so that we send at most one batch of + # aborts per step (added to any that we're doing due to + # stop string matches for that step) def _abort(): pass @@ -152,6 +156,11 @@ def _abort(): return stream.generator() + # TODO: we should support multiple prompts in one call, as you + # can do with LLM.generate. So that for multi-prompt completion + # requests we don't need to send multiple messages to core proc, + # and so we don't need multiple streams which then get + # re-multiplexed in the API server anyhow. async def generate( self, prompt: PromptType, @@ -188,7 +197,8 @@ async def run_output_handler(self): # NOTE: we could simplify the Detokenizer code by returning full # List[RequestOutput] rather than pushing to the Queue at the # expense of doing another loop through List[RequestOutput]. - self.detokenizer.step_streaming(outputs) + _to_abort = self.detokenizer.step_streaming(outputs) + # TODO: send aborts (in one message) except BaseException as e: logger.error(e) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index c44baf568f339..72d93649493e2 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,6 +1,9 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple +from sequence import SequenceStatus + +from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind @@ -14,13 +17,17 @@ @dataclass -class DetokenizerRequestState: +class IncrementalDetokenizer: # Generation data output_text: str tokens: List[str] token_ids: List[int] + # Stop strings + stop: List[str] + include_stop_str_in_output: bool + # Metadata for incremental detokenization prefix_offset: int read_offset: int @@ -33,6 +40,13 @@ class DetokenizerRequestState: # Request output (Cached + updated incrementally) request_output: RequestOutput + # Tokenizer for this request + tokenizer: AnyTokenizer + + # Accounting for stop string buffering + buffer_length: int + _last_output_text_offset: int = 0 + # Streaming RequestOutputs to clients in async mode. stream: Optional[AsyncStream] = None @@ -42,7 +56,7 @@ def from_new_request( tokenizer: AnyTokenizer, request: DetokenizerRequest, stream: Optional[AsyncStream] = None, - ) -> "DetokenizerRequestState": + ) -> "IncrementalDetokenizer": tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokenizer=tokenizer, @@ -56,11 +70,20 @@ def from_new_request( request.prompt_token_ids, ) + stops = request.stop + # How many chars to hold back stop strings are to be excluded + # from the output when streaming + buffer_length = 0 if not stops or request.include_stop_str_in_output \ + else max(len(s) for s in stops) - 1 + return cls( output_text="", tokens=tokens, # Detokenizer mutates this list, so need a unique copy. + # NOTE(Nick): could we take ownership of it though? token_ids=request.prompt_token_ids.copy(), + stop=stops, + include_stop_str_in_output=request.include_stop_str_in_output, prefix_offset=prefix_offset, read_offset=read_offset, skip_special_tokens=request.skip_special_tokens, @@ -68,9 +91,104 @@ def from_new_request( spaces_between_special_tokens, output_kind=request.output_kind, request_output=request_output, + tokenizer=tokenizer, + buffer_length=buffer_length, stream=stream, ) + def add_tokens( + self, + new_token_ids: List[int], + finish_reason: Optional[str], + stop_reason: Optional[str], + ) -> Optional[RequestOutput]: + """ + Update RequestState for the request_id by: + 1) Detokenize the new token ids incrementally. + 2) Update the RequestOutput with the new text. + """ + + # 1) Detokenize the new token ids incrementally. + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. + decoded_text = "" + for new_token_id in new_token_ids: + self.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + ) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + self.output_text += new_decoded_token_text + + decoded_text += new_decoded_token_text + + # 2) Evaluate stop criteria + if self.stop: + stop = StopChecker.check_stop_strings( + output_text=self.output_text, + new_char_count=666, + stop=self.stop, + include_in_output=self.include_stop_str_in_output, + ) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + self.output_text = self.output_text[:truncate_to] + finish_reason = SequenceStatus.FINISHED_STOPPED + stop_reason = stop_str + + # TODO: handle stop_token_ids here too? + + # 2) Update the RequestOutput object with the new text. + finished = bool(finish_reason) + if self.output_kind == RequestOutputKind.FINAL_ONLY \ + and not finished: + return None + + delta = self.output_kind == RequestOutputKind.DELTA + request_output = self.request_output + completion_output = request_output.outputs[0] + + output_text = self._get_next_output_text(finished, delta) + token_ids = new_token_ids if delta else self.token_ids + + completion_output.text = output_text + completion_output.token_ids = token_ids + + if finished: + completion_output.finish_reason = finish_reason + completion_output.stop_reason = stop_reason + request_output.finished = finished + + return request_output + + def _get_next_output_text(self, finished: bool, delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + + # We return the full output text if the sequence is finished. + buffer_length = 0 if finished else self.buffer_length + if not delta: + return self.output_text[:-buffer_length] if buffer_length else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + class Detokenizer: @@ -78,8 +196,8 @@ def __init__(self, tokenizer_name: str, stream_mode: bool = False): self.tokenizer = get_tokenizer(tokenizer_name) self.stream_mode = stream_mode - # Request id -> DetokenizerRequestState - self.request_states: Dict[str, DetokenizerRequestState] = {} + # Request id -> IncrementalDetokenizer + self.request_states: Dict[str, IncrementalDetokenizer] = {} def is_request_active(self, request_id: str): return request_id in self.request_states @@ -101,135 +219,76 @@ def add_request( assert ((self.stream_mode and stream is not None) or (not self.stream_mode and stream is None)) - request_state = DetokenizerRequestState.from_new_request( + request_state = IncrementalDetokenizer.from_new_request( self.tokenizer, request, stream) self.request_states[request.request_id] = request_state def step( - self, encore_core_outputs: List[EngineCoreOutput] - ) -> List[RequestOutput]: + self, encore_core_outputs: List[EngineCoreOutput] + ) -> Tuple[List[RequestOutput], List[str]]: """Update state and request the RequestOutputs to the LLMEngine.""" assert not self.stream_mode request_outputs: List[RequestOutput] = [] + requests_to_abort: List[str] = [] for engine_core_output in encore_core_outputs: request_id = engine_core_output.request_id + detokenizer = self.request_states[request_id] # Detokenize and update state. - request_output = self._update_request_state( - tokenizer=self.tokenizer, - request_state=self.request_states[request_id], + request_output = detokenizer.add_tokens( new_token_ids=engine_core_output.new_token_ids, - finished=engine_core_output.finished, finish_reason=engine_core_output.finish_reason, stop_reason=engine_core_output.stop_reason, ) - # Add to RequestOutputs list. - request_outputs.append(request_output) + if request_output is not None: + # Add to RequestOutputs list. + request_outputs.append(request_output) - # Free completed requests. - if engine_core_output.finished: - self.request_states.pop(request_id) + # Free completed requests. + if request_output.finished: + self.request_states.pop(request_id) + if not engine_core_output.finished: + requests_to_abort.append(request_id) # Return to EngineClient. - return request_outputs + return request_outputs, requests_to_abort - def step_streaming(self, - encore_core_outputs: List[EngineCoreOutput]) -> None: + def step_streaming( + self, encore_core_outputs: List[EngineCoreOutput]) -> List[str]: """Update state and put the RequestOutput in the per request queues.""" assert self.stream_mode + requests_to_abort: List[str] = [] for engine_core_output in encore_core_outputs: request_id = engine_core_output.request_id + detokenizer = self.request_states[request_id] # Detokenize and update state. - request_output = self._update_request_state( - tokenizer=self.tokenizer, - request_state=self.request_states[request_id], + request_output = detokenizer.add_tokens( new_token_ids=engine_core_output.new_token_ids, - finished=engine_core_output.finished, finish_reason=engine_core_output.finish_reason, stop_reason=engine_core_output.stop_reason, ) - # Send the RequestOutput to the per client output queue. - assert self.request_states[request_id].stream is not None - self.request_states[request_id].stream.put(request_output) - # TODO: is caching RequestOutput sound? - # What happens if the reader from the stream falls behind? - # Won't the object in the queue get mutated? - - # Free completed requests. - if engine_core_output.finished: - self.request_states[request_id].stream.finish() - self.request_states.pop(request_id) - logger.debug("Finished request %s.", request_id) - - @staticmethod - def _update_request_state( - tokenizer: AnyTokenizer, - request_state: DetokenizerRequestState, - new_token_ids: List[int], - finished: bool, - finish_reason: Optional[str], - stop_reason: Optional[str], - ) -> RequestOutput: - """ - Update RequestState for the request_id by: - 1) Detokenize the new token ids incrementally. - 2) Update the RequestOutput with the new text. - """ - - # 1) Detokenize the new token ids incrementally. - # TODO(woosuk): This method becomes very inefficient when the number of - # new_token_ids is more than 1. We need to optimize this. - decoded_text = "" - for new_token_id in new_token_ids: - request_state.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=request_state.token_ids, - prev_tokens=request_state.tokens, - prefix_offset=request_state.prefix_offset, - read_offset=request_state.read_offset, - skip_special_tokens=request_state.skip_special_tokens, - spaces_between_special_tokens=request_state. - spaces_between_special_tokens, - ) - - request_state.tokens.extend(new_tokens) - request_state.prefix_offset = prefix_offset - request_state.read_offset = read_offset - request_state.output_text += new_decoded_token_text - - decoded_text += new_decoded_token_text - - # 2) Update the RequestOutput object with the new text. - request_output = request_state.request_output - completion_output = request_output.outputs[0] - if request_state.output_kind == RequestOutputKind.CUMULATIVE: - completion_output.text += decoded_text - completion_output.token_ids = request_state.token_ids - elif request_state.output_kind == RequestOutputKind.DELTA: - completion_output.text = decoded_text - num_prev_tokens = len(completion_output.token_ids) - completion_output.token_ids = request_state.token_ids[ - num_prev_tokens:] - elif request_state.output_kind == RequestOutputKind.FINAL_ONLY: - if finished: - completion_output.text = request_state.output_text - completion_output.token_ids = request_state.token_ids - else: - completion_output.text = "" - completion_output.token_ids = [] - - if finished: - completion_output.finish_reason = finish_reason - completion_output.stop_reason = stop_reason - request_output.finished = finished - - return request_output + if request_output is not None: + # Send the RequestOutput to the per client output queue. + stream = detokenizer.stream + assert stream is not None + stream.put(request_output) + # TODO: is caching RequestOutput sound? + # What happens if the reader from the stream falls behind? + # Won't the object in the queue get mutated? + + # Free completed requests. + if request_output.finished: + stream.finish() + self.request_states.pop(request_id) + logger.debug("Finished request %s.", request_id) + if not engine_core_output.finished: + requests_to_abort.append(request_id) + + return requests_to_abort diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 609d66bdca4d2..c0016f5a4575a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -134,6 +134,9 @@ def step(self) -> List[RequestOutput]: engine_core_outputs = self.engine_core.step() # 2) Step the Detokenizer. - request_outputs = self.detokenizer.step(engine_core_outputs) + request_outputs, to_abort = self.detokenizer.step(engine_core_outputs) + + # 3) Abort requests that finished due to stop criteria + self.abort_request(to_abort) return request_outputs diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index f6b3538fe8990..852a7d3c55979 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -84,6 +84,7 @@ def process_inputs( eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) assert isinstance(params, SamplingParams) + # TODO: can we avoid cloning here in multiproc case sampling_params = params.clone() sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) @@ -94,7 +95,8 @@ def process_inputs( processed_inputs.get("prompt_token_ids"), sampling_params.skip_special_tokens, sampling_params.spaces_between_special_tokens, - sampling_params.output_kind) + sampling_params.output_kind, sampling_params.stop, + sampling_params.include_stop_str_in_output) # Make Request for EngineCore. engine_core_request = EngineCoreRequest(