diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py index 23876ac82..eabe4d3c1 100644 --- a/lmdeploy/serve/client.py +++ b/lmdeploy/serve/client.py @@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1): Args: tritonserver_addr (str): the address in format "ip:port" of triton inference server - model_name (str): the name of the deployed model session_id (int): the identical id of a session """ log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index ff253e49d..05bb24869 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -26,6 +26,7 @@ class Session: request_id: str = '' histories: str = '' # history conversations of the session sequence_length: int = 0 # the total generated token number in the session + sequence_offset: int = 0 # the new generated token offset in the session prompt: str = '' response: str = '' status: int = None # status of the session @@ -539,14 +540,15 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, Yields: tuple: status, text, generated token number """ - offset = n_input_token + preseq_length + session.sequence_offset = n_input_token + preseq_length + sentinel = n_input_token + preseq_length status, res, n_token = None, '', 0 while True: result = res_queue.get() if result is None: status = StatusCode.TRITON_STREAM_END res = session.response - n_token = session.sequence_length - offset + n_token = session.sequence_length - sentinel session.status = StatusCode.TRITON_STREAM_END break if 'errcode' in result: @@ -569,30 +571,31 @@ def stream_consumer(postprocess, res_queue, session, n_input_token, output_ids = result.as_numpy('output_ids') session.sequence_length = sequence_length.squeeze() - sequence_length = sequence_length - offset + new_token_length = sequence_length - session.sequence_offset last_token_id = output_ids[-1][-1][session.sequence_length - 1] if last_token_id == eos_id: session.sequence_length = session.sequence_length - 1 - sequence_length = sequence_length - 1 + new_token_length = new_token_length - 1 output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) - sequence_length = sequence_length.reshape( - (1, sequence_length.shape[-1])) + new_token_length = new_token_length.reshape( + (1, new_token_length.shape[-1])) if profile_generation: yield (StatusCode.TRITON_STREAM_ING, 'postprocessing is ignored during profiling ' - 'token generation', sequence_length.squeeze()) + 'token generation', new_token_length.squeeze()) continue - output_str = postprocess(output_ids[:, :, offset:], - sequence_length) + output_str = postprocess( + output_ids[:, :, session.sequence_offset:], + new_token_length) + session.sequence_offset = session.sequence_length text = output_str[0].decode() if display: - new_text = text[len(session.response):] - print(new_text, end='', flush=True) - session.response = text + print(text, end='', flush=True) + session.response += text yield (StatusCode.TRITON_STREAM_ING, session.response, - sequence_length.squeeze()) + session.sequence_offset - sentinel) except Exception as e: logger.error(f'catch exception: {e}')