Skip to content

Commit f6fa384

Browse files
committed
format
Signed-off-by: inkcherry <[email protected]>
1 parent 1cf3038 commit f6fa384

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

deepspeed/module_inject/auto_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def _replace(self, child, name, conv_linear_layer):
384384
elif name == "lm_head" or name == 'embed_out':
385385
if is_autotp_training_mode():
386386
return child
387-
387+
388388
## gather output column parallel
389389
## return LinearLayer(child, self.mp_group, name=name, gather_output=True)
390390
else:

deepspeed/module_inject/layers.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,23 +109,23 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
109109
dist.all_reduce(grad_output.contiguous(), group=ctx.group)
110110
return None, grad_output
111111

112+
112113
class GatherTensor(torch.autograd.Function):
113114
"""Gather the input from model parallel region and concatinate."""
114115

115-
116116
@staticmethod
117117
def forward(ctx, group, input_):
118118
"""Forward function."""
119119
# gather along last dim
120-
world_size=dist.get_world_size(group)
121-
if world_size==1:
122-
return
123-
ctx.group=group
124-
ctx.world_size=world_size
125-
126-
gather_shape = (world_size,) + input_.shape
127-
output =torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name() )
128-
dist.all_gather_into_tensor(output, input_.contiguous(), group)
120+
world_size = dist.get_world_size(group)
121+
if world_size == 1:
122+
return
123+
ctx.group = group
124+
ctx.world_size = world_size
125+
126+
gather_shape = (world_size, ) + input_.shape
127+
output = torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name())
128+
dist.all_gather_into_tensor(output, input_.contiguous(), group)
129129
tensor_list = output.chunk(world_size, dim=0)
130130
output = torch.cat(tensor_list, dim=-1).squeeze(0).contiguous()
131131
return output
@@ -139,6 +139,7 @@ def backward(ctx, grad_output):
139139
grad_output = input_list[rank].contiguous()
140140
return None, grad_output
141141

142+
142143
class TensorParallel_Layer(nn.Module, ABC):
143144
"""
144145
A base class for model layers with tensor parallelism support.
@@ -434,19 +435,18 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa
434435
self.config_tp_params(self.weight)
435436
if self.bias is not None:
436437
self.config_tp_params(self.bias)
437-
self.gather_output=gather_output
438-
438+
self.gather_output = gather_output
439439

440440
def forward(self, input):
441441
if getattr(self, 'mp_group', None) is not None:
442442
input = ColumnParallel.apply(self.mp_group, input)
443443
output = torch.matmul(input, self.weight.transpose(-1, -2))
444444
if self.bias is not None:
445445
output += self.bias
446-
446+
447447
if self.gather_output:
448-
output = GatherTensor.apply(self.mp_group,output)
449-
448+
output = GatherTensor.apply(self.mp_group, output)
449+
450450
return output
451451

452452
@torch.no_grad()
@@ -634,7 +634,7 @@ def __init__(self, module, mp_group, **kwargs):
634634
def forward(self, input):
635635
input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head")
636636
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index])
637-
637+
638638
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
639639
self.weight.transpose(-1, -2))
640640
if self.mp_group is not None:

0 commit comments

Comments
 (0)