Skip to content

Commit

Permalink
Fix flush cache (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 16, 2024
1 parent 56f5fc4 commit 41d1f67
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
10 changes: 7 additions & 3 deletions benchmark/latency_throughput/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

# Benchmark Latency and Throughput

## SGLang

### Launch server
### Launch a server
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Expand Down Expand Up @@ -33,6 +32,11 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
```

### Profile with Nsight
1. To profile a single batch, use `nsys profile --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
2. To profile a server, use `nsys profile --cuda-graph-trace=node python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B`.


## Other baselines

### vLLM
Expand Down Expand Up @@ -64,4 +68,4 @@ python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat

```
python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
```
```
4 changes: 2 additions & 2 deletions benchmark/latency_throughput/bench_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def run_one_batch_size(bs):
output_throughput = bs * max_new_tokens / latency
overall_throughput = bs * (input_len + output_len) / latency
print(f"latency: {latency:.2f} s")
print(f"decode throughput: {output_throughput:.2f} token/s")
print(f"overall throughput: {overall_throughput:.2f} token/s")
print(f"output throughput: {output_throughput:.2f} token/s")
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")

with open("results.jsonl", "a") as fout:
res = {
Expand Down
37 changes: 20 additions & 17 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,23 +284,26 @@ def main(server_args, bench_args):
else:
work_func = latency_test

workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)

for proc in workers:
proc.join()

proc.terminate()
if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)

for proc in workers:
proc.join()

proc.terminate()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/controller/manager_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class ControllerSingle:
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
# Parse args
self.server_args = server_args
self.tp_procs = []

# Init communication
context = zmq.Context(2)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def free(self, free_index: torch.Tensor):
self.can_use_mem_size += len(free_index)

def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)

self.mem_state.fill_(True)
self.can_use_mem_size = self.size

Expand Down

0 comments on commit 41d1f67

Please sign in to comment.