From 41d1f67704a3761423131f48c357b957452a00a9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 15 Jul 2024 19:56:55 -0700 Subject: [PATCH] Fix flush cache (#627) --- benchmark/latency_throughput/README.md | 10 +++-- benchmark/latency_throughput/bench_one.py | 4 +- python/sglang/bench_latency.py | 37 ++++++++++--------- .../srt/managers/controller/manager_single.py | 1 + python/sglang/srt/memory_pool.py | 2 + 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/benchmark/latency_throughput/README.md b/benchmark/latency_throughput/README.md index cabbb6eced..31433a0301 100644 --- a/benchmark/latency_throughput/README.md +++ b/benchmark/latency_throughput/README.md @@ -1,9 +1,8 @@ - # Benchmark Latency and Throughput ## SGLang -### Launch server +### Launch a server ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 ``` @@ -33,6 +32,11 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 ``` +### Profile with Nsight +1. To profile a single batch, use `nsys profile --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512` +2. To profile a server, use `nsys profile --cuda-graph-trace=node python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B`. + + ## Other baselines ### vLLM @@ -64,4 +68,4 @@ python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat ``` python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 -``` \ No newline at end of file +``` diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index 0bb26ee155..b390c44a53 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -102,8 +102,8 @@ def run_one_batch_size(bs): output_throughput = bs * max_new_tokens / latency overall_throughput = bs * (input_len + output_len) / latency print(f"latency: {latency:.2f} s") - print(f"decode throughput: {output_throughput:.2f} token/s") - print(f"overall throughput: {overall_throughput:.2f} token/s") + print(f"output throughput: {output_throughput:.2f} token/s") + print(f"(input + output) throughput: {overall_throughput:.2f} token/s") with open("results.jsonl", "a") as fout: res = { diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index c4c6d0ecf8..45d23b63da 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -284,23 +284,26 @@ def main(server_args, bench_args): else: work_func = latency_test - workers = [] - for tp_rank in range(server_args.tp_size): - proc = multiprocessing.Process( - target=work_func, - args=( - server_args, - bench_args, - tp_rank, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - proc.terminate() + if server_args.tp_size == 1: + work_func(server_args, bench_args, 0) + else: + workers = [] + for tp_rank in range(server_args.tp_size): + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + bench_args, + tp_rank, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + proc.terminate() if __name__ == "__main__": diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index c2cb922fca..d4186d4845 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -96,6 +96,7 @@ class ControllerSingle: def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict): # Parse args self.server_args = server_args + self.tp_procs = [] # Init communication context = zmq.Context(2) diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 6b5b411152..c0a384ccca 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -98,6 +98,8 @@ def free(self, free_index: torch.Tensor): self.can_use_mem_size += len(free_index) def clear(self): + self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) + self.mem_state.fill_(True) self.can_use_mem_size = self.size