Skip to content

Commit

Permalink
move cp_group setting to DotProductAttention (#468)
Browse files Browse the repository at this point in the history
* rename set_context_parallel_running to set_context_parallel_group

Signed-off-by: xren <[email protected]>

* bug fix

Signed-off-by: xren <[email protected]>

---------

Signed-off-by: xren <[email protected]>
  • Loading branch information
xrennvidia authored Oct 11, 2023
1 parent d7511ec commit 2574a1c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
24 changes: 19 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,17 @@ def custom_forward(*input_args, **input_kwargs):

return hidden_states

def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream

def forward(
self,
query_layer: torch.Tensor,
Expand Down Expand Up @@ -2549,16 +2560,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
"""Set TP group"""
self.tp_group = tp_group

def set_context_parallel_running(
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
self.core_attention.cp_group = cp_group
self.core_attention.cp_global_ranks = cp_global_ranks
self.core_attention.cp_stream = cp_stream
"""Set CP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,19 +433,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group)

def set_context_parallel_running(
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
"""Set CP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_running"):
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream)
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)

def forward(
self,
Expand Down

0 comments on commit 2574a1c

Please sign in to comment.