-
Notifications
You must be signed in to change notification settings - Fork 663
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor SchedulePolicy to improve code organization (#2571)
- Loading branch information
1 parent
f5d0865
commit bdb3929
Showing
2 changed files
with
211 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import unittest | ||
|
||
from sglang.srt.managers.schedule_batch import Req | ||
from sglang.srt.managers.schedule_policy import ( | ||
CacheAgnosticPolicy, | ||
CacheAwarePolicy, | ||
SchedulePolicy, | ||
) | ||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode | ||
from sglang.srt.sampling.sampling_params import SamplingParams | ||
|
||
|
||
class TestSchedulePolicy(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.tree_cache = RadixCache(None, None, False) | ||
|
||
def test_init_with_cache_aware_policy(self): | ||
policy = SchedulePolicy(policy="lpm", tree_cache=self.tree_cache) | ||
self.assertEqual(policy.policy, CacheAwarePolicy.LPM) | ||
|
||
def test_init_with_cache_agnostic_policy(self): | ||
policy = SchedulePolicy(policy="fcfs", tree_cache=self.tree_cache) | ||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) | ||
|
||
def test_init_with_unknown_policy(self): | ||
with self.assertRaises(ValueError): | ||
SchedulePolicy(policy="invalid", tree_cache=self.tree_cache) | ||
|
||
def test_init_with_disabled_cache(self): | ||
disabled_tree_cache = RadixCache(None, None, disable=True) | ||
policy = SchedulePolicy(policy="lpm", tree_cache=disabled_tree_cache) | ||
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) | ||
|
||
def test_calc_priority_fcfs(self): | ||
tree_cache = RadixCache(None, None, False) | ||
waiting_queue = [ | ||
Req(1, "a b", [1, 2], SamplingParams()), | ||
Req(3, "a b c", [1, 2, 3], SamplingParams()), | ||
Req(2, "a", [1], SamplingParams()), | ||
] | ||
|
||
policy = SchedulePolicy(policy="fcfs", tree_cache=tree_cache) | ||
policy.calc_priority(waiting_queue) | ||
# Check if FCFS keeps the original order | ||
self.assertEqual(waiting_queue[0].rid, 1) | ||
self.assertEqual(waiting_queue[1].rid, 3) | ||
self.assertEqual(waiting_queue[2].rid, 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |