diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py index 8a9ac8e96..9b9387d04 100644 --- a/benchmark/profile_serving.py +++ b/benchmark/profile_serving.py @@ -53,29 +53,25 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): def warmup(tritonserver_addr: str, - model_name: str, concurrency: int, - session_len: int, output_seqlen: int, warmup_round: int = 4): print('start to warmup ...') def _infer(_chatbot, session_id): for _ in range(warmup_round): - for _, _, _ in chatbot.stream_infer( + for _, _, _ in _chatbot.stream_infer( session_id, prompt='', request_output_len=output_seqlen, sequence_start=True, sequence_end=True): continue - chatbot.reset_session() + _chatbot.reset_session() _start = time.perf_counter() chatbots = [ Chatbot(tritonserver_addr=tritonserver_addr, - model_name=model_name, - session_len=session_len, ignore_eos=True, profile_generation=True) for _ in range(concurrency) ] @@ -90,8 +86,8 @@ def _infer(_chatbot, session_id): print(f'end warmup, elapsed time: {round(_end - _start, 2)} s') -def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str, - samples: int, test_round: int, session_len: int): +def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, + test_round: int, session_len: int): start = time.perf_counter() with open(dataset_path) as f: dataset = json.load(f) @@ -134,24 +130,20 @@ def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str, def main(tritonserver_addr: str, - model_name: str, tokenizer_path: str, dataset_path: str, concurrency: int = 1, session_len: int = 2048, - samples: int = 2000, + samples: int = 1000, test_round: int = 1): - warmup(tritonserver_addr, model_name, concurrency, session_len, - session_len) - req_que = read_dataset(tritonserver_addr, tokenizer_path, dataset_path, - samples, test_round, session_len) + warmup(tritonserver_addr, concurrency, session_len - 1) + req_que = read_dataset(tokenizer_path, dataset_path, samples, test_round, + session_len) res_que = mp.Queue() procs = [] _start = time.perf_counter() for i in range(concurrency): chatbot = Chatbot(tritonserver_addr=tritonserver_addr, - model_name=model_name, - session_len=session_len, display=False, profile_serving=True, ignore_eos=True)