Skip to content

Commit

Permalink
Reduce number of workspaces (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
wisclmy0611 authored Jul 8, 2024
1 parent 0877f1e commit f4e885b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import nn

from flashinfer.cascade import merge_state

from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
Expand Down Expand Up @@ -95,8 +97,6 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
return o

def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)

o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
Expand All @@ -117,10 +117,10 @@ def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
logits_soft_cap=self.logit_cap,
)

from flashinfer.cascade import merge_state

o, _ = merge_state(o1, s1, o2, s2)

self.store_kv_cache(k, v, input_metadata)

if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize()

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def init_flash_infer(self):
use_tensor_cores = False

workspace_buffers = torch.empty(
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
Expand All @@ -417,7 +417,7 @@ def init_flash_infer(self):
workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
)
else:
self.flashinfer_prefill_wrapper_ragged = (
Expand Down

0 comments on commit f4e885b

Please sign in to comment.