@@ -1920,7 +1920,19 @@ def set_context_parallel_group(
1920
1920
cp_global_ranks : List [int ],
1921
1921
cp_stream : torch .cuda .Stream ,
1922
1922
) -> None :
1923
- """Set CP group"""
1923
+ """
1924
+ Set the context parallel attributes for the given
1925
+ module before executing the forward pass.
1926
+
1927
+ Parameters
1928
+ ----------
1929
+ cp_group : ProcessGroup
1930
+ context parallel process group.
1931
+ cp_global_ranks : List[int]
1932
+ list of global ranks in the context group.
1933
+ cp_stream : torch.cuda.Stream
1934
+ cuda stream for context parallel execution.
1935
+ """
1924
1936
self .cp_group = cp_group
1925
1937
self .cp_global_ranks = cp_global_ranks
1926
1938
self .cp_stream = cp_stream
@@ -2560,7 +2572,15 @@ def _allocate_memory(
2560
2572
)
2561
2573
2562
2574
def set_tensor_parallel_group (self , tp_group : Union [dist_group_type , None ]) -> None :
2563
- """Set TP group"""
2575
+ """
2576
+ Set the tensor parallel group for the given
2577
+ module before executing the forward pass.
2578
+
2579
+ Parameters
2580
+ ----------
2581
+ tp_group : ProcessGroup, default = `None`
2582
+ tensor parallel process group.
2583
+ """
2564
2584
self .tp_group = tp_group
2565
2585
2566
2586
def set_context_parallel_group (
@@ -2569,7 +2589,19 @@ def set_context_parallel_group(
2569
2589
cp_global_ranks : List [int ],
2570
2590
cp_stream : torch .cuda .Stream ,
2571
2591
) -> None :
2572
- """Set CP group"""
2592
+ """
2593
+ Set the context parallel attributes for the given
2594
+ module before executing the forward pass.
2595
+
2596
+ Parameters
2597
+ ----------
2598
+ cp_group : ProcessGroup
2599
+ context parallel process group.
2600
+ cp_global_ranks : List[int]
2601
+ list of global ranks in the context group.
2602
+ cp_stream : torch.cuda.Stream
2603
+ cuda stream for context parallel execution.
2604
+ """
2573
2605
# Deep iterate but skip self to avoid infinite recursion.
2574
2606
for index , child in enumerate (self .modules ()):
2575
2607
if index == 0 :
0 commit comments