Skip to content

Commit

Permalink
Merge branch 'main' into float8tensor_experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
ksivaman authored Oct 18, 2023
2 parents 76af588 + f456ba1 commit 5d51e7b
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 15 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.0dev
1.1.0dev
2 changes: 0 additions & 2 deletions docs/api/c/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
C/C++ API
=========

.. Caution:: This feature is not officially supported yet and may change without notice.

The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
directly from C/C++, without Python.

Expand Down
12 changes: 6 additions & 6 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ pyTorch
=======

.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
:members: forward, set_context_parallel_group

.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward, set_context_parallel_group, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward, set_context_parallel_group, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
Expand Down
38 changes: 35 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,19 @@ def set_context_parallel_group(
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
Expand Down Expand Up @@ -2560,7 +2572,15 @@ def _allocate_memory(
)

def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group

def set_context_parallel_group(
Expand All @@ -2569,7 +2589,19 @@ def set_context_parallel_group(
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
Expand Down
10 changes: 9 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,15 @@ def set_fp8_weights(self) -> None:
)

def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group
self.tp_group_initialized = True

Expand Down
24 changes: 22 additions & 2 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,15 @@ def __init__(
)

def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
Expand All @@ -439,7 +447,19 @@ def set_context_parallel_group(
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
Expand Down

0 comments on commit 5d51e7b

Please sign in to comment.