Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 27, 2023
1 parent 533af8a commit 1f2c414
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 34 deletions.
9 changes: 5 additions & 4 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def infer(model, session_id: int, input_ids: List, output_seqlen: int,
prev = time.perf_counter()
n_pre_token = 0
"""
The iterator provided by `chatbot.stream_infer` denotes the number of generated tokens so far,
The iterator provided by `stream_infer` denotes the number of generated tokens so far,
which is represented by the variable `n_token`.
Please note that `n_token` is not a continuous value. In other words, during the iteration,
its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc.
Expand Down Expand Up @@ -275,14 +275,15 @@ class ProfileResult:

def parse_args():
parser = argparse.ArgumentParser(description='Regression Test')
parser.add_argument('--model-path',
parser.add_argument('model_path',
type=str,
help='benchmark test model path')
help='the path of the model in localhost or '
'the repo_id of the model in huggingface.co')
parser.add_argument('--concurrency',
nargs='+',
type=int,
help='how many requests launched concurrently',
default=[1, 32, 64, 128])
default=[1, 16, 32, 64])
parser.add_argument(
'--prompt-tokens',
nargs='+',
Expand Down
53 changes: 28 additions & 25 deletions benchmark/profile_restful_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os
import random
import time
from queue import Queue
Expand All @@ -8,8 +7,8 @@

import fire
import numpy as np
import requests

from lmdeploy.serve.openai.api_client import get_streaming_response
from lmdeploy.tokenizer import Tokenizer


Expand Down Expand Up @@ -60,15 +59,10 @@ def __init__(self,
server_addr: str,
tokenzier_path: str,
temperature: float = 0.8,
top_k: int = 1,
top_p: float = 1.0):
self.tokenizer = Tokenizer(tokenzier_path)
# We choose `v1/compeletions` API to profile the performance since
# it won't decorate prompt according to the served model's
# chat template
self.api_url = server_addr + '/v1/completions'
self.api_url = server_addr + '/v1/chat/completions'
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p

def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
Expand All @@ -80,15 +74,31 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
timestamps = []
tokens = []
timestamps.append(time.perf_counter())
for _, n_token, _ in get_streaming_response(
prompt,
self.api_url,
session_id,
request_output_len=output_seqlen,
ignore_eos=True,
stream=stream_output):
headers = {'content-type': 'application/json'}
pload = {
'model': 'llama',
'messages': prompt,
'temperature': self.temperature,
'top_p': self.top_p,
'n': 1,
'max_tokens': output_seqlen,
'stream': stream_output,
'session_id': session_id,
'ignore_eos': True,
}
response = requests.post(self.api_url,
headers=headers,
json=pload,
stream=stream_output)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
timestamps.append(time.perf_counter())
tokens.append(n_token)
if chunk:
data = json.loads(chunk.decode('utf-8'))
n_token = data.pop('tokens', 0)
tokens.append(n_token)

first_token_latency = np.round(timestamps[1] - timestamps[0], 3)
token_latency = np.round(timestamps[-1] - timestamps[0], 3)
completion_tokens = tokens[-1]
Expand Down Expand Up @@ -184,11 +194,9 @@ def main(server_addr: str,
dataset: str,
concurrency: int = 1,
num_prompts: int = 1000,
top_k: int = 1,
top_p: float = 1.0,
temperature: float = 0.8,
stream_output: bool = False,
log_level: str = 'INFO',
seed: int = 0):
"""Benchmark the request througput of api server.
Expand All @@ -199,28 +207,23 @@ def main(server_addr: str,
concurrency (int, optional): Number of working threads to process the sampled prompts.
Defaults to 1.
num_prompts (int, optional): Number of prompts to process. Defaults to 1000.
top_k (int, optional): The number of highest probability vocabulary tokens
to keep for top-k-filtering. Defaults to 1.
top_p (float, optional): the set of most probable tokens with
probabilities that add up to top_p or higher
are kept for generation. Defaults to 1.0.
temperature (float, optional): The value used to modulate the next token probabilities.
Defaults to 0.8.
stream_output (bool, optional): Indicator for streaming output. Defaults to True.
log_level(str, optional): The log level. Defaults to INFO
seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0.
""" # noqa
if not server_addr.startswith('http://'):
print(f'[WARNING] server_addr of the api_server should '
f'start with http://, but got {server_addr}')
server_addr = 'http://' + server_addr
f'start with "http://", but got "{server_addr}"')
server_addr = 'http://' + server_addr.strip()

random.seed(seed)
os.environ['TM_LOG_LEVEL'] = log_level

engine = Engine(server_addr,
tokenizer_path,
top_k=top_k,
top_p=top_p,
temperature=temperature)

Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ async def generate(
len(input_ids), tokens, finish_reason)
response_size = tokens

# `response_size` might be note updated since
# ` if response.endswith('�')`
if response_size != tokens:
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop:
Expand Down
10 changes: 6 additions & 4 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
EmbeddingsRequest, ErrorResponse, GenerateRequest, GenerateResponse,
ModelCard, ModelList, ModelPermission, UsageInfo)

os.environ['TM_LOG_LEVEL'] = 'ERROR'


class VariableInterface:
"""A IO interface maintaining variables."""
Expand Down Expand Up @@ -476,12 +474,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
def main(model_path: str,
server_name: str = '0.0.0.0',
server_port: int = 23333,
instance_num: int = 32,
instance_num: int = 64,
tp: int = 1,
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*'],
log_level: str = 'INFO',
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand All @@ -496,7 +495,10 @@ def main(model_path: str,
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
"""
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
""" # noqa E501
os.environ['TM_LOG_LEVEL'] = log_level

if allow_origins:
app.add_middleware(
CORSMiddleware,
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
outputs = []
for output, len_ in zip(output_ids, sequence_length):
output, len_ = output, len_.item()
if len(output) > 0 and output[-1].item() == self.eos_id:
if len(output) > 0 and output[-1].item(
) == self.eos_id and not ignore_eos:
outputs.append((output[:-1], len_ - 1))
elif len(output) > 0 and output[-1].item() in self.stop_tokens:
outputs.append((output[:-1], len_))
Expand Down

0 comments on commit 1f2c414

Please sign in to comment.