1+ import argparse
12import os
23import time
34import torch
1112import test_low_latency
1213
1314
14- def test_main (num_sms : int , local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int , buffer : deep_ep .Buffer , group : dist .ProcessGroup , args ):
15+ # noinspection PyShadowingNames
16+ def test_main (args : argparse .Namespace , num_sms : int ,
17+ local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int ,
18+ buffer : deep_ep .Buffer , group : dist .ProcessGroup ):
1519 # Settings
16- num_tokens = args .num_tokens
17- hidden = args .hidden
18- num_topk_groups = args .num_topk_groups
19- num_topk = args .num_topk
20- num_experts = args .num_experts
20+ num_tokens , hidden = args .num_tokens , args .hidden
21+ num_topk_groups , num_topk , num_experts = args .num_topk_groups , args .num_topk , args .num_experts
2122
2223 assert num_experts % num_ranks == 0 and num_local_ranks == 8
2324 if local_rank == 0 :
@@ -223,29 +224,28 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
223224 print ('' , flush = True )
224225
225226
226- # noinspection PyUnboundLocalVariable
227- def test_loop (local_rank : int , num_local_ranks : int , args ):
227+ # noinspection PyUnboundLocalVariable,PyShadowingNames
228+ def test_loop (local_rank : int , num_local_ranks : int , args : argparse . Namespace ):
228229 num_nodes = int (os .getenv ('WORLD_SIZE' , 1 ))
229230 rank , num_ranks , group = init_dist (local_rank , num_local_ranks )
230- test_ll_compatibility = os .getenv ('EP_TEST_LL_COMPATIBILITY' , False )
231- if test_ll_compatibility :
231+ if args .test_ll_compatibility :
232232 ll_num_tokens , ll_hidden , ll_num_experts , ll_num_topk = 16 , 5120 , 256 , 9
233233
234234 num_sms = 24
235- num_qps_per_rank = max (num_sms , ll_num_experts // num_ranks if test_ll_compatibility else 0 )
235+ num_qps_per_rank = max (num_sms , ll_num_experts // num_ranks if args . test_ll_compatibility else 0 )
236236
237- buffer = deep_ep .Buffer (group , int (1e9 ), int (1e9 ), low_latency_mode = test_ll_compatibility ,
237+ buffer = deep_ep .Buffer (group , int (1e9 ), int (1e9 ), low_latency_mode = args . test_ll_compatibility ,
238238 num_qps_per_rank = num_qps_per_rank )
239239 assert num_local_ranks == 8 and num_ranks > 8
240240 torch .manual_seed (rank )
241241
242242 for i in (num_sms , ):
243- test_main (i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group , args )
243+ test_main (args , i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group )
244244 if local_rank == 0 :
245245 print ('' , flush = True )
246246
247247 # Test compatibility with low latency functions
248- if test_ll_compatibility :
248+ if args . test_ll_compatibility :
249249 buffer .clean_low_latency_buffer (ll_num_tokens , ll_hidden , ll_num_experts )
250250 test_low_latency .test_main (ll_num_tokens , ll_hidden , ll_num_experts , ll_num_topk , rank , num_ranks , group , buffer , seed = 1 )
251251
@@ -255,30 +255,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
255255
256256
257257if __name__ == '__main__' :
258- import argparse
259- parser = argparse .ArgumentParser (description = 'Test internode expert parallel' )
258+ parser = argparse .ArgumentParser (description = 'Test internode EP kernels' )
260259 parser .add_argument ('--num-processes' , type = int , default = 8 ,
261260 help = 'Number of processes to spawn (default: 8)' )
262261 parser .add_argument ('--num-tokens' , type = int , default = 4096 ,
263262 help = 'Number of tokens (default: 4096)' )
264263 parser .add_argument ('--hidden' , type = int , default = 7168 ,
265264 help = 'Hidden dimension size (default: 7168)' )
266265 parser .add_argument ('--num-topk-groups' , type = int , default = None ,
267- help = 'Number of top-k groups (default: min(num_nodes, 4))' )
266+ help = 'Number of top-k groups (default: ` min(num_nodes, 4)` )' )
268267 parser .add_argument ('--num-topk' , type = int , default = 8 ,
269268 help = 'Number of top-k experts (default: 8)' )
270- parser .add_argument ('--num-experts' , type = int , default = None ,
271- help = 'Number of experts (default: calculated as (256 // num_ranks) * num_ranks)' )
269+ parser .add_argument ('--num-experts' , type = int , default = 256 ,
270+ help = 'Number of experts (default: 256' )
271+ parser .add_argument ('--test-ll-compatibility' , action = 'store_true' ,
272+ help = 'whether to test compatibility with low-latency kernels' )
272273 args = parser .parse_args ()
273274
274- # Set default num_topk_groups if not provided
275+ # Set default ` num_topk_groups` if not provided
275276 if args .num_topk_groups is None :
276277 num_nodes = int (os .getenv ('WORLD_SIZE' , 1 ))
277278 args .num_topk_groups = min (num_nodes , 4 )
278279
279- # Set default num_experts if not provided
280- if args .num_experts is None :
281- args .num_experts = (256 // args .num_processes ) * args .num_processes
282-
283280 num_processes = args .num_processes
284281 torch .multiprocessing .spawn (test_loop , args = (num_processes , args ), nprocs = num_processes )
0 commit comments