Skip to content

Commit 1cf3038

Browse files
committed
update
Signed-off-by: inkcherry <[email protected]>
1 parent 6ab42aa commit 1cf3038

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

deepspeed/module_inject/auto_tp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,10 @@ def _replace(self, child, name, conv_linear_layer):
383383
return Conv_LinearALlreduce(child, self.mp_group, name=name)
384384
elif name == "lm_head" or name == 'embed_out':
385385
if is_autotp_training_mode():
386-
# pass
387-
# return child
388-
return LinearLayer(child, self.mp_group, name=name, gather_output=True)
386+
return child
387+
388+
## gather output column parallel
389+
## return LinearLayer(child, self.mp_group, name=name, gather_output=True)
389390
else:
390391
return LmHeadLinearAllreduce(child, self.mp_group)
391392

deepspeed/module_inject/layers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,6 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
112112
class GatherTensor(torch.autograd.Function):
113113
"""Gather the input from model parallel region and concatinate."""
114114

115-
# @staticmethod
116-
# def symbolic(graph, input_):
117-
# """Symbolic function for tracing."""
118-
# return _gather_along_last_dim(input_)
119115

120116
@staticmethod
121117
def forward(ctx, group, input_):
@@ -431,8 +427,7 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa
431427
super(LinearLayer, self).__init__(mp_group, **kwargs)
432428
self.weight = module.weight
433429
self.bias = module.bias
434-
if gather_output:
435-
b=0
430+
436431
if not skip_partition:
437432
self._tp_partition([self.weight, self.bias])
438433
self.support_training = True
@@ -639,7 +634,6 @@ def __init__(self, module, mp_group, **kwargs):
639634
def forward(self, input):
640635
input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head")
641636
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index])
642-
input= input[:, :, input_shard_offset:input_shard_offset + input_shard_size]
643637

644638
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
645639
self.weight.transpose(-1, -2))

deepspeed/module_inject/replace_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,6 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
335335
return new_module
336336

337337
def set_lm_head(module):
338-
# if is_autotp_training_mode():
339-
# # we need to handle autoTP training mode separately.
340-
# return
341338

342339
embedding_weight = None
343340
for n, p in module.named_parameters():

0 commit comments

Comments
 (0)