Skip to content

Commit

Permalink
rollback chatbot.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Aug 11, 2023
1 parent d173210 commit 629a97f
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class Session:
request_id: str = ''
histories: str = '' # history conversations of the session
sequence_length: int = 0 # the total generated token number in the session
sequence_offset: int = 0 # the new generated token offset in the session
prompt: str = ''
response: str = ''
status: int = None # status of the session
Expand Down Expand Up @@ -599,15 +598,14 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
Yields:
tuple: status, text, generated token number
"""
session.sequence_offset = n_input_token + preseq_length
sentinel = n_input_token + preseq_length
offset = n_input_token + preseq_length
status, res, n_token = None, '', 0
while True:
result = res_queue.get()
if result is None:
status = StatusCode.TRITON_STREAM_END
res = session.response
n_token = session.sequence_length - sentinel
n_token = session.sequence_length - offset
session.status = StatusCode.TRITON_STREAM_END
break
if 'errcode' in result:
Expand All @@ -630,31 +628,30 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
output_ids = result.as_numpy('output_ids')

session.sequence_length = sequence_length.squeeze()
new_token_length = sequence_length - session.sequence_offset
sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
new_token_length = new_token_length - 1
sequence_length = sequence_length - 1

output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
new_token_length = new_token_length.reshape(
(1, new_token_length.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))

if profile_generation:
yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling '
'token generation', new_token_length.squeeze())
'token generation', sequence_length.squeeze())
continue
output_str = postprocess(
output_ids[:, :, session.sequence_offset:],
new_token_length)
session.sequence_offset = session.sequence_length
output_str = postprocess(output_ids[:, :, offset:],
sequence_length)
text = output_str[0].decode()
if display:
print(text, end='', flush=True)
session.response += text
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text
yield (StatusCode.TRITON_STREAM_ING, session.response,
session.sequence_offset - sentinel)
sequence_length.squeeze())
except Exception as e:
logger.error(f'catch exception: {e}')

Expand Down

0 comments on commit 629a97f

Please sign in to comment.