@@ -338,11 +338,11 @@ def __init__(
338338 if sp_stream is not None :
339339 self .overlap_handles = {}
340340 self .sp_overlap_comm = True
341- self .dafult_stream = get_accelerator ().default_stream ()
341+ self .default_stream = get_accelerator ().default_stream ()
342342
343343 def layer_sync (self , layer ):
344344 if self .sp_overlap_comm and hasattr (layer , 'done_event' ):
345- self .dafult_stream .wait_event (layer .done_event )
345+ self .default_stream .wait_event (layer .done_event )
346346
347347 def forward (self ,
348348 query : Tensor ,
@@ -374,7 +374,7 @@ def bwd_hook(layer_type):
374374 def pre_hook_fun (grad ):
375375 type = 'd' + layer_type
376376 self .overlap_handles [type + '_work' ].wait ()
377- self .sp_stream .wait_stream (self .dafult_stream )
377+ self .sp_stream .wait_stream (self .default_stream )
378378 all2all_output = self .overlap_handles [type + '_grad' ]
379379 grad = list (grad )
380380 grad [0 ] = self .overlap_handles [type + '_post_all2all_func' ](all2all_output )
@@ -389,7 +389,7 @@ def pre_hook_fun(grad):
389389 key_layer = _SeqAllToAll .apply (self .spg , key , self .scatter_idx , self .gather_idx , batch_dim_idx , None ,
390390 self .overlap_handles , 'k' )
391391 if self .sp_overlap_comm :
392- self .dafult_stream .wait_stream (self .sp_stream )
392+ self .default_stream .wait_stream (self .sp_stream )
393393
394394 value_layer = _SeqAllToAll .apply (self .spg , value , self .scatter_idx , self .gather_idx , batch_dim_idx , None ,
395395 self .overlap_handles , 'v' )
0 commit comments