From 50ff81166840dcfdecdb6d42cdc4610833c5dd99 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Thu, 9 Nov 2023 02:19:37 +0800 Subject: [PATCH 1/6] [JAX/Paddle] Deprecate QKV_INTERLEAVED enum (#504) * Deprecate QKV_INTERLEAVED use in JAX Signed-off-by: Reese Wang * Deprecate QKV_INTERLEAVED use in Paddle Signed-off-by: Reese Wang * Enhance qkv enum mappings Signed-off-by: rewang * Fix LD_LIBRARY_PATH issue Signed-off-by: rewang * Arbitrary seqlen kernels only support self attention currently Signed-off-by: rewang --------- Signed-off-by: Reese Wang Signed-off-by: rewang --- tests/jax/test_fused_attn.py | 8 ++--- tests/paddle/test_layers.py | 8 ++--- tests/paddle/test_operators.py | 10 +++--- tests/paddle/utils.py | 11 ++++--- .../common/fused_attn/fused_attn.cpp | 4 +-- transformer_engine/jax/cpp_extensions.py | 33 +++++++++++-------- transformer_engine/jax/csrc/extensions.cpp | 5 ++- transformer_engine/jax/csrc/modules.cpp | 16 ++++----- transformer_engine/jax/flax/transformer.py | 12 ++++--- transformer_engine/jax/fused_attn.py | 13 ++++++-- transformer_engine/paddle/constants.py | 6 ---- transformer_engine/paddle/cpp_extensions.py | 8 ++--- transformer_engine/paddle/csrc/common.cpp | 29 ++++++++++++++++ transformer_engine/paddle/csrc/common.h | 4 ++- transformer_engine/paddle/csrc/custom_ops.cu | 15 +-------- transformer_engine/paddle/csrc/extensions.cu | 19 +++++++++-- transformer_engine/paddle/layer/attention.py | 9 +++-- 17 files changed, 123 insertions(+), 87 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d67670c662..c4f70c85aa 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -18,7 +18,7 @@ from flax.linen import make_causal_mask from jax import value_and_grad, jit -from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType +from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine_jax import get_device_compute_capability @@ -163,13 +163,13 @@ def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probabi if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0: pytest.skip("Arbitrary seqlen backend hasn't support padded input.") - if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type, - dropout_probability, s, s, head_dim): + if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, + attn_mask_type, dropout_probability, s, s, head_dim): pytest.skip("Unsupported inputs combination or device compute capability.") compute_capability = get_device_compute_capability(0) if (backend == Backend.Max512 - and not (compute_capability == 80 or compute_capability >= 90)): + and not (compute_capability == 80 or compute_capability >= 90)): pytest.skip("Unsupported compute capability for " "fused attention with <=512 sequence length") diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index f398925b4f..fb544069b8 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -639,7 +639,7 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, kv_seqlen=kv_seqlen, dtype=math_dtype, dropout=0.0, - qkv_layout="qkv_interleaved" if attn_type == "self" else "kv_interleaved", + qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd", bias_type="no_bias", mask_type=mask_type, ): @@ -767,7 +767,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, kv_seqlen=kv_seqlen, dtype=math_dtype, dropout=0.0, - qkv_layout="qkv_interleaved", + qkv_layout="bs3hd", bias_type="no_bias", mask_type=mask_type, ): @@ -945,7 +945,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, kv_seqlen=kv_seqlen, dtype=math_dtype, dropout=0.0, - qkv_layout="qkv_interleaved", + qkv_layout="bs3hd", bias_type="no_bias", mask_type=mask_type, ): @@ -956,7 +956,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, kv_seqlen=kv_seqlen, dtype=math_dtype, dropout=0.0, - qkv_layout="kv_interleaved", + qkv_layout="bshd_bs2hd", bias_type="no_bias", mask_type=mask_type, ): diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index 455dbb02da..e81e3d87fa 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -683,9 +683,9 @@ def _get_fused_attention_out(self): kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) qkv_layout = ( - "qkv_interleaved" + "bs3hd" if self.attn_mode == "self_attn" - else "kv_interleaved" + else "bshd_bs2hd" ) fused_attention_backend = get_fused_attention_backend( head_size=self.head_size, @@ -779,7 +779,7 @@ def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): kv_seqlen=s, dtype=dtype, dropout=0.0, - qkv_layout="qkv_interleaved", + qkv_layout="bs3hd", bias_type="no_bias", mask_type="causal" if is_causal_masking else "padding", ): @@ -804,7 +804,7 @@ def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): kv_seqlen=s_kv, dtype=dtype, dropout=0.0, - qkv_layout="kv_interleaved", + qkv_layout="bshd_bs2hd", bias_type="no_bias", mask_type="padding", ): @@ -830,7 +830,7 @@ def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking) kv_seqlen=s, dtype=dtype, dropout=0.0, - qkv_layout="qkv_interleaved", + qkv_layout="bs3hd", bias_type="no_bias", mask_type="causal" if is_causal_masking else "padding", ): diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py index 1eb4eb7599..45c78f0bcf 100644 --- a/tests/paddle/utils.py +++ b/tests/paddle/utils.py @@ -14,13 +14,12 @@ import transformer_engine # pylint: disable=unused-import from transformer_engine.paddle.constants import ( TE_DType, - QKVLayout, AttnBiasType, AttnMaskType, FusedAttnBackend, ) from transformer_engine.paddle.fp8 import FP8TensorMeta -import transformer_engine_paddle as tex +import transformer_engine_paddle as tex # pylint: disable=wrong-import-order def create_fp8_meta(num_gemms=1, amax_history_len=10): @@ -101,13 +100,14 @@ def set_random_seed(seed): paddle.seed(global_seed) + def get_fused_attention_backend( head_size: int, q_seqlen: int, kv_seqlen: int, dtype: Union[paddle.dtype, str], dropout: float, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "bs3hd", bias_type: str = "no_bias", mask_type: str = "causal", ) -> tex.NVTE_Fused_Attn_Backend: @@ -121,7 +121,7 @@ def get_fused_attention_backend( return tex.get_fused_attn_backend( TE_DType[dtype], TE_DType[dtype], - QKVLayout[qkv_layout], + tex.get_nvte_qkv_layout(qkv_layout), AttnBiasType[bias_type], AttnMaskType[mask_type], dropout, @@ -130,13 +130,14 @@ def get_fused_attention_backend( head_size, ) + def is_fused_attention_supported( head_size: int, q_seqlen: int, kv_seqlen: int, dtype: Union[paddle.dtype, str], dropout: float, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "bs3hd", bias_type: str = "no_bias", mask_type: str = "causal", ) -> bool: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index f724d1d051..ed120a26a7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -133,8 +133,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) - || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { + || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) + || (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index b6fe6e1fd9..a876fe5315 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -83,6 +83,7 @@ class FusedAttnHelper: q_type: jnp.dtype kv_type: jnp.dtype + qkv_layout: NVTE_QKV_Layout attn_bias_type: NVTE_Bias_Type attn_mask_type: NVTE_Mask_Type dropout_probability: float @@ -96,10 +97,13 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" - return transformer_engine_jax.get_fused_attn_backend( - jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type), - NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, self.attn_bias_type, self.attn_mask_type, - self.dropout_probability, self.max_seqlen_q, self.max_seqlen_kv, self.head_dim) + return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type), + jax_dtype_to_te_dtype(self.kv_type), + self.qkv_layout, self.attn_bias_type, + self.attn_mask_type, + self.dropout_probability, + self.max_seqlen_q, self.max_seqlen_kv, + self.head_dim) def merge_named_shape(base, new): @@ -210,14 +214,15 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): # Need to disable one pylint error as the second function # parameter name recenctly in JAX. Otherwise we won't be # compatible with multiple JAX version. - out = custom_call(name, # pylint: disable=too-many-function-args - args.output_types, - operands=args.operands, - operand_layouts=args.operand_layouts, - result_layouts=args.output_layouts, - backend_config=opaque, - has_side_effect=has_side_effect, - **kwargs) + out = custom_call( # pylint: disable=too-many-function-args + name, + args.output_types, + operands=args.operands, + operand_layouts=args.operand_layouts, + result_layouts=args.output_layouts, + backend_config=opaque, + has_side_effect=has_side_effect, + **kwargs) return out @@ -2103,8 +2108,8 @@ def abstract( output_shape = (batch, max_seqlen, num_head, head_dim) output_dtype = qkv_dtype - backend = FusedAttnHelper(qkv_dtype, qkv_dtype, attn_bias_type, attn_mask_type, - dropout_probability, max_seqlen, max_seqlen, + backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, + attn_mask_type, dropout_probability, max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index d9e8361f1e..5c8e534126 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -86,9 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) - .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index db5668db3e..eb38830243 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -762,7 +762,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen should be equal to kv_max_seqlen in the self attention."); @@ -845,7 +845,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen should be equal to kv_max_seqlen in the self attention."); @@ -929,7 +929,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; auto dtype = descriptor.dtype; auto q_shape = std::vector{batch * q_max_seqlen, num_head, head_dim}; @@ -1064,9 +1064,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa s_tensor.data(), // not used for FP16/BF16 &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.scaling_factor, descriptor.dropout_probability, - NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, - query_workspace_tensor.data(), stream); + descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, + descriptor.bias_type, descriptor.mask_type, query_workspace_tensor.data(), stream); size_t workspace_size = query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); @@ -1081,9 +1080,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa s_tensor.data(), // not used for FP16/BF16 &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.scaling_factor, descriptor.dropout_probability, - NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, - workspace_tensor.data(), stream); + descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, + descriptor.bias_type, descriptor.mask_type, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 451d7731b1..316bfcdd56 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -23,7 +23,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax -from ..fused_attn import AttnBiasType, AttnMaskType +from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import self_fused_attn, cross_fused_attn from ..softmax import SoftmaxType @@ -428,6 +428,8 @@ def canonicalize_attn_mask_type(attn_mask_type): raise ValueError(f"Unsupported {attn_mask_type=}, " "supported attn_mask_type = {'causal', 'padding'}") + is_self_attn = (inputs_q is inputs_kv) + qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype) @@ -441,7 +443,7 @@ def _check_seqlen(seqlen): def _check_head_dim(head_dim): return head_dim in [64, 128] - has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, + has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, attn_bias_type, attn_mask_type, self.dropout_rate, q_seqlen, kv_seqlen, self.head_dim) @@ -484,7 +486,7 @@ def _check_head_dim(head_dim): residual = inputs_q if self.fuse_qkv: - if inputs_q is inputs_kv: + if is_self_attn: qkv_proj, ln_out = LayerNormDenseGeneral( enable_layernorm=not self.output_layernorm, layernorm_type=self.layernorm_type, @@ -571,7 +573,7 @@ def _check_head_dim(head_dim): kernel_init=query_init, name='query')(inputs_q) - if inputs_q is inputs_kv: + if is_self_attn: assert ln_out is not None inputs_kv = ln_out @@ -650,7 +652,7 @@ def _check_head_dim(head_dim): # ensure the old key never used del dropout_rng - if inputs_q is inputs_kv: + if is_self_attn: qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 3951d87274..100ad89f46 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -10,6 +10,7 @@ from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Mask_Type +from transformer_engine_jax import NVTE_QKV_Layout from .cpp_extensions import FusedAttnHelper from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd @@ -36,13 +37,19 @@ class AttnMaskType(Enum): CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK -def is_fused_attn_kernel_available(q_type, kv_type, attn_bias_type, attn_mask_type, +class QKVLayout(Enum): + """QKV layout""" + BS3HD = NVTE_QKV_Layout.NVTE_BS3HD + BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD + + +def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim): """ To check whether the fused attention kernel is available """ - return FusedAttnHelper(q_type, kv_type, attn_bias_type.value, attn_mask_type.value, - dropout_probability, max_seqlen_q, max_seqlen_kv, + return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, + attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py index 58d603374c..76ba76b95e 100644 --- a/transformer_engine/paddle/constants.py +++ b/transformer_engine/paddle/constants.py @@ -53,12 +53,6 @@ class FP8BwdTensors(Enum): RecomputeFunctionNames = ('unpack', 'backward') -QKVLayout = { - "not_interleaved": tex.NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED, - "qkv_interleaved": tex.NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, - "kv_interleaved": tex.NVTE_QKV_Layout.NVTE_KV_INTERLEAVED, -} - AttnBiasType = { "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, "pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index 0a13df9684..0ad4878b07 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -431,7 +431,7 @@ def fused_attn_fwd_qkvpacked( attn_scale: float = None, dropout: float = 0.0, set_zero: bool = True, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "bs3hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -518,7 +518,7 @@ def fused_attn_bwd_qkvpacked( attn_scale: float = None, dropout: float = 0.0, set_zero: bool = True, - qkv_layout: str = "qkv_interleaved", + qkv_layout: str = "bs3hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -587,7 +587,7 @@ def fused_attn_fwd_kvpacked( attn_scale: float = None, dropout: float = 0.0, set_zero: bool = True, - qkv_layout: str = "kv_interleaved", + qkv_layout: str = "bshd_bs2hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -685,7 +685,7 @@ def fused_attn_bwd_kvpacked( attn_scale: float = None, dropout: float = 0.0, set_zero: bool = True, - qkv_layout: str = "kv_interleaved", + qkv_layout: str = "bshd_bs2hd", bias_type: str = "no_bias", attn_mask_type: str = "padding", ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp index e71a201480..7f3bbd3126 100644 --- a/transformer_engine/paddle/csrc/common.cpp +++ b/transformer_engine/paddle/csrc/common.cpp @@ -53,5 +53,34 @@ paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const pad NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); } +// MHA utils +// convert QKV layout to enum +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) { + static const std::unordered_map layout_map = { + {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, + {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, + {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, + {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, + {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, + {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, + {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, + {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, + {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, + {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, + {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, + {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, + {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, + {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, + {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, + }; + + auto it = layout_map.find(qkv_layout); + if (it != layout_map.end()) { + return it->second; + } else { + NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); + } +} + } // namespace paddle_ext } // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index ae12ab44be..265c167e32 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -12,7 +12,6 @@ #include "paddle/extension.h" #include "paddle/phi/backends/all_context.h" -#include "common/util/logging.h" #include #include #include @@ -22,6 +21,7 @@ #include #include #include +#include "common/util/logging.h" namespace transformer_engine { namespace paddle_ext { @@ -177,5 +177,7 @@ TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, c TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout); + } // namespace paddle_ext } // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 3703285d18..40e7951f2a 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include "common.h" @@ -13,20 +14,6 @@ namespace transformer_engine { namespace paddle_ext { -// MHA utils -// convert QKV layout to enum -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) { - if (qkv_layout == "not_interleaved") { - return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED; - } else if (qkv_layout == "qkv_interleaved") { - return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; - } else if (qkv_layout == "kv_interleaved") { - return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; - } else { - NVTE_ERROR("Invalid QKV layout. \n"); - } -} - // convert bias type to enum NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { if (bias_type == "no_bias") { diff --git a/transformer_engine/paddle/csrc/extensions.cu b/transformer_engine/paddle/csrc/extensions.cu index 095ef7c58a..82f4479b42 100644 --- a/transformer_engine/paddle/csrc/extensions.cu +++ b/transformer_engine/paddle/csrc/extensions.cu @@ -15,6 +15,7 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); + m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); // Data structures py::enum_(m, "DType", py::module_local()) .value("kByte", DType::kByte) @@ -36,9 +37,21 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED) - .value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) - .value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED); + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index 02aa53b042..55d117251c 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -15,8 +15,8 @@ from .layernorm_linear import LayerNormLinear from .linear import Linear from .softmax import FusedScaleMaskSoftmax -from ..constants import (AttnTypes, TE_DType, QKVLayout, AttnBiasType, AttnMaskType, - FusedAttnBackend, dist_group_type) +from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend, + dist_group_type) from ..cpp_extensions import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, @@ -28,7 +28,6 @@ from ..utils import attention_mask_func, divide from ..recompute import recompute - __all__ = ["DotProductAttention", "MultiHeadAttention"] @@ -168,7 +167,7 @@ def __init__(self, self.attn_mask_type = attn_mask_type self.attention_dropout = attention_dropout self.attention_type = attention_type - self.qkv_layout = "qkv_interleaved" if attention_type == "self" else "kv_interleaved" + self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd" self.backend = backend @@ -237,7 +236,7 @@ def forward( max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1] self.fused_attention_backend = tex.get_fused_attn_backend( TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], - QKVLayout[self.qkv_layout], AttnBiasType[core_attention_bias_type], + tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv, query_layer.shape[-1]) From c706ff8df5d3792c629aaec22c4e56dc9571001c Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Wed, 8 Nov 2023 13:07:16 -0800 Subject: [PATCH 2/6] Returning an empty tensor of param dtype for wgrad (#507) * Returning an empty tensor of param dtype for wgrad Signed-off-by: Selvaraj Anandaraj * lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 5 +++++ transformer_engine/pytorch/module/layernorm_mlp.py | 10 ++++++++++ transformer_engine/pytorch/module/linear.py | 5 +++++ 3 files changed, 20 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d4746ba3a0..0679f2dac2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -525,6 +525,11 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): weight.grad_added_to_main_grad = True + wgrad = torch.empty(weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 40256dba6a..bda4c309ab 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -879,6 +879,11 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'): fc1_weight.grad_added_to_main_grad = True + fc1_wgrad = torch.empty(fc1_weight.main_grad.shape, + dtype=fc1_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False + ) elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: @@ -888,6 +893,11 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'): fc2_weight.grad_added_to_main_grad = True + fc2_wgrad = torch.empty(fc2_weight.main_grad.shape, + dtype=fc2_weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False + ) elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b14877e74b..66c1a8f012 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -465,6 +465,11 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): weight.grad_added_to_main_grad = True + wgrad = torch.empty(weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False + ) elif ctx.fuse_wgrad_accumulation: wgrad = None else: From 64a3d1d565b2630ed370aa4ad1f8554cb5210407 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Wed, 8 Nov 2023 17:21:22 -0800 Subject: [PATCH 3/6] Make user buffer name configurable (#499) * Make user buffer name configurable Signed-off-by: Sangkug Lym * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Fix duplicate argument Signed-off-by: Kirthi Shankar Sivamani * Fix autograd Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Sangkug Lym Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/module/layernorm_linear.py | 15 ++++++++++++--- transformer_engine/pytorch/module/linear.py | 12 ++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0679f2dac2..64fc9c2c8c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -86,6 +86,7 @@ def forward( ub_bulk_dgrad: bool, ub_split_ag: bool, ub_atomic_gemm_ag: bool, + ub_name: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -111,7 +112,7 @@ def forward( if ub_split_ag or ub_atomic_gemm_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("qkv_fprop") + ub_obj_lnout = get_ub(ub_name+"_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype @@ -268,6 +269,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -310,7 +312,7 @@ def backward( if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("qkv_dgrad") + ub_obj_lnout = get_ub(ctx.ub_name+"_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ( grad_output, @@ -350,7 +352,7 @@ def backward( dgrad_size = list(grad_output.size()) dgrad_size[1] = weight.size(1) if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("qkv_wgrad") + ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) @@ -567,6 +569,7 @@ def backward( None, None, None, + None, ) @@ -674,6 +677,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, ub_atomic_gemm_ag: bool = False, + ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -694,6 +698,10 @@ def __init__( self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag: assert ( @@ -978,6 +986,7 @@ def forward( self.ub_bulk_dgrad, self.ub_split_ag, self.ub_atomic_gemm_ag, + self.ub_name, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 66c1a8f012..7adfb035d3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -82,6 +82,7 @@ def forward( ub_split_ag: bool, ub_atomic_gemm_rs: bool, ub_atomic_gemm_ag: bool, + ub_name: str, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -180,7 +181,7 @@ def forward( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) if ub_split_rs or ub_atomic_gemm_rs: - ub_obj_projout = get_ub("proj_fprop") + ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size @@ -285,6 +286,7 @@ def forward( ctx.tp_group = tp_group ctx.ub_split_ag = ub_split_ag ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag + ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -326,7 +328,7 @@ def backward( if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub("proj_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad") ( grad_output, grad_output_c, @@ -499,6 +501,7 @@ def backward( None, None, None, + None, ) @@ -588,6 +591,7 @@ def __init__( ub_split_ag: bool = False, ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_ag: bool = False, + ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -604,6 +608,9 @@ def __init__( self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: assert ( @@ -848,6 +855,7 @@ def forward( self.ub_split_ag, self.ub_atomic_gemm_rs, self.ub_atomic_gemm_ag, + self.ub_name, ) out = linear_fn(*args) From bfaec64489bf05dcbd80aa0a1a167b2d560747fe Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 14 Nov 2023 05:00:54 +0800 Subject: [PATCH 4/6] [C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469) * Move bias to float32 Signed-off-by: Reese Wang * Enable varlen Signed-off-by: Reese Wang * Increase neg infinity abs values Signed-off-by: Reese Wang * Enable varlen tests Signed-off-by: Reese Wang * Remove unnecessary code Signed-off-by: Reese Wang * Fix lint Signed-off-by: Reese Wang * Support variable sequence length after cuDNN 8.9.6 Signed-off-by: Reese Wang * Use unique_ptr instead of shared_ptr Signed-off-by: Reese Wang * Add a new mask type: PADDING_CAUSAL_MASK Signed-off-by: Reese Wang * Support flash padding mask after 8.9.6 Signed-off-by: Reese Wang * Enhance the Max512 handling for causal masking and add the related tests Signed-off-by: Reese Wang * Update the fused attn support lists Signed-off-by: Reese Wang * Remove padding_aware from the caching Signed-off-by: Reese Wang * Fix libtransformer.so issue Signed-off-by: Reese Wang * Reduce the pad ratio tests Signed-off-by: Reese Wang * Fix a bug with cuDNN 8.9.5 Signed-off-by: Reese Wang * Release backend resource after the module level unit test Signed-off-by: Reese Wang * Clean the jax live arrays before running the unit tests Signed-off-by: Reese Wang * Fix too-few-public-methods lint Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_custom_call_compute.py | 10 + tests/jax/test_custom_call_shape.py | 11 + tests/jax/test_fused_attn.py | 98 ++-- tests/jax/test_layer.py | 10 + tests/jax/test_praxis_layers.py | 10 + .../common/fused_attn/fused_attn.cpp | 8 +- .../fused_attn_f16_arbitrary_seqlen.cu | 440 ++++++++++++++---- .../fused_attn_f16_max512_seqlen.cu | 9 +- .../include/transformer_engine/fused_attn.h | 42 +- transformer_engine/jax/csrc/extensions.cpp | 3 +- transformer_engine/jax/flax/module.py | 16 +- transformer_engine/jax/flax/transformer.py | 8 +- transformer_engine/jax/fused_attn.py | 1 + 13 files changed, 500 insertions(+), 166 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 10d5b5f987..05b7bc3603 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -37,6 +37,16 @@ is_fp8_supported, reason = is_fp8_available() +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() + + class TestFP8Dot: @pytest.mark.skipif(not is_fp8_supported, reason=reason) diff --git a/tests/jax/test_custom_call_shape.py b/tests/jax/test_custom_call_shape.py index 0539575629..32d645b668 100644 --- a/tests/jax/test_custom_call_shape.py +++ b/tests/jax/test_custom_call_shape.py @@ -3,6 +3,7 @@ # See LICENSE for license information. import pytest +import jax import jax.numpy as jnp from jax.core import ShapedArray @@ -31,6 +32,16 @@ TRANSPOSE = [True, False] +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() + + class TestGEMMShapeInfer: @staticmethod diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index c4f70c85aa..20e6ab5f77 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -21,12 +21,22 @@ from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available -from transformer_engine_jax import get_device_compute_capability +from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order # Type annotations Array = jnp.ndarray +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() + + class Backend(Enum): """ Fused attn backend. @@ -52,6 +62,13 @@ def fixture_backend(request): DTYPES = [jnp.bfloat16, jnp.float16] +def is_causal_mask(mask: AttnMaskType): + """ + Check if the mask is a causal mask + """ + return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] + + def make_decoder_mask(tokens: Array) -> Array: """ Create padded causal mask @@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): Self attention with JAX native implementation """ attn_mask_type = kwargs['attn_mask_type'] - if attn_mask_type == AttnMaskType.CAUSAL_MASK: + if is_causal_mask(attn_mask_type): mask = make_decoder_mask(q_token) else: mask = make_attention_mask(q_token > 0, kv_token > 0) @@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): deterministic=not kwargs['is_training'], dropout_rate=kwargs['dropout_probability'], dropout_rng=dropout_rng, - dtype=qkv.dtype) - return output + dtype=jnp.float32) + return output.astype(qkv.dtype) def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): @@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): assert q.dtype == kv.dtype attn_mask_type = kwargs['attn_mask_type'] - if attn_mask_type == AttnMaskType.CAUSAL_MASK: + if is_causal_mask(attn_mask_type): raise NotImplementedError mask = make_attention_mask(q_token > 0, kv_token > 0) @@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): deterministic=not kwargs['is_training'], dropout_rate=kwargs['dropout_probability'], dropout_rng=dropout_rng, - dtype=q.dtype) - return output + dtype=jnp.float32) + return output.astype(q.dtype) def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): """ Self fused attention """ - if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: + attn_mask_type = kwargs['attn_mask_type'] + if is_causal_mask(attn_mask_type): mask = make_decoder_mask(q_token) else: mask = make_attention_mask(q_token > 0, kv_token > 0) @@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) """ assert q.dtype == kv.dtype - if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: + attn_mask_type = kwargs['attn_mask_type'] + if is_causal_mask(attn_mask_type): raise NotImplementedError mask = make_attention_mask(q_token > 0, kv_token > 0) @@ -149,32 +168,28 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) @pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) -@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) +@pytest.mark.parametrize('attn_mask_type', [ + AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK +]) @pytest.mark.parametrize('dropout_probability', [0., 0.1]) @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('is_training', [True, False]) -@pytest.mark.parametrize('pad_ratio', [0, 0.3]) class TestSelfFusedAttn(): """Tests for transformer_engine.jax.fused_attn.self_fused_attn""" @staticmethod def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype, - head_dim, pad_ratio): - if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0: - pytest.skip("Arbitrary seqlen backend hasn't support padded input.") + head_dim): + + assert isinstance(backend, Backend) if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, attn_mask_type, dropout_probability, s, s, head_dim): pytest.skip("Unsupported inputs combination or device compute capability.") - compute_capability = get_device_compute_capability(0) - if (backend == Backend.Max512 - and not (compute_capability == 80 or compute_capability >= 90)): - pytest.skip("Unsupported compute capability for " - "fused attention with <=512 sequence length") - def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, - dropout_probability, dtype, is_training, pad_ratio): + dropout_probability, dtype, is_training): """Setup the test inputs""" self.__class__._check_inputs(s, attn_bias_type=attn_bias_type, @@ -182,8 +197,13 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, backend=backend, dropout_probability=dropout_probability, dtype=dtype, - head_dim=d, - pad_ratio=pad_ratio) + head_dim=d) + + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + pad_ratio = 0.0 + else: + pad_ratio = 0.3 + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) @@ -212,7 +232,7 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, self.is_training = is_training def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability, - dtype, is_training, pad_ratio): + dtype, is_training): """ Test forward without using JIT """ @@ -225,8 +245,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop backend=backend, dropout_probability=dropout_probability, dtype=dtype, - is_training=is_training, - pad_ratio=pad_ratio) + is_training=is_training) primitive_out = customcall_self_fused_attn(self.qkv, self.bias, @@ -265,7 +284,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop jnp.zeros_like(pri_invalid, jnp.float32)) def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, - dropout_probability, dtype, is_training, pad_ratio): + dropout_probability, dtype, is_training): """ Test forward, backward, and autodiff by jax.value_and_grad """ @@ -281,13 +300,12 @@ def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, back backend=backend, dropout_probability=dropout_probability, dtype=dtype, - is_training=is_training, - pad_ratio=pad_ratio) + is_training=is_training) def grad_func(fused_attn_func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the graident gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000 - if attn_mask_type == AttnMaskType.CAUSAL_MASK: + if is_causal_mask(attn_mask_type): gradient_multiplier = gradient_multiplier / 10 # Keep only valid result for the gradient # fused_attn output has shape (b, s, h, d) @@ -333,15 +351,15 @@ def grad_func(fused_attn_func, *args, **kwargs): rtol=1e-4, atol=1e-5) - valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,), - axis=1) - valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,), - axis=1) + valid_primitive_dqkv, invalid_primitive_dqkv = \ + jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1) + valid_reference_dqkv, invalid_reference_dqkv = \ + jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1) - valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split( - valid_primitive_dqkv.astype(jnp.float32), 3, axis=2) - valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split( - valid_reference_dqkv.astype(jnp.float32), 3, axis=2) + valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \ + jnp.split(valid_primitive_dqkv, 3, axis=2) + valid_reference_dq, valid_reference_dk, valid_reference_dv = \ + jnp.split(valid_reference_dqkv, 3, axis=2) np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5) np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5) @@ -482,9 +500,7 @@ def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_prob def grad_func(fused_attn_func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the graident - gradient_multiplier = 10000 - if attn_mask_type == AttnMaskType.CAUSAL_MASK: - gradient_multiplier = gradient_multiplier / 10 + gradient_multiplier = 1e4 # Keep only valid result for the gradient # fused_attn output has shape (b, s_q, h, d) valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 4f9e224663..0037142b64 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -19,6 +19,16 @@ is_fp8_supported, reason = is_fp8_available() +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() + + def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs): output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs) return jnp.mean(output) diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 5a1bf41fb2..3c440abb70 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -38,6 +38,16 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() + + def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: assert key in test_fd, \ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ed120a26a7..aa37719817 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -87,6 +87,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const int sm_arch_ = cuda::sm_arch(device_id); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + auto cudnn_runtime_version = cudnnGetVersion(); if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) && (sm_arch_ >= 90) && (max_seqlen_q == max_seqlen_kv) @@ -111,6 +112,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) @@ -131,7 +133,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && (max_seqlen_q == max_seqlen_kv) && ((head_dim == 64) || (head_dim == 128)) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || ((cudnn_runtime_version >= 8906) && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index dd4bf301a3..61676e483a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -79,7 +79,7 @@ createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, static cudnn_frontend::Tensor createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, std::vector* ops) { // Creates the necessary tensor descriptors int64_t q_dim[4] = {b, h, s_q, d}; @@ -95,6 +95,9 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t s_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); auto kTransposeTensor = tensor_create( tensorType, K_ID, k_dim, k_stride, false, false); // is virtual @@ -105,21 +108,150 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, // Define the matmul 1 desc auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); + auto seqlenQTensor = tensor_create( + CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = tensor_create( + CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + // Create a matmul 1 node - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(kTransposeTensor) - .setcMatDesc(sTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); + auto&& matmul_op_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op_builder.setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(sTensor) + .setmatmulDesc(matmul_1_Desc); + + if (padding_aware) { + matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor); + } + + auto matmul_op1 = matmul_op_builder.build(); ops->push_back(std::move(matmul_op1)); return sTensor; } +static cudnn_frontend::Tensor +createPaddingMask(int64_t b, + int64_t h, + int64_t s_q, + int64_t s_kv, + int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor& prevBlockOutputTensor) { + CUDNN_FRONTEND_UNUSED(d); + CUDNN_FRONTEND_UNUSED(layout); + CUDNN_FRONTEND_UNUSED(tensorType); + + NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one"); + + // subtraction output + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t maskVal_dim[4] = {1, 1, 1, 1}; + int64_t maskVal_stride[4] = {1, 1, 1, 1}; + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + // mask value to put in the masked pixels + auto maskValTensor = tensor_create( + CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true); + auto seqlenQTensor = tensor_create( + CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = tensor_create( + CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + + // gen index row output + auto rowIndexTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false); + // gen index column output + auto columnIndexTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false); + // less than row output + auto lessThanRowTensor = tensor_create( + CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false); + // less than column output + auto lessThanColTensor = tensor_create( + CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false); + // padding mask (lessthanRow && lessthanCol) + auto paddingMaskTensor = tensor_create( + CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false); + + // output after masking + auto maskOutputTensor = tensor_create( + CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false); + + // Define the gen index for row descriptor + auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(2) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a gen index Node. + auto genIndexRow_op = unary_pw_op_create( + prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); + + // Define the gen index for row descriptor + auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setAxis(3) + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a gen index Node. + auto genIndexColumn_op = unary_pw_op_create( + prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); + + // Define the less than comparison for row descriptor + auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); + + // Create a less than comparison for row Node. + auto lessThanRow_op = binary_pw_op_create( + rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc); + + // Define the less than comparison for column descriptor + auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); + + // Create a less than comparison for col Node. + auto lessThanCol_op = binary_pw_op_create( + columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc); + + // Define the less than comparison for column descriptor + auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); + + // Create a and node for combining lessThanRow and lessThanCol + auto paddingMaskAnd_op = binary_pw_op_create( + lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc); + + /////////////////// Apply the mask ////////////////////////// + + // Define the binary select to perform masking descriptor + auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); + + // Create a binary select Node. + auto mask_op = ternary_pw_op_create( + prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc); + + ops->push_back(std::move(genIndexRow_op)); + ops->push_back(std::move(genIndexColumn_op)); + ops->push_back(std::move(lessThanRow_op)); + ops->push_back(std::move(lessThanCol_op)); + ops->push_back(std::move(paddingMaskAnd_op)); + ops->push_back(std::move(mask_op)); + + return maskOutputTensor; +} + static cudnn_frontend::Tensor createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, @@ -502,7 +634,7 @@ createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d static void createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, + bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType, std::vector* ops, cudnn_frontend::Tensor const &afterScaleDropoutTensor) { NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one"); @@ -515,6 +647,14 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t o_stride[4]; generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + + auto seqlenQTensor = tensor_create( + CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto seqlenKTensor = tensor_create( + CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); + auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); // second GEMM output auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); @@ -522,15 +662,23 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, // Define the matmul 2 desc auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); // Create a matmul 2 node - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(afterScaleDropoutTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); + auto&& matmul_op_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op_builder.setaMatDesc(afterScaleDropoutTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmatmulDesc(matmul_2_Desc); + + if (padding_aware) { + matmul_op_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor); + } + + auto matmul_op2 = matmul_op_builder.build(); ops->push_back(std::move(matmul_op2)); } @@ -538,9 +686,10 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrSoftmaxStats, void *devPtrO, + void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, @@ -552,12 +701,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( dropout_probability = 0.0f; } + // also known as variable_sequence_length + bool padding_aware = (mask_type == NVTE_PADDING_MASK) || + (mask_type == NVTE_PADDING_CAUSAL_MASK); + FADescriptor descriptor{b, h, s_q, s_kv, d, scaling_factor, is_training, dropout_probability, layout, NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType, + mask_type, tensorType, false}; using CacheType = std::map; @@ -577,15 +730,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector ops; // Q * K^T - auto sTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops); + auto sTensor = createQKBMM( + b, h, s_q, s_kv, d, padding_aware, layout, tensorType, &ops); // Q * K^T * bmmScale auto sScaleTensor = createScale( b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops); - // Causual mask - auto sAfterMaskTensor = createCausalMask( - b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor); + auto& sAfterMaskTensor = sScaleTensor; + + if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) { + sAfterMaskTensor = createCausalMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor); + } + + if (padding_aware) { + sAfterMaskTensor = createPaddingMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor); + } NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0"); @@ -597,7 +759,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto dropout_output = createDropoutForward( b, h, s_q, s_kv, d, dropout_probability, tensorType, &ops, softmax_output); - createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dropout_output); + createSVBMM(b, h, s_q, s_kv, d, padding_aware, + layout, tensorType, &ops, dropout_output); for (unsigned int i = 0; i < ops.size(); i++) { all_ops.push_back(&ops[i]); @@ -636,13 +799,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // Exit to request upper level API to allocate memory if needed if (workspace == nullptr) { - *workspace_size = plan_workspace_size; + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; return; } + // Prepare actual seqlen + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + + if (padding_aware) { + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), + static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + std::set> data_ptrs; // Add all the data pointers to be used in the variant pack - float negInfinity = -1.0E+10f; + float negInfinity = -1.0E+30f; float scale_dropout = 1.0f/(1.0f - dropout_probability); data_ptrs.insert(std::pair(Q_ID, devPtrQ)); @@ -655,6 +834,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( data_ptrs.insert(std::pair(D_OFFSET_ID, devPtrDropoutOffset)); data_ptrs.insert(std::pair(D_CONST_ID, &scale_dropout)); + if (padding_aware) { + data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); + data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); + } + // If training mode, we write out softmax stats if (is_training) { data_ptrs.insert(std::pair(S_STATS_ID, devPtrSoftmaxStats)); @@ -675,21 +859,26 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose, - void* devPtrO, void* devPtrSoftmaxStats, + NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrKTranspose, + void* devPtrVTranspose, void* devPtrO, void* devPtrSoftmaxStats, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) { try { NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + // also known as variable_sequence_length + bool padding_aware = (mask_type == NVTE_PADDING_MASK) || + (mask_type == NVTE_PADDING_CAUSAL_MASK); + FADescriptor descriptor{b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, layout, NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType, + mask_type, tensorType, use_workspace_opt}; using CacheType = std::map; @@ -747,9 +936,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + int64_t scale_dim[4] = {1, 1, 1, 1}; int64_t scale_stride[4] = {1, 1, 1, 1}; + auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, + seqlen_stride, false, false); + auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, + seqlen_stride, false, false); + /******************************************************************************* * Dot product dO * O */ @@ -823,15 +1020,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( // matmul to calculate dvTensor auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - auto matmul_op0 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(kTransposeTensor) - .setcMatDesc(pTensor) - .setmatmulDesc(matmul_0_Desc) - .build(); + auto&& matmul_op_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op_builder.setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(pTensor) + .setmatmulDesc(matmul_0_Desc); + + if (padding_aware) { + matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor); + } + + auto matmul_op0 = matmul_op_builder.build(); ops.push_back(std::move(matmul_op0)); @@ -851,8 +1055,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( /******************************************************************************* * Causal masking -> pAfterMaskTensor */ - auto pAfterMaskTensor = createCausalMask( - b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor); + auto& pAfterMaskTensor = pAfterScaleTensor; + + if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) { + pAfterMaskTensor = createCausalMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor); + } + + if (padding_aware) { + pAfterMaskTensor = createPaddingMask( + b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor); + } /******************************************************************************* * pAfterMaskTensor - softmaxStats -> pAfterSubtract */ @@ -930,15 +1143,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - auto matmul_op1 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(sTransposeTensor) - .setbMatDesc(dOTensor) - .setcMatDesc(dVTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); + auto&& matmul_op1_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op1_builder.setaMatDesc(sTransposeTensor) + .setbMatDesc(dOTensor) + .setcMatDesc(dVTensor) + .setmatmulDesc(matmul_1_Desc); + + if (padding_aware) { + matmul_op1_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor); + } + + auto matmul_op1 = matmul_op1_builder.build(); ops.push_back(std::move(matmul_op1)); @@ -954,15 +1174,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - auto matmul_op2 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dOTensor) - .setbMatDesc(vTransposeTensor) - .setcMatDesc(dSTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); + auto&& matmul_op2_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op2_builder.setaMatDesc(dOTensor) + .setbMatDesc(vTransposeTensor) + .setcMatDesc(dSTensor) + .setmatmulDesc(matmul_2_Desc); + + if (padding_aware) { + matmul_op2_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor); + } + + auto matmul_op2 = matmul_op2_builder.build(); ops.push_back(std::move(matmul_op2)); @@ -1059,30 +1286,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - if (!use_workspace_opt) { - auto matmul_op3 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dPScaledTensor) - .setbMatDesc(kTensor) - .setcMatDesc(dqAccumTensor) - .setmatmulDesc(matmul_3_Desc) - .build(); - - ops.push_back(std::move(matmul_op3)); + auto&& matmul_op3_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op3_builder.setaMatDesc(dPScaledTensor) + .setbMatDesc(kTensor) + .setmatmulDesc(matmul_3_Desc); + + if (use_workspace_opt) { + matmul_op3_builder.setcMatDesc(dQTensor); } else { - auto matmul_op3 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dPScaledTensor) - .setbMatDesc(kTensor) - .setcMatDesc(dQTensor) - .setmatmulDesc(matmul_3_Desc) - .build(); - - ops.push_back(std::move(matmul_op3)); + matmul_op3_builder.setcMatDesc(dqAccumTensor); + } + + if (padding_aware) { + matmul_op3_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor); } + auto matmul_op3 = matmul_op3_builder.build(); + + ops.push_back(std::move(matmul_op3)); + /******************************************************************************* * dP.T @ Q -> dK */ @@ -1098,14 +1325,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder() .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0.0f) .build(); - auto matmul_op4 = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dPTransposeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(dKTensor) - .setmatmulDesc(matmul_4_Desc) - .build(); + + auto&& matmul_op4_builder = + cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR); + + matmul_op4_builder.setaMatDesc(dPTransposeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(dKTensor) + .setmatmulDesc(matmul_4_Desc); + + if (padding_aware) { + matmul_op4_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor); + } + + auto matmul_op4 = matmul_op4_builder.build(); ops.push_back(std::move(matmul_op4)); @@ -1153,29 +1388,36 @@ void fused_attn_arbitrary_seqlen_bwd_impl( // Exit to request upper level API to allocate memory if needed size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float); - size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float); + size_t dqAccum_workspace_size = use_workspace_opt ? 0 : b * s_q * h * d * sizeof(float); + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); if (workspace == nullptr) { - if (use_workspace_opt) { - *workspace_size = plan_workspace_size + softmaxSum_workspace_size; - } else { - *workspace_size = plan_workspace_size + softmaxSum_workspace_size - + dqAccum_workspace_size; - } + *workspace_size = plan_workspace_size + softmaxSum_workspace_size + + dqAccum_workspace_size + actual_seqlen_workspace_size; return; } void *devPtrSoftmaxSum = static_cast(workspace) + plan_workspace_size; - void *devPtrdQAccumulator = nullptr; - if (!use_workspace_opt) { - devPtrdQAccumulator = static_cast(devPtrSoftmaxSum) + void *devPtrdQAccumulator = static_cast(devPtrSoftmaxSum) + softmaxSum_workspace_size; + if (!use_workspace_opt) { NVTE_CHECK_CUDA(cudaMemsetAsync( devPtrdQAccumulator, 0, dqAccum_workspace_size, stream)); } + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void *devActualSeqlenQ = + static_cast(devPtrdQAccumulator) + dqAccum_workspace_size; + void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, static_cast(devPtrCuSeqlenQ), + static_cast(devPtrCuSeqlenKV), + static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); + NVTE_CHECK_CUDA(cudaGetLastError()); + std::set> data_ptrs; // add all the data pointers to be used in the variant pack - float negInfinity = -1.0E+10f; + float negInfinity = -1.0E+31f; float scale_dropout = 1.0f/(1.0f - dropout_probability); data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); if (!use_workspace_opt) { @@ -1194,6 +1436,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( data_ptrs.insert(std::pair(D_SEED_ID, devPtrDropoutSeed)); data_ptrs.insert(std::pair(D_OFFSET_ID, devPtrDropoutOffset)); data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); + if (padding_aware) { + data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); + data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); + } float scaleProb = 1.0f - dropout_probability; data_ptrs.insert(std::pair(D_CONST_ID, &scale_dropout)); @@ -1254,6 +1500,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } + void *devPtrCuSeqlens = cu_seqlens->data.dptr; + void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); @@ -1262,8 +1510,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, + is_training, attn_scale, p_dropout, qkv_layout, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1318,6 +1567,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); + void *devPtrCuSeqlens = cu_seqlens->data.dptr; + const auto qkv_type = input_QKV->data.dtype; size_t workspace_size = 0; @@ -1349,9 +1600,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, #endif fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, - attn_scale, p_dropout, qkv_layout, + attn_scale, p_dropout, qkv_layout, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(qkv_type), workspace->data.dptr, &workspace_size, stream, handle, use_workspace_opt); @@ -1412,11 +1664,15 @@ void fused_attn_arbitrary_seqlen_fwd( void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, - is_training, attn_scale, p_dropout, qkv_layout, + is_training, attn_scale, p_dropout, qkv_layout, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1467,6 +1723,9 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; + void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + size_t workspace_size = 0; bool use_workspace_opt = false; @@ -1497,9 +1756,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m #endif fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, - attn_scale, p_dropout, qkv_layout, + attn_scale, p_dropout, qkv_layout, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle, use_workspace_opt); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 663ff37187..bfda5fec30 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -298,7 +298,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 /////////////////// Apply the mask ////////////////////////// - auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ? std::move(causalMaskTensor) : std::move(paddingMaskTensor); @@ -314,7 +315,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 ops.push_back(std::move(lessThanRow_op)); ops.push_back(std::move(lessThanCol_op)); ops.push_back(std::move(paddingMaskAnd_op)); - if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) { + if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { ops.push_back(std::move(rowGreaterCol_op)); ops.push_back(std::move(causalMaskAnd_op)); } @@ -680,7 +682,8 @@ void fused_attn_max_512_fwd_impl( // WAR: causal_mask without bias needs memset the S buffer // inference mode doesn't need the S auxiliary auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || - (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training; + (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && is_training; std::shared_ptr maskInput; auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6de3c63512..46f4ec6a1c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -146,6 +146,8 @@ enum NVTE_Mask_Type { NVTE_PADDING_MASK = 1, /*! Causal attention mask */ NVTE_CAUSAL_MASK = 2, + /*! Padding and causal attention mask */ + NVTE_PADDING_CAUSAL_MASK = 3, }; /*! \enum NVTE_Fused_Attn_Backend @@ -209,10 +211,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | - | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -254,10 +256,10 @@ void nvte_fused_attn_fwd_qkvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | - | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -308,8 +310,8 @@ void nvte_fused_attn_bwd_qkvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. @@ -356,8 +358,8 @@ void nvte_fused_attn_fwd_kvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. @@ -415,10 +417,10 @@ void nvte_fused_attn_bwd_kvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | - | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor. @@ -467,10 +469,10 @@ void nvte_fused_attn_fwd( * * Support Matrix: \verbatim - | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | - | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | + | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 | + | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor. diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 5c8e534126..d9042ff664 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -83,7 +83,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 89da212367..6ef520a8c2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -102,7 +102,7 @@ def _combine_biases(*masks: List[Array]): return mask -class Softmax(nn.Module): +class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. The input's shape should be [batch, heads, q_seqlen, k_seqlen]. @@ -176,7 +176,7 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp return outputs -class LayerNorm(nn.Module): +class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods r""" Applies layer normalization over a mini-batch of inputs. There are two types of normalization supported by this module, @@ -431,8 +431,9 @@ def __call__(self, inputs: Array) -> Array: bias = nn_partitioning.param_with_axes('bias', self.bias_init, features, - self.dtype, + jnp.float32, axes=self.bias_axes) + bias = bias.astype(self.dtype) else: bias = None @@ -656,8 +657,9 @@ def __call__(self, inputs: Array) -> Array: bias = nn_partitioning.param_with_axes('bias', self.bias_init, features, - self.dtype, + jnp.float32, axes=self.bias_axes) + bias = bias.astype(self.dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -969,8 +971,9 @@ def fp8_meta_generator(): bias = nn_partitioning.param_with_axes('wi_bias', self.bias_init, intermediate_dim, - self.dtype, + jnp.float32, axes=self.bias_axes_1) + bias = bias.astype(self.dtype) bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape x += jnp.reshape(bias, bias_shape) @@ -1029,8 +1032,9 @@ def fp8_meta_generator(): if self.use_bias: bias = nn_partitioning.param_with_axes('wo_bias', self.bias_init, (hidden_size,), - self.dtype, + jnp.float32, axes=self.bias_axes_2) + bias = bias.astype(self.dtype) out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 316bfcdd56..a21b9901ea 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -247,7 +247,7 @@ def core_attention(query: Array, dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) -class MultiHeadAttention(nn.Module): +class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Multi-head Attention (MHA), including Query, Key, Value and Output projection. @@ -422,7 +422,7 @@ def canonicalize_attn_mask_type(attn_mask_type): Convert the string to AttnMaskType """ if attn_mask_type == 'causal': - return AttnMaskType.CAUSAL_MASK + return AttnMaskType.PADDING_CAUSAL_MASK if attn_mask_type == 'padding': return AttnMaskType.PADDING_MASK raise ValueError(f"Unsupported {attn_mask_type=}, " @@ -741,7 +741,7 @@ def convert_to_softmax_type(attn_mask_type, mask): return out, residual -class RelativePositionBiases(nn.Module): +class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods """ T5-style relative positional embeddings to the attention logits. @@ -848,7 +848,7 @@ class TransformerLayerType(Enum): DECODER = "decoder" -class TransformerLayer(nn.Module): +class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods r""" TransformerLayer is made up of a relative embedding, an attention block and a feedforward network (MLP). diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 100ad89f46..a8a6421a89 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -35,6 +35,7 @@ class AttnMaskType(Enum): NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK + PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK class QKVLayout(Enum): From a9cfbfd3841932d60f4b5c78bb48393f1e8e3d31 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 13 Nov 2023 15:13:40 -0800 Subject: [PATCH 5/6] [PyTorch] Improve memory usage in backward of LayerNormLinear and LayerNormMLP (#509) Improve PyTorch memory usage Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/module/layernorm_linear.py | 26 ++++---- .../pytorch/module/layernorm_mlp.py | 59 ++++++++++++++----- transformer_engine/pytorch/module/linear.py | 4 ++ transformer_engine/pytorch/utils.py | 12 ++++ 4 files changed, 75 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 64fc9c2c8c..cec0577d60 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -29,6 +29,7 @@ get_default_init_method, cast_if_needed, assert_dim_for_fp8_exec, + clear_tensor_data, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -40,13 +41,13 @@ ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo - from ._common import _apply_normalization - from ..float8_tensor import Float8Tensor + __all__ = ["LayerNormLinear"] + class _LayerNormLinear(torch.autograd.Function): """LayerNormLinear semi-top level module Calls custom cuda extensions. @@ -355,7 +356,7 @@ def backward( ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: - dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( @@ -393,6 +394,7 @@ def backward( fp8_meta_tensor = meta_tensor, D_dtype = out_te_type, ) + clear_tensor_data(grad_output_c) else: # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( @@ -453,6 +455,7 @@ def backward( ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor ) + clear_tensor_data(ln_out_total_t, grad_output_t) else: ln_out_total_c = tex.cast_from_fp8( ln_out_total, @@ -475,6 +478,7 @@ def backward( ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor ) + clear_tensor_data(ln_out_total_c) else: # WGRAD wgrad, grad_bias, _ = tex.gemm( @@ -490,6 +494,7 @@ def backward( ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) + clear_tensor_data(ln_out_total) if ctx.ub_bulk_wgrad: dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output @@ -501,25 +506,24 @@ def backward( handle.wait() # LayerNorm gradient - d_ln_out = dgrad.view(inputmat.shape) + dgrad = dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output: - d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) + dgrad = dgrad + grad_outputs[1].view_as(dgrad) if ctx.normalization == "LayerNorm": - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, + dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) elif ctx.normalization == "RMSNorm": - dxmat, dgamma = tex.rmsnorm_bwd( - d_ln_out, inputmat, rsigma, ln_weight, + dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, inputmat, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) dbeta = None - if not ctx.use_bias: grad_bias = None @@ -538,7 +542,7 @@ def backward( wgrad = None return ( - dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bda4c309ab..62bde2ff82 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -32,6 +32,7 @@ get_default_init_method, cast_if_needed, assert_dim_for_fp8_exec, + clear_tensor_data, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -276,6 +277,8 @@ def forward( ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) + if not is_grad_enabled: + clear_tensor_data(ln_out_total) gelu_out = activation_func( fc1_out, @@ -283,6 +286,8 @@ def forward( tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, ) + if not is_grad_enabled: + clear_tensor_data(fc1_out) fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( None, None, None, activation_dtype) @@ -329,6 +334,8 @@ def forward( fp8_meta_tensor = fc2_meta_tensor, D_dtype = fc2_te_type, ) + if not is_grad_enabled: + clear_tensor_data(gelu_out) else: # Cast for native AMP fc1_weight = cast_if_needed(fc1_weight, activation_dtype) @@ -360,6 +367,8 @@ def forward( ub=ub_obj_lnout if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None, ) + if not is_grad_enabled: + clear_tensor_data(ln_out_total) if bias_gelu_nvfusion: fc1_out, _, _ = fc1_outputs @@ -373,6 +382,8 @@ def forward( None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype]) + if not is_grad_enabled: + clear_tensor_data(fc1_out) if fp8_calibration: # amax of fc2 input @@ -405,6 +416,8 @@ def forward( ub=ub_obj_fc2out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None, ) + if not is_grad_enabled: + clear_tensor_data(gelu_out) if is_grad_enabled: ctx.save_for_backward( @@ -519,6 +532,7 @@ def backward( ) = TransformerEngineBaseModule.grad_output_preprocess( ctx, grad_outputs[0], True ) + if ctx.ub_bulk_wgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: @@ -571,10 +585,13 @@ def backward( ) if ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + clear_tensor_data(grad_output_c) + # FC2 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if fc2_weight.requires_grad: gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) + clear_tensor_data(gelu_out) fc2_wgrad, _ = tex.fp8_gemm( gelu_out_t, fwd_scale_inverses, @@ -592,6 +609,7 @@ def backward( else None, use_split_accumulator=_2X_ACC_WGRAD, ) + clear_tensor_data(gelu_out_t, grad_output_t) if ctx.activation == 'gelu': fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( @@ -610,6 +628,7 @@ def backward( tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ) + clear_tensor_data(fc1_out) else: if fc2_weight.requires_grad: gelu_out_c = tex.cast_from_fp8( @@ -619,6 +638,7 @@ def backward( fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) + clear_tensor_data(gelu_out) fc2_wgrad, _, _ = tex.gemm( gelu_out_c, grad_output, @@ -632,6 +652,7 @@ def backward( if ctx.fuse_wgrad_accumulation else None, ) + clear_tensor_data(gelu_out_c) if ctx.activation == 'gelu': fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( @@ -642,6 +663,7 @@ def backward( fc1_out, TE_DType[fc2_dgrad.dtype]) fc1_bias_grad = dgelu_no_fp8.sum(dim=0) + clear_tensor_data(fc1_out) dgelu = tex.cast_to_fp8( dgelu_no_fp8, @@ -716,21 +738,24 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + clear_tensor_data(gelu_out) if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu': - fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) + fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) else: - if ctx.activation == 'gelu': - dgelu = fc2_dgrad - else: - dgelu = activation_func(fc2_dgrad, - fc1_out, - TE_DType[fc2_dgrad.dtype]) + if ctx.activation != 'gelu': + fc2_dgrad = activation_func(fc2_dgrad, + fc1_out, + TE_DType[fc2_dgrad.dtype]) # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM # and will not be calculated in case wgrad is not required. if not fc1_weight.requires_grad: - fc1_bias_grad = dgelu.sum(dim=0) + fc1_bias_grad = fc2_dgrad.sum(dim=0) + + # Overwrite data. Deleting the tensor does not release underlying memory. + clear_tensor_data(fc1_out) + dgelu = fc2_dgrad fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) @@ -741,6 +766,7 @@ def backward( fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) + # FC1 DGRAD: Unconditional _ = tex.gemm( fc1_weight, @@ -802,6 +828,7 @@ def backward( ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, ) + clear_tensor_data(ln_out_total_t, dgelu_t) else: ln_out_total_c = tex.cast_from_fp8( ln_out_total, @@ -826,6 +853,7 @@ def backward( ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, ) + clear_tensor_data(ln_out_total_c, dgelu_no_fp8) else: # FC1 WGRAD fc1_wgrad_outputs = tex.gemm( @@ -841,6 +869,7 @@ def backward( ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) + clear_tensor_data(ln_out_total, dgelu) if ctx.bias_gelu_nvfusion: fc1_wgrad, _, _ = fc1_wgrad_outputs @@ -857,20 +886,20 @@ def backward( handle.wait() # LayerNorm gradient - d_ln_out = fc1_dgrad.view(inputmat.shape) + dgrad = fc1_dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output: - d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) + dgrad = dgrad + grad_outputs[1].view_as(dgrad) if ctx.normalization == "LayerNorm": - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, + dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) elif ctx.normalization == "RMSNorm": - dxmat, dgamma = tex.rmsnorm_bwd( - d_ln_out, inputmat, rsigma, ln_weight, + dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, inputmat, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) dbeta = None @@ -904,7 +933,7 @@ def backward( fc2_wgrad = None return ( - dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, fc1_wgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7adfb035d3..3c9c9eacbd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -26,6 +26,7 @@ get_default_init_method, cast_if_needed, assert_dim_for_fp8_exec, + clear_tensor_data, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -431,6 +432,7 @@ def backward( out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ) + clear_tensor_data(inputmat_t_total) else: wgrad, _, _ = gemm( inputmat_total, @@ -442,6 +444,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + clear_tensor_data(inputmat_total) else: # WGRAD wgrad, grad_bias, _ = gemm( @@ -455,6 +458,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + clear_tensor_data(inputmat_total) # Column Parallel Linear if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index fc2be596f1..83b75df281 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,6 +8,18 @@ import torch +def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None: + """ + Trick to deallocate tensor memory when delete operation does not + release the tensor due to PyTorch override. + + Must be used carefully. + """ + for t in tensors: + t.data = torch.Tensor() + del t + + def get_device_compute_capability() -> Tuple[int, int]: """CUDA compute capability of current GPU""" props = torch.cuda.get_device_properties(torch.cuda.current_device()) From 7976bd003fcf084dd068069b92a9a79b1743316a Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Mon, 13 Nov 2023 17:28:35 -0600 Subject: [PATCH 6/6] Update README.rst - Installation section (#502) * Update README.rst - Installation section Added pip install instructions and cleaned up pre-reqs and FlashAttention-2 section Signed-off-by: Santosh Bhavani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Santosh Bhavani Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/README.rst b/README.rst index 492422f250..ba6bd0e112 100644 --- a/README.rst +++ b/README.rst @@ -135,37 +135,47 @@ Installation ---------- .. installation -In the NGC container +Pre-requisites ^^^^^^^^^^^^^^^^^^^^ +* Linux x86_64 +* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada +* NVIDIA Driver supporting CUDA 11.8 or later +* cuDNN 8.1 or later +* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. -The quickest way to get started with Transformer Engine is the NGC PyTorch container on -`NVIDIA GPU Cloud Catalog `_ (versions 22.09 and later). +Docker +^^^^^^^^^^^^^^^^^^^^ + +The quickest way to get started with Transformer Engine is by using Docker images on +`NVIDIA GPU Cloud (NGC) Catalog `_. For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.04-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3 -Where 23.04 is the container version. For example, 23.04 for the April 2023 release. +Where 23.10 is the container version. For example, 23.10 for the October 2023 release. -Pre-requisites +pip ^^^^^^^^^^^^^^^^^^^^ -* Linux x86_64 -* CUDA 11.8 or later -* NVIDIA Driver supporting CUDA 11.8 or later -* cuDNN 8.1 or later -* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. +To install the latest stable version of Transformer Engine, + +.. code-block:: bash + + pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + +This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). From source ^^^^^^^^^^^ +`See the installation guide `_. -`See the installation guide `_. - -Compiling with Flash Attention 2 +Compiling with FlashAttention-2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. + +It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9). -TransformerEngine release v0.11.0 adds support for Flash Attention 2.0 for improved performance. It is a known issue that Flash Attention 2.0 compilation is -resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory -errors during the installation of TransformerEngine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of Flash Attention 1 (v1.0.6 to v1.0.9). +Note that NGC PyTorch 23.08+ containers include FlashAttention-2. Model Support ----------