@@ -1914,6 +1914,17 @@ def custom_forward(*input_args, **input_kwargs):
1914
1914
1915
1915
return hidden_states
1916
1916
1917
+ def set_context_parallel_group (
1918
+ self ,
1919
+ cp_group : Union [dist_group_type , None ],
1920
+ cp_global_ranks : List [int ],
1921
+ cp_stream : torch .cuda .Stream ,
1922
+ ) -> None :
1923
+ """Set CP group"""
1924
+ self .cp_group = cp_group
1925
+ self .cp_global_ranks = cp_global_ranks
1926
+ self .cp_stream = cp_stream
1927
+
1917
1928
def forward (
1918
1929
self ,
1919
1930
query_layer : torch .Tensor ,
@@ -2549,16 +2560,19 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
2549
2560
"""Set TP group"""
2550
2561
self .tp_group = tp_group
2551
2562
2552
- def set_context_parallel_running (
2563
+ def set_context_parallel_group (
2553
2564
self ,
2554
2565
cp_group : Union [dist_group_type , None ],
2555
2566
cp_global_ranks : List [int ],
2556
2567
cp_stream : torch .cuda .Stream ,
2557
2568
) -> None :
2558
- """Set CP group and CP dual-stream running"""
2559
- self .core_attention .cp_group = cp_group
2560
- self .core_attention .cp_global_ranks = cp_global_ranks
2561
- self .core_attention .cp_stream = cp_stream
2569
+ """Set CP group"""
2570
+ # Deep iterate but skip self to avoid infinite recursion.
2571
+ for index , child in enumerate (self .modules ()):
2572
+ if index == 0 :
2573
+ continue
2574
+ if hasattr (child , "set_context_parallel_group" ):
2575
+ child .set_context_parallel_group (cp_group , cp_global_ranks , cp_stream )
2562
2576
2563
2577
def forward (
2564
2578
self ,
0 commit comments