Skip to content

Commit

Permalink
update profile scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 22, 2023
1 parent 42e57c8 commit e706ae4
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 59 deletions.
149 changes: 95 additions & 54 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,44 @@
from lmdeploy.turbomind import TurboMind


def infer(model, session_id: int, input_ids: str, output_seqlen: int,
def infer(model, session_id: int, input_ids: List, output_seqlen: int,
test_round: int, que: Queue):
chatbot = model.create_instance()
stats = []
for i in range(test_round):
start = time.perf_counter()
timestamps = []
tokens = []
for _ in range(test_round):
token_latency_stats = [0] * (output_seqlen + 1)
prev = time.perf_counter()
n_pre_token = 0
"""
The iterator provided by `chatbot.stream_infer` denotes the number of generated tokens so far,
which is represented by the variable `n_token`.
Please note that `n_token` is not a continuous value. In other words, during the iteration,
its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc.
So, it is quite difficult to get the latency of each generated token.
As a work-around, we set the latency `new-prev` of each iteration to the first token of
the new generated tokens, and leave the latency of the rest tokens being 0.
For example, in the first iteration, 5 tokens are generated.
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
for outputs in chatbot.stream_infer(session_id,
input_ids,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
res, token = outputs[0]
timestamps.append(time.perf_counter())
tokens.append(token)

# TODO: ignore first token
first_token_latency = np.round(timestamps[0] - start, 2)
if len(timestamps) == 1:
token_latency = np.round(timestamps[0] - start, 2)
token = tokens[0]
else:
token_latency = np.round(timestamps[-1] - timestamps[0], 2)
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
ignore_eos=True,
stream_output=True):
_, n_token = outputs[0]
now = time.perf_counter()
if n_pre_token != n_token:
token_latency_stats[n_pre_token] = np.round(now - prev, 3)
n_pre_token = n_token
prev = now

assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
que.put((session_id, stats))


Expand Down Expand Up @@ -128,33 +139,49 @@ def profile_throughput(model_path: str,
_end = time.perf_counter()
elapsed_time = _end - _start

stats = []
token_latency_stats = []
while not que.empty():
session_id, _stats = que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)

stats = np.array(stats).reshape(-1, 3)

first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_latency_min = np.min(stats[:, 2], axis=0)
token_latency_max = np.max(stats[:, 2], axis=0)
token_latency_ave = np.mean(stats[:, 2], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / np.sum(stats[:, 2],
axis=0) * concurrency
print(f'\n{"-" * 50}\nconcurrency: {concurrency}, input_tokens: '
f'{input_seqlen}, output_tokens: {output_seqlen}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
_, _stats = que.get()
token_latency_stats += _stats

# The shape is [concurrency*test_round, output_seqlen]
token_latency_stats = np.stack(token_latency_stats, axis=0)

first_token_latency_min = np.round(
np.min(token_latency_stats[:, 0], axis=0), 3)
first_token_latency_max = np.round(
np.max(token_latency_stats[:, 0], axis=0), 3)
first_token_latency_ave = np.round(
np.mean(token_latency_stats[:, 0], axis=0), 3)
token_latency_max = np.round(np.max(np.sum(token_latency_stats, axis=1)),
3)
token_latency_min = np.round(np.min(np.sum(token_latency_stats, axis=1)),
3)
token_latency_ave = np.round(np.mean(np.sum(token_latency_stats, axis=1)),
3)
# sort token_latency without the first token's latency
sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten())
percentiles = [
np.round(
sorted_token_latency[int(percent * len(sorted_token_latency))], 3)
for percent in [0.5, 0.75, 0.95, 0.99]
]

throughput = np.round(token_latency_stats.size / elapsed_time, 2)
print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n'
f'concurrency: {concurrency}, test_round: {test_round}\n'
f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): '
f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, '
f'{token_latency_ave:.2f}s\n'
f'throughput: {throughput:.2f} token/s\n{"-" * 50}')
return tm_model.model_name, throughput, tm_model.gpu_count
f'{first_token_latency_min}s, {first_token_latency_max}s, '
f'{first_token_latency_ave}s\ntotal_token latency(min, max, ave): '
f'{token_latency_min}s, {token_latency_max}s, '
f'{token_latency_ave}s\n'
f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n'
f'throughput: {throughput} token/s\n{"-" * 50}')
return tm_model.model_name, \
[first_token_latency_min, first_token_latency_max,
first_token_latency_ave], \
percentiles, throughput, tm_model.gpu_count


class MemoryMonitor:
Expand Down Expand Up @@ -235,6 +262,8 @@ class ProfileResult:
batch: int
prompt_tokens: int
completion_tokens: int
first_token_latency: List
percentiles: List
throughput_per_proc: float
throughput_per_node: float
mem_per_proc: float
Expand All @@ -258,22 +287,26 @@ def parse_args():
type=int,
help='how many requests launched concurrently. One-to-one'
'correspondence with completion-tokens',
default=[64, 512, 512, 1024])
default=[1, 512, 512, 1024])
parser.add_argument('--completion-tokens',
nargs='+',
type=int,
help='how many tokens to be generated. One-to-one'
'correspondence with prompt-tokens',
default=[512, 512, 1024, 1024])
parser.add_argument('--tp', type=int, help='Tensor parallel', default=1)
parser.add_argument('--dst-csv',
parser.add_argument('--csv',
type=str,
help='Where to save the result.',
default='profile_generation.csv')
parser.add_argument('--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
parser.add_argument('--test-round',
type=int,
help='number of test rounds',
default=10)
args = parser.parse_args()
return args

Expand All @@ -292,9 +325,11 @@ def main():
concurrency=batch,
input_seqlen=prompt_tokens,
output_seqlen=completion_tokens,
tp=args.tp)
tp=args.tp,
test_round=args.test_round)
output = Pool(1).map(profile_target, (args.model_path, ))
model_name, throughput_per_proc, tp = output[0]
model_name, first_token_latency, percentiles, \
throughput_per_proc, tp = output[0]
time.sleep(5) # wait a while for releasing GPU mem
memory = MemoryMonitor.terminate()
device_count = MemoryMonitor.device_count.value
Expand All @@ -303,25 +338,31 @@ def main():
batch=batch,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
first_token_latency=first_token_latency,
percentiles=percentiles,
throughput_per_proc=throughput_per_proc,
throughput_per_node=throughput_per_proc / tp *
device_count,
mem_per_proc=memory,
mem_per_gpu=memory / tp,
mem_per_node=memory / tp * device_count))
with open(args.dst_csv, 'w') as csvfile:
with open(args.csv, 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([
'batch', 'prompt_tokens', 'completion_tokens',
'throughput_per_proc(token/s)', 'throughput_per_node(token/s)',
'mem_per_proc(GB)', 'mem_per_gpu(GB)', 'mem_per_node(GB)'
'1st_token_latency(min)(s)', '1st_token_latency(max)(s)',
'1st_token_latency(ave)(s)', 'percentile50(s)', 'percentile75(s)',
'percentile95(s)', 'percentile99(s)', 'throughput(token/s)',
'mem_per_proc(GB)', 'mem_per_gpu(GB)'
])
for re in results:
writer.writerow([
re.batch, re.prompt_tokens, re.completion_tokens,
f'{re.throughput_per_proc:.2f}',
f'{re.throughput_per_node:.2f}', f'{re.mem_per_proc:.2f}',
f'{re.mem_per_gpu:.2f}', f'{re.mem_per_node:.2f}'
re.first_token_latency[0], re.first_token_latency[1],
re.first_token_latency[2], re.percentiles[0],
re.percentiles[1], re.percentiles[2], re.percentiles[3],
f'{re.throughput_per_proc:.2f}', f'{re.mem_per_proc:.2f}',
f'{re.mem_per_gpu:.2f}'
])


Expand Down
4 changes: 3 additions & 1 deletion benchmark/profile_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def main(server_addr: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 1000,
stream_output: bool = False):
stream_output: bool = False,
seed: int = 0):
random.seed(seed)
api_url = server_addr + '/v1/chat/interactive'
warmup(api_url, concurrency, session_len - 1, 4, stream_output)
req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/profile_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def main(tritonserver_addr: str,
dataset_path: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 1000):
samples: int = 1000,
seed: int = 0):
random.seed(seed)
warmup(tritonserver_addr, concurrency, session_len - 1)
req_que = mp.Queue()
res_que = mp.Queue()
Expand Down
8 changes: 5 additions & 3 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,16 @@ def process_request(self,
def main(dataset: str,
model_path: str,
concurrency: int = 1,
num_prompts: int = 1000,
samples: int = 1000,
tp: int = 1,
stream_output: bool = True):
stream_output: bool = True,
seed: int = 0):
random.seed(seed)

engine = Engine(model_path, tp=tp)
tokenizer = engine.tokenizer

requests = sample_requests(dataset, num_prompts, tokenizer)
requests = sample_requests(dataset, samples, tokenizer)

engine.process_request(requests, concurrency, stream_output)

Expand Down

0 comments on commit e706ae4

Please sign in to comment.