diff --git a/VERSION b/VERSION index 11524f9cda..fb570a932f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.0dev +1.1.0dev diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index faf6cd4575..a9c1443592 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -6,8 +6,6 @@ C/C++ API ========= -.. Caution:: This feature is not officially supported yet and may change without notice. - The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library directly from C/C++, without Python. diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index ea6d2b8763..53fdb3db66 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -7,26 +7,26 @@ pyTorch ======= .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) - :members: forward + :members: forward, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs) - :members: forward + :members: forward, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs) - :members: forward + :members: forward, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs) - :members: forward + :members: forward, set_context_parallel_group .. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs) - :members: forward + :members: forward, set_context_parallel_group, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) - :members: forward + :members: forward, set_context_parallel_group, set_tensor_parallel_group .. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length) :members: swap_key_value_dict diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b313880d66..0d2dbe0bc8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1920,7 +1920,19 @@ def set_context_parallel_group( cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, ) -> None: - """Set CP group""" + """ + Set the context parallel attributes for the given + module before executing the forward pass. + + Parameters + ---------- + cp_group : ProcessGroup + context parallel process group. + cp_global_ranks : List[int] + list of global ranks in the context group. + cp_stream : torch.cuda.Stream + cuda stream for context parallel execution. + """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream @@ -2560,7 +2572,15 @@ def _allocate_memory( ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: - """Set TP group""" + """ + Set the tensor parallel group for the given + module before executing the forward pass. + + Parameters + ---------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ self.tp_group = tp_group def set_context_parallel_group( @@ -2569,7 +2589,19 @@ def set_context_parallel_group( cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, ) -> None: - """Set CP group""" + """ + Set the context parallel attributes for the given + module before executing the forward pass. + + Parameters + ---------- + cp_group : ProcessGroup + context parallel process group. + cp_global_ranks : List[int] + list of global ranks in the context group. + cp_stream : torch.cuda.Stream + cuda stream for context parallel execution. + """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2edd10bcb4..9b6ab6e684 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -476,7 +476,15 @@ def set_fp8_weights(self) -> None: ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: - """Set TP group.""" + """ + Set the tensor parallel group for the given + module before executing the forward pass. + + Parameters + ---------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ self.tp_group = tp_group self.tp_group_initialized = True diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f5089e2d90..c21be000e3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -425,7 +425,15 @@ def __init__( ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: - """Set TP group""" + """ + Set the tensor parallel group for the given + module before executing the forward pass. + + Parameters + ---------- + tp_group : ProcessGroup, default = `None` + tensor parallel process group. + """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: @@ -439,7 +447,19 @@ def set_context_parallel_group( cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, ) -> None: - """Set CP group""" + """ + Set the context parallel attributes for the given + module before executing the forward pass. + + Parameters + ---------- + cp_group : ProcessGroup + context parallel process group. + cp_global_ranks : List[int] + list of global ranks in the context group. + cp_stream : torch.cuda.Stream + cuda stream for context parallel execution. + """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: