From f4e885b7c3fbed59ce48c7c3046e628e7a58d396 Mon Sep 17 00:00:00 2001 From: Mingyi Date: Sun, 7 Jul 2024 19:35:22 -0700 Subject: [PATCH] Reduce number of workspaces (#601) --- python/sglang/srt/layers/radix_attention.py | 8 ++++---- python/sglang/srt/managers/controller/model_runner.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index b9ea6ad85be..6ee4a31a144 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -4,6 +4,8 @@ import torch from torch import nn +from flashinfer.cascade import merge_state + from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd @@ -95,8 +97,6 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - self.store_kv_cache(k, v, input_metadata) - o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), @@ -117,10 +117,10 @@ def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): logits_soft_cap=self.logit_cap, ) - from flashinfer.cascade import merge_state - o, _ = merge_state(o1, s1, o2, s2) + self.store_kv_cache(k, v, input_metadata) + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: torch.cuda.synchronize() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 0bb869cf793..4eeaeac76ec 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -408,7 +408,7 @@ def init_flash_infer(self): use_tensor_cores = False workspace_buffers = torch.empty( - 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" + 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" ) self.flashinfer_prefill_wrapper_ragged = ( BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD") @@ -417,7 +417,7 @@ def init_flash_infer(self): workspace_buffers[1], "NHD" ) self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores + workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores ) else: self.flashinfer_prefill_wrapper_ragged = (