@@ -109,23 +109,23 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
109109        dist .all_reduce (grad_output .contiguous (), group = ctx .group )
110110        return  None , grad_output 
111111
112+ 
112113class  GatherTensor (torch .autograd .Function ):
113114    """Gather the input from model parallel region and concatinate.""" 
114115
115- 
116116    @staticmethod  
117117    def  forward (ctx , group , input_ ):
118118        """Forward function.""" 
119119        # gather along last dim 
120-         world_size = dist .get_world_size (group )
121-         if  world_size == 1 :
122-             return   
123-         ctx .group = group 
124-         ctx .world_size = world_size 
125- 
126-         gather_shape  =  (world_size ,) +  input_ .shape 
127-         output  = torch .empty (gather_shape , dtype = input_ .dtype , device = get_accelerator ().current_device_name ()   )
128-         dist .all_gather_into_tensor (output , input_ .contiguous (),   group )
120+         world_size   =   dist .get_world_size (group )
121+         if  world_size   ==   1 :
122+             return 
123+         ctx .group   =   group 
124+         ctx .world_size   =   world_size 
125+ 
126+         gather_shape  =  (world_size ,  ) +  input_ .shape 
127+         output  =   torch .empty (gather_shape , dtype = input_ .dtype , device = get_accelerator ().current_device_name ())
128+         dist .all_gather_into_tensor (output , input_ .contiguous (), group )
129129        tensor_list  =  output .chunk (world_size , dim = 0 )
130130        output  =  torch .cat (tensor_list , dim = - 1 ).squeeze (0 ).contiguous ()
131131        return  output 
@@ -139,6 +139,7 @@ def backward(ctx, grad_output):
139139        grad_output  =  input_list [rank ].contiguous ()
140140        return  None , grad_output 
141141
142+ 
142143class  TensorParallel_Layer (nn .Module , ABC ):
143144    """ 
144145    A base class for model layers with  tensor parallelism support. 
@@ -434,19 +435,18 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa
434435        self .config_tp_params (self .weight )
435436        if  self .bias  is  not   None :
436437            self .config_tp_params (self .bias )
437-         self .gather_output = gather_output 
438-      
438+         self .gather_output  =  gather_output 
439439
440440    def  forward (self , input ):
441441        if  getattr (self , 'mp_group' , None ) is  not   None :
442442            input  =  ColumnParallel .apply (self .mp_group , input )
443443        output  =  torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
444444        if  self .bias  is  not   None :
445445            output  +=  self .bias 
446-              
446+ 
447447        if  self .gather_output :
448-             output  =  GatherTensor .apply (self .mp_group ,output )
449-              
448+             output  =  GatherTensor .apply (self .mp_group ,  output )
449+ 
450450        return  output 
451451
452452    @torch .no_grad () 
@@ -634,7 +634,7 @@ def __init__(self, module, mp_group, **kwargs):
634634    def  forward (self , input ):
635635        input_shard_size  =  get_shard_size (input .shape [- 1 ], self .tp_world_size , "lm_head" )
636636        input_shard_offset  =  sum (get_shard_size_list (input .shape [- 1 ], self .tp_world_size , "lm_head" )[0 :self .tp_index ])
637-          
637+ 
638638        output  =  torch .matmul (input [:, :, input_shard_offset :input_shard_offset  +  input_shard_size ],
639639                              self .weight .transpose (- 1 , - 2 ))
640640        if  self .mp_group  is  not   None :
0 commit comments