@@ -112,10 +112,6 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
112112class GatherTensor (torch .autograd .Function ):
113113 """Gather the input from model parallel region and concatinate."""
114114
115- # @staticmethod
116- # def symbolic(graph, input_):
117- # """Symbolic function for tracing."""
118- # return _gather_along_last_dim(input_)
119115
120116 @staticmethod
121117 def forward (ctx , group , input_ ):
@@ -431,8 +427,7 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa
431427 super (LinearLayer , self ).__init__ (mp_group , ** kwargs )
432428 self .weight = module .weight
433429 self .bias = module .bias
434- if gather_output :
435- b = 0
430+
436431 if not skip_partition :
437432 self ._tp_partition ([self .weight , self .bias ])
438433 self .support_training = True
@@ -639,7 +634,6 @@ def __init__(self, module, mp_group, **kwargs):
639634 def forward (self , input ):
640635 input_shard_size = get_shard_size (input .shape [- 1 ], self .tp_world_size , "lm_head" )
641636 input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .tp_world_size , "lm_head" )[0 :self .tp_index ])
642- input = input [:, :, input_shard_offset :input_shard_offset + input_shard_size ]
643637
644638 output = torch .matmul (input [:, :, input_shard_offset :input_shard_offset + input_shard_size ],
645639 self .weight .transpose (- 1 , - 2 ))
0 commit comments