From 2574a1ca23f6d7fe9b4748c6cc347f158d232e22 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Wed, 11 Oct 2023 08:43:08 -0700 Subject: [PATCH] move cp_group setting to DotProductAttention (#468) * rename set_context_parallel_running to set_context_parallel_group Signed-off-by: xren * bug fix Signed-off-by: xren --------- Signed-off-by: xren --- transformer_engine/pytorch/attention.py | 24 ++++++++++++++++++----- transformer_engine/pytorch/transformer.py | 8 ++++---- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 381a4ad553..dc4c9fca9c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1914,6 +1914,17 @@ def custom_forward(*input_args, **input_kwargs): return hidden_states + def set_context_parallel_group( + self, + cp_group: Union[dist_group_type, None], + cp_global_ranks: List[int], + cp_stream: torch.cuda.Stream, + ) -> None: + """Set CP group""" + self.cp_group = cp_group + self.cp_global_ranks = cp_global_ranks + self.cp_stream = cp_stream + def forward( self, query_layer: torch.Tensor, @@ -2549,16 +2560,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N """Set TP group""" self.tp_group = tp_group - def set_context_parallel_running( + def set_context_parallel_group( self, cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, ) -> None: - """Set CP group and CP dual-stream running""" - self.core_attention.cp_group = cp_group - self.core_attention.cp_global_ranks = cp_global_ranks - self.core_attention.cp_stream = cp_stream + """Set CP group""" + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(self.modules()): + if index == 0: + continue + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) def forward( self, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index c1006942d0..f5089e2d90 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -433,19 +433,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N if hasattr(child, "set_tensor_parallel_group"): child.set_tensor_parallel_group(tp_group) - def set_context_parallel_running( + def set_context_parallel_group( self, cp_group: Union[dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, ) -> None: - """Set CP group and CP dual-stream running""" + """Set CP group""" # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: continue - if hasattr(child, "set_context_parallel_running"): - child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream) + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream) def forward( self,