Skip to content

Commit 3676f96

Browse files
zhiyi Huzhiyi Hu
authored andcommitted
minor modifications
1 parent 4136519 commit 3676f96

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_internode_hook.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def large_gemm_with_hook(hook):
306306
# noinspection PyShadowingNames
307307
def test_dispatch_hook(x, config, handle, return_recv_hook):
308308
_, _, _, _, _, _, hook = \
309-
buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank)
309+
buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook)
310310
large_gemm_with_hook(hook) if return_recv_hook else None
311311
torch.cuda.synchronize()
312312

@@ -318,7 +318,7 @@ def test_combine_hook(x, config, handle, return_recv_hook):
318318

319319
def test_dispatch_combine_hook(x, config, handle, return_recv_hook):
320320
recv_x, _, _, _, _, _, hook = \
321-
buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook, num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank)
321+
buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook)
322322
large_gemm_with_hook(hook) if return_recv_hook else None
323323

324324
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
@@ -471,7 +471,7 @@ def test_func_native(x, config, handle):
471471
# Tune combine performance
472472
best_time, best_results = 1e10, None
473473
for nvl_chunk_size in range(1, 13, 1):
474-
for rdma_chunk_size in range(8, 33, 4):
474+
for rdma_chunk_size in range(12, 33, 4):
475475
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
476476
tune_args = {'x': recv_x, 'handle': handle_native, 'config': config}
477477
avg_t = bench(lambda: buffer.combine(**tune_args))[0]

0 commit comments

Comments
 (0)