Skip to content

Commit 0963020

Browse files
authored
Improve documentation (#478)
Improve docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 2c41083 commit 0963020

File tree

5 files changed

+72
-14
lines changed

5 files changed

+72
-14
lines changed

docs/api/c/index.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
C/C++ API
77
=========
88

9-
.. Caution:: This feature is not officially supported yet and may change without notice.
10-
119
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
1210
directly from C/C++, without Python.
1311

docs/api/pytorch.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@ pyTorch
77
=======
88

99
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
10-
:members: forward
10+
:members: forward, set_tensor_parallel_group
1111

1212
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
1313

1414
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
1515

1616
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
17-
:members: forward
17+
:members: forward, set_tensor_parallel_group
1818

1919
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
20-
:members: forward
20+
:members: forward, set_tensor_parallel_group
2121

2222
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
23-
:members: forward
23+
:members: forward, set_context_parallel_group
2424

2525
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
26-
:members: forward
26+
:members: forward, set_context_parallel_group, set_tensor_parallel_group
2727

2828
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
29-
:members: forward
29+
:members: forward, set_context_parallel_group, set_tensor_parallel_group
3030

3131
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
3232
:members: swap_key_value_dict

transformer_engine/pytorch/attention.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,19 @@ def set_context_parallel_group(
19201920
cp_global_ranks: List[int],
19211921
cp_stream: torch.cuda.Stream,
19221922
) -> None:
1923-
"""Set CP group"""
1923+
"""
1924+
Set the context parallel attributes for the given
1925+
module before executing the forward pass.
1926+
1927+
Parameters
1928+
----------
1929+
cp_group : ProcessGroup
1930+
context parallel process group.
1931+
cp_global_ranks : List[int]
1932+
list of global ranks in the context group.
1933+
cp_stream : torch.cuda.Stream
1934+
cuda stream for context parallel execution.
1935+
"""
19241936
self.cp_group = cp_group
19251937
self.cp_global_ranks = cp_global_ranks
19261938
self.cp_stream = cp_stream
@@ -2560,7 +2572,15 @@ def _allocate_memory(
25602572
)
25612573

25622574
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
2563-
"""Set TP group"""
2575+
"""
2576+
Set the tensor parallel group for the given
2577+
module before executing the forward pass.
2578+
2579+
Parameters
2580+
----------
2581+
tp_group : ProcessGroup, default = `None`
2582+
tensor parallel process group.
2583+
"""
25642584
self.tp_group = tp_group
25652585

25662586
def set_context_parallel_group(
@@ -2569,7 +2589,19 @@ def set_context_parallel_group(
25692589
cp_global_ranks: List[int],
25702590
cp_stream: torch.cuda.Stream,
25712591
) -> None:
2572-
"""Set CP group"""
2592+
"""
2593+
Set the context parallel attributes for the given
2594+
module before executing the forward pass.
2595+
2596+
Parameters
2597+
----------
2598+
cp_group : ProcessGroup
2599+
context parallel process group.
2600+
cp_global_ranks : List[int]
2601+
list of global ranks in the context group.
2602+
cp_stream : torch.cuda.Stream
2603+
cuda stream for context parallel execution.
2604+
"""
25732605
# Deep iterate but skip self to avoid infinite recursion.
25742606
for index, child in enumerate(self.modules()):
25752607
if index == 0:

transformer_engine/pytorch/module/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,15 @@ def set_fp8_weights(self) -> None:
467467
)
468468

469469
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
470-
"""Set TP group."""
470+
"""
471+
Set the tensor parallel group for the given
472+
module before executing the forward pass.
473+
474+
Parameters
475+
----------
476+
tp_group : ProcessGroup, default = `None`
477+
tensor parallel process group.
478+
"""
471479
self.tp_group = tp_group
472480
self.tp_group_initialized = True
473481

transformer_engine/pytorch/transformer.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,15 @@ def __init__(
425425
)
426426

427427
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
428-
"""Set TP group"""
428+
"""
429+
Set the tensor parallel group for the given
430+
module before executing the forward pass.
431+
432+
Parameters
433+
----------
434+
tp_group : ProcessGroup, default = `None`
435+
tensor parallel process group.
436+
"""
429437
# Deep iterate but skip self to avoid infinite recursion.
430438
for index, child in enumerate(self.modules()):
431439
if index == 0:
@@ -439,7 +447,19 @@ def set_context_parallel_group(
439447
cp_global_ranks: List[int],
440448
cp_stream: torch.cuda.Stream,
441449
) -> None:
442-
"""Set CP group"""
450+
"""
451+
Set the context parallel attributes for the given
452+
module before executing the forward pass.
453+
454+
Parameters
455+
----------
456+
cp_group : ProcessGroup
457+
context parallel process group.
458+
cp_global_ranks : List[int]
459+
list of global ranks in the context group.
460+
cp_stream : torch.cuda.Stream
461+
cuda stream for context parallel execution.
462+
"""
443463
# Deep iterate but skip self to avoid infinite recursion.
444464
for index, child in enumerate(self.modules()):
445465
if index == 0:

0 commit comments

Comments
 (0)