Skip to content

Commit

Permalink
Refactor SchedulePolicy to improve code organization (#2571)
Browse files Browse the repository at this point in the history
  • Loading branch information
libratiger authored Jan 3, 2025
1 parent f5d0865 commit bdb3929
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 90 deletions.
249 changes: 159 additions & 90 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set, Union

import torch

Expand Down Expand Up @@ -50,124 +50,193 @@
)


class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""

LPM = "lpm" # longest prefix match
DFS_WEIGHT = "dfs-weight" # depth-first search weighting


class CacheAgnosticPolicy(Enum):
"""Scheduling policies that are not aware of the tree cache."""

FCFS = "fcfs" # first come first serve
LOF = "lof" # longest output first
RANDOM = "random"


class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]

self.policy = policy
def __init__(self, policy: str, tree_cache: BasePrefixCache):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache

# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
)

def calc_priority(self, waiting_queue: List[Req]):
if len(waiting_queue) > 128 and self.policy == "lpm":
# Turn off the expensive prefix matching and sorting when the #queue is large.
policy = "fcfs"
else:
policy = self.policy
def calc_priority(self, waiting_queue: List[Req]) -> bool:
policy = self._determine_active_policy(waiting_queue)

# Compute matched prefix length
prefix_computed = False
if policy == "lpm" or policy == "dfs-weight":
# rid to deprioritize in the current run for in-batch prefix caching.
temporary_deprioritized = set()
self.waiting_queue_radix_tree.reset()

for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
if isinstance(policy, CacheAwarePolicy):
prefix_computed = True
temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")

# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized.add(r.rid)
else:
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)
return prefix_computed

prefix_computed = True
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy

def _validate_and_adjust_policy(
self, policy: str, tree_cache: BasePrefixCache
) -> Policy:
"""
Validates the policy and adjusts it if necessary based on tree cache settings.
"""
try:
policy_enum = CacheAwarePolicy(policy)
if tree_cache.disable:
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS
return policy_enum
except ValueError:
try:
return CacheAgnosticPolicy(policy)
except ValueError:
raise ValueError(f"Unknown schedule_policy: {policy=}")

def _compute_prefix_matches(
self, waiting_queue: List[Req], policy: CacheAwarePolicy
) -> Set[int]:
"""
Computes and caches the matching prefixes for requests in the waiting queue,
and handles in-batch prefix caching logic.
"""
temporary_deprioritized: Set[int] = set()
self.waiting_queue_radix_tree.reset()

for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)

if policy == "lpm":
# Longest Prefix Match
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized.add(r.rid)
else:
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)
return temporary_deprioritized

@staticmethod
def _sort_by_longest_prefix(
waiting_queue: List[Req], temporary_deprioritized: Set[int]
) -> None:
"""Sorts the waiting queue based on the longest prefix match."""
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
)
elif policy == "fcfs":
# first come first serve
pass
elif policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
elif policy == "random":
random.shuffle(waiting_queue)
elif policy == "dfs-weight":
# Experimental policy based on custom weights
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)

node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight)

waiting_queue.clear()
self.get_dfs_priority(
self.tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")
)

return prefix_computed
@staticmethod
def _sort_by_dfs_weight(
waiting_queue: List[Req], tree_cache: BasePrefixCache
) -> None:
"""Sorts the waiting queue based on a depth-first search weighting."""
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)

node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)

waiting_queue.clear()
SchedulePolicy._get_dfs_priority(
tree_cache.root_node,
node_to_weight,
last_node_to_reqs,
waiting_queue,
)

@staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)

def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
@staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)

@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)
SchedulePolicy._calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]

def get_dfs_priority(
self,
@staticmethod
def _get_dfs_priority(
cur_node: TreeNode,
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
):
) -> None:
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
SchedulePolicy._get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q
)
q.extend(last_node_to_reqs[cur_node])


Expand Down
52 changes: 52 additions & 0 deletions test/srt/test_schedule_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest

from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_policy import (
CacheAgnosticPolicy,
CacheAwarePolicy,
SchedulePolicy,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams


class TestSchedulePolicy(unittest.TestCase):

def setUp(self):
self.tree_cache = RadixCache(None, None, False)

def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(policy="lpm", tree_cache=self.tree_cache)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)

def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(policy="fcfs", tree_cache=self.tree_cache)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)

def test_init_with_unknown_policy(self):
with self.assertRaises(ValueError):
SchedulePolicy(policy="invalid", tree_cache=self.tree_cache)

def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True)
policy = SchedulePolicy(policy="lpm", tree_cache=disabled_tree_cache)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)

def test_calc_priority_fcfs(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]

policy = SchedulePolicy(policy="fcfs", tree_cache=tree_cache)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)


if __name__ == "__main__":
unittest.main()

0 comments on commit bdb3929

Please sign in to comment.