Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Jul 22, 2023
1 parent aa528ba commit 07d5ed5
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions benchmark/profile_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,25 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):


def warmup(tritonserver_addr: str,
model_name: str,
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...')

def _infer(_chatbot, session_id):
for _ in range(warmup_round):
for _, _, _ in chatbot.stream_infer(
for _, _, _ in _chatbot.stream_infer(
session_id,
prompt='',
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
continue
chatbot.reset_session()
_chatbot.reset_session()

_start = time.perf_counter()
chatbots = [
Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True) for _ in range(concurrency)
]
Expand All @@ -90,8 +86,8 @@ def _infer(_chatbot, session_id):
print(f'end warmup, elapsed time: {round(_end - _start, 2)} s')


def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str,
samples: int, test_round: int, session_len: int):
def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
test_round: int, session_len: int):
start = time.perf_counter()
with open(dataset_path) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -134,24 +130,20 @@ def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str,


def main(tritonserver_addr: str,
model_name: str,
tokenizer_path: str,
dataset_path: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 2000,
samples: int = 1000,
test_round: int = 1):
warmup(tritonserver_addr, model_name, concurrency, session_len,
session_len)
req_que = read_dataset(tritonserver_addr, tokenizer_path, dataset_path,
samples, test_round, session_len)
warmup(tritonserver_addr, concurrency, session_len - 1)
req_que = read_dataset(tokenizer_path, dataset_path, samples, test_round,
session_len)
res_que = mp.Queue()
procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
display=False,
profile_serving=True,
ignore_eos=True)
Expand Down

0 comments on commit 07d5ed5

Please sign in to comment.