Skip to content

Commit 2574a1c

Browse files
authored
move cp_group setting to DotProductAttention (#468)
* rename set_context_parallel_running to set_context_parallel_group Signed-off-by: xren <[email protected]> * bug fix Signed-off-by: xren <[email protected]> --------- Signed-off-by: xren <[email protected]>
1 parent d7511ec commit 2574a1c

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,17 @@ def custom_forward(*input_args, **input_kwargs):
19141914

19151915
return hidden_states
19161916

1917+
def set_context_parallel_group(
1918+
self,
1919+
cp_group: Union[dist_group_type, None],
1920+
cp_global_ranks: List[int],
1921+
cp_stream: torch.cuda.Stream,
1922+
) -> None:
1923+
"""Set CP group"""
1924+
self.cp_group = cp_group
1925+
self.cp_global_ranks = cp_global_ranks
1926+
self.cp_stream = cp_stream
1927+
19171928
def forward(
19181929
self,
19191930
query_layer: torch.Tensor,
@@ -2549,16 +2560,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
25492560
"""Set TP group"""
25502561
self.tp_group = tp_group
25512562

2552-
def set_context_parallel_running(
2563+
def set_context_parallel_group(
25532564
self,
25542565
cp_group: Union[dist_group_type, None],
25552566
cp_global_ranks: List[int],
25562567
cp_stream: torch.cuda.Stream,
25572568
) -> None:
2558-
"""Set CP group and CP dual-stream running"""
2559-
self.core_attention.cp_group = cp_group
2560-
self.core_attention.cp_global_ranks = cp_global_ranks
2561-
self.core_attention.cp_stream = cp_stream
2569+
"""Set CP group"""
2570+
# Deep iterate but skip self to avoid infinite recursion.
2571+
for index, child in enumerate(self.modules()):
2572+
if index == 0:
2573+
continue
2574+
if hasattr(child, "set_context_parallel_group"):
2575+
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
25622576

25632577
def forward(
25642578
self,

transformer_engine/pytorch/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,19 +433,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
433433
if hasattr(child, "set_tensor_parallel_group"):
434434
child.set_tensor_parallel_group(tp_group)
435435

436-
def set_context_parallel_running(
436+
def set_context_parallel_group(
437437
self,
438438
cp_group: Union[dist_group_type, None],
439439
cp_global_ranks: List[int],
440440
cp_stream: torch.cuda.Stream,
441441
) -> None:
442-
"""Set CP group and CP dual-stream running"""
442+
"""Set CP group"""
443443
# Deep iterate but skip self to avoid infinite recursion.
444444
for index, child in enumerate(self.modules()):
445445
if index == 0:
446446
continue
447-
if hasattr(child, "set_context_parallel_running"):
448-
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream)
447+
if hasattr(child, "set_context_parallel_group"):
448+
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
449449

450450
def forward(
451451
self,

0 commit comments

Comments
 (0)