@@ -306,7 +306,7 @@ def large_gemm_with_hook(hook):
306
306
# noinspection PyShadowingNames
307
307
def test_dispatch_hook (x , config , handle , return_recv_hook ):
308
308
_ , _ , _ , _ , _ , _ , hook = \
309
- buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook , num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
309
+ buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook )
310
310
large_gemm_with_hook (hook ) if return_recv_hook else None
311
311
torch .cuda .synchronize ()
312
312
@@ -318,7 +318,7 @@ def test_combine_hook(x, config, handle, return_recv_hook):
318
318
319
319
def test_dispatch_combine_hook (x , config , handle , return_recv_hook ):
320
320
recv_x , _ , _ , _ , _ , _ , hook = \
321
- buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook , num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
321
+ buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook )
322
322
large_gemm_with_hook (hook ) if return_recv_hook else None
323
323
324
324
recv_x = per_token_cast_back (* recv_x ) if isinstance (recv_x , tuple ) else recv_x
@@ -471,7 +471,7 @@ def test_func_native(x, config, handle):
471
471
# Tune combine performance
472
472
best_time , best_results = 1e10 , None
473
473
for nvl_chunk_size in range (1 , 13 , 1 ):
474
- for rdma_chunk_size in range (8 , 33 , 4 ):
474
+ for rdma_chunk_size in range (12 , 33 , 4 ):
475
475
config = deep_ep .Config (num_sms , nvl_chunk_size , nvl_buffer_size , rdma_chunk_size , rdma_buffer_size )
476
476
tune_args = {'x' : recv_x , 'handle' : handle_native , 'config' : config }
477
477
avg_t = bench (lambda : buffer .combine (** tune_args ))[0 ]
0 commit comments