diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 586abd0541..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -42,6 +42,7 @@ jobs: || github.actor == 'kocchop' || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index bb3ba209ed..fc8458844b 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -16,7 +16,7 @@ Forward and backward passes with layer weights distributed over all GPUs in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across groups in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2 # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9a11ccc008..4e52153db9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..5bb86c6081 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -20,7 +20,7 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e194a228d2..1538062975 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -20,7 +20,6 @@ from utils import ( make_causal_mask, make_self_mask, - assert_tree_like_allclose, assert_allclose, print_debug_tensor_stats, ) @@ -32,7 +31,6 @@ AttnMaskType, QKVLayout, QKVFormat, - get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn( dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape - qkv_format = get_qkv_format(qkv_layout) + qkv_format = qkv_layout.get_qkv_format() batch, seqlen, num_head, hidden = data_shape @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient _, max_seq_len, num_heads, _ = data_shape gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + if attn_mask_type.is_causal(): gradient_multiplier /= 10 ret_valid = func(*args, **kwargs) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index af05538ef5..759ea893ef 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -28,7 +28,6 @@ QKVFormat, fused_attn, fused_attn_thd, - get_qkv_format, make_swa_mask, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -50,6 +49,7 @@ def init(): yield +@partial(jax.jit, static_argnums=(5, 6, 7, 9)) def general_dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -102,29 +102,36 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - -def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: +@jax.jit +def make_causal_mask( + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike = None, + segment_pos_kv: ArrayLike = None, +) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding position to participate in attention and `False` means masking out that position. + If segment_pos is not provided, aragne of the segment_ids will be applied. """ - q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) - kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) - inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal) return inv_causal_mask +@partial(jax.jit, static_argnums=(4, 5)) def make_mask( - q_token: ArrayLike, - kv_token: ArrayLike, - segment_pad_q: ArrayLike, - segment_pad_kv: ArrayLike, + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike, + segment_pos_kv: ArrayLike, attn_mask_type: AttnMaskType, window_size: Optional[Tuple[int, int]] = None, ) -> Array: @@ -132,18 +139,31 @@ def make_mask( Create attention mask based on mask type. A `True` value in the mask means masking out the corresponding position and a `False` value means allowing that position to participate in attention. + + - segment_ids should start with 1, and using 0s for the paddings. + Expected that each segment starts without paddings. + - segment_pos marks the token position in the segments. + + A example pair of segments_ids and segment_pos: + segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] + segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ inv_mask = make_attention_mask( - q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) + segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): - inv_causal_mask = make_causal_mask(q_token, kv_token) - inv_mask = combine_masks(inv_causal_mask, inv_mask) - if segment_pad_q is not None and segment_pad_kv is not None: - inv_pad_mask = make_attention_mask( - segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) + if attn_mask_type.is_causal(): + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask( + segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) - inv_mask = combine_masks(inv_pad_mask, inv_mask) + inv_mask = combine_masks(inv_causal_mask, inv_mask) if window_size is not None: max_seqlen_q = inv_mask.shape[-2] @@ -157,7 +177,8 @@ def make_mask( return mask -def get_seqlens_and_offsets(segment_ids, segment_pad): +@jax.jit +def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) @@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) - first_column = jnp.ones((x.shape[0], 1), dtype=bool) + first_column = x[..., :1] != 0 same_as_previous = jnp.hstack((first_column, same_as_previous)) return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( same_as_previous @@ -173,13 +194,9 @@ def _find_offsets(x): offsets = _find_offsets(segment_ids) offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - if segment_pad is not None: - segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) - padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) - output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1) - else: - output = jnp.insert(seqlens, -1, values=0, axis=-1) - return output, offsets + seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + seqlens = jnp.where(seqlens, seqlens, -1) + return seqlens, offsets @jax.jit @@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): query, key, value, - bias=bias, - mask=mask, + bias, + mask, deterministic=not kwargs["is_training"], scale_factor=kwargs["scaling_factor"], dropout_rate=kwargs["dropout_probability"], @@ -228,7 +245,6 @@ def customcall_fused_dpa( TE customcall dot product attention implementation """ qkv_layout = kwargs["qkv_layout"] - is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD match qkv_layout: case QKVLayout.BS3HD | QKVLayout.T3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -242,7 +258,7 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not is_thd: + if not qkv_layout.is_thd(): kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( @@ -262,10 +278,10 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = "1HSS" - BIAS_B1SS = "B1SS" - BIAS_BHSS = "BHSS" - BIAS_11SS = "11SS" + _1HSS = "1HSS" + _B1SS = "B1SS" + _BHSS = "BHSS" + _11SS = "11SS" @dataclass @@ -300,18 +316,12 @@ def _get_max_segments_per_sequence(self): def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available - if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ - AttnMaskType.PADDING_MASK, - AttnMaskType.PADDING_CAUSAL_MASK, - ]: + if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") - qkv_format = get_qkv_format(self.qkv_layout) - if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") - - if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: if self.num_heads_q != self.num_heads_kv: pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") @@ -339,15 +349,11 @@ def _check_configs(self): if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS - and self.bias_shape != BiasShape.BIAS_1HSS + and self.bias_shape != BiasShape._1HSS ): - if self.attn_mask_type not in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - ]: + if self.attn_mask_type.is_padding(): pytest.skip( - "B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask" ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip( @@ -370,18 +376,18 @@ def _setup_inputs(self): if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None - elif self.bias_shape == BiasShape.BIAS_1HSS: + elif self.bias_shape == BiasShape._1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_B1SS: + elif self.bias_shape == BiasShape._B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_BHSS: + elif self.bias_shape == BiasShape._BHSS: bias_shape = ( self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv, ) - elif self.bias_shape == BiasShape.BIAS_11SS: + elif self.bias_shape == BiasShape._11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") @@ -391,7 +397,7 @@ def _setup_inputs(self): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.bias_shape == BiasShape._1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for @@ -408,10 +414,10 @@ def _setup_inputs(self): else: self.bias = None - if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pad_ratio = 0.0 - else: + if self.attn_mask_type.is_padding(): pad_ratio = 0.3 + else: + pad_ratio = 0.0 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) @@ -425,6 +431,8 @@ def generate_random_segment_ids( rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad segment_ids = np.zeros((batch_size, sequence_length), dtype=int) + segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad segment_pad = np.zeros((batch_size, sequence_length), dtype=int) @@ -440,58 +448,62 @@ def generate_random_segment_ids( break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id + segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: num_valid = rng.integers(1, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 segment_pad[i, current_pos:sequence_length] = 1 - return segment_ids, segment_pad - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: + segment_ids, segment_pos, segment_pad = map( + jnp.asarray, [segment_ids, segment_pos, segment_pad] + ) + segment_ids = jnp.where(segment_pad, 0, segment_ids) + return segment_ids, segment_pos, segment_pad + + if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 - self.token_q, self.segment_pad_q = generate_random_segment_ids( + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - # TODO(rewang): Check if qkvpacked supported different q/kv - # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): - self.token_kv = self.token_q - self.segment_pad_kv = self.segment_pad_q + if self.qkv_layout == QKVLayout.T3HD: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q else: - self.token_kv, self.segment_pad_kv = generate_random_segment_ids( + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, ) - self.pad_q = self.segment_pad_q - self.pad_kv = self.segment_pad_kv + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 - self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) - self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) - self.segment_pad_q = self.segment_pad_kv = None + self.segment_ids_q, self.pad_q = gen_valid( + self.batch_size, self.max_seqlen_q, pad_ratio + ) + self.segment_ids_kv, self.pad_kv = gen_valid( + self.batch_size, self.max_seqlen_kv, pad_ratio + ) + self.segment_pos_q = self.segment_pos_kv = None + self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + # For reference code self.mask = make_mask( - self.token_q, - self.token_kv, - self.segment_pad_q, - self.segment_pad_kv, + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, self.attn_mask_type, self.window_size, ) - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets( - self.token_q, self.segment_pad_q - ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.token_kv, self.segment_pad_kv - ) + if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None self.mask_for_customcall = self.mask self.dropout_rng = dropout_key if self.dropout_prob > 0 else None @@ -547,13 +559,11 @@ def test_backward(self): """ self._setup_inputs() - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS: - pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.") def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid = jnp.where( @@ -586,7 +596,7 @@ def grad_func(func, *args, **kwargs): } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) + arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -629,7 +639,7 @@ def check_dqkv(primitive, reference, pad): check_dqkv(primitive_dk, reference_dk, self.pad_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv) - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad): ) -@pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), - ], -) @pytest.mark.parametrize( "attn_mask_type", [ @@ -736,6 +736,16 @@ class TestFusedAttn: pytest.param(False, id="INFERENCE"), ], ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), + ], + ) def _test_forward( b, s_q, @@ -779,6 +789,13 @@ def _test_forward( runner.test_forward() @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) def test_backward( b, s_q, diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 78a6225e1f..242bafa5e2 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -19,7 +19,11 @@ from jax import nn as jax_nn from jax import random as jax_random -from transformer_engine.jax.attention import AttnMaskType, make_swa_mask +from transformer_engine.jax.attention import ( + AttnMaskType, + canonicalize_attn_mask_type, + make_swa_mask, +) from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -913,15 +917,7 @@ def apply_swa_mask( window_size: Tuple[int, int] = (-1, -1), ) -> Array: """Apply the sliding window mask to a given mask""" - mask_map = { - "no_mask": AttnMaskType.NO_MASK, - "padding": AttnMaskType.PADDING_MASK, - "causal": AttnMaskType.CAUSAL_MASK, - "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, - "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - } - _attn_mask_type = mask_map.get(attn_mask_type, None) + _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py new file mode 100644 index 0000000000..0f00a6717b --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -0,0 +1,181 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import argparse + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn, optim +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from contextlib import nullcontext + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(SimpleNet, self).__init__() + self.fc1 = te.Linear(input_size, hidden_size) + self.fc2 = te.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--iter", type=int, default=10, help="Number of iterations for forward pass" + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + # Adding hsdp_dim as a list argument, comma-separated + parser.add_argument( + "--sharding-dims", + type=int, + nargs="+", + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', + ) + args = parser.parse_args(argv, namespace) + if args.sharding_dims: + assert len(args.sharding_dims) <= 2 + return args + + +sub_modules_to_wrap = [te.Linear] + + +def _train(args): + assert "TORCHELASTIC_RUN_ID" in os.environ + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + assert LOCAL_SIZE == WORLD_SIZE + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + device = torch.device(f"cuda:{LOCAL_RANK}") + + # FP8 Configuration + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext + build_model_context_args = {} + + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + else: + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + # Move the model to the correct device + + model.to(device) + + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") + # Creating a DeviceMesh for fully_shard + world_size = int(WORLD_SIZE) + device_ids = list(range(world_size)) + if LOCAL_RANK == 0: + print(f"sharding-dims:{args.sharding_dims}") + # Setup the sharding mesh for FSDP/HSDP + if args.sharding_dims == None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 1: + assert args.sharding_dims[0] == device_ids[-1] + 1 + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 2: # HSDP + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + mesh = init_device_mesh( + "cuda", + (args.sharding_dims[0], args.sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + + # Apply FSDP/HSDP + custom_attrs = save_custom_attrs(model) + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for iteration in range(args.iter): + # Zero the parameter gradients + optimizer.zero_grad() + input_data = torch.randn(args.batch_size, args.input_size).to(device) + output = model(input_data) + target = torch.randn(args.batch_size, args.output_size).to(device) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + + dist.destroy_process_group() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Done...") + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py new file mode 100644 index 0000000000..3c9197c322 --- /dev/null +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import subprocess +from pathlib import Path +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import torch +from packaging.version import Version as PkgVersion + + +def get_torch_version(): + """Get pytorch version from __version__""" + + def get_torch_version_str(): + import torch + + return str(torch.__version__) + + return PkgVersion(get_torch_version_str()) + + +if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs.") + +if torch.cuda.device_count() % 2 != 0: + pytest.skip("Number of device should be divided by 2.") + +if not get_torch_version() >= PkgVersion("2.4"): + pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(fp_init, sharding_dims): + test_path = TEST_ROOT / "run_fsdp2_model.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + if fp_init: + test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: + test_cmd += ["--sharding-dims", str(sharding_dims[0])] + elif len(sharding_dims) == 2: + test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] + else: + assert False + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0: + raise AssertionError(result.stderr.decode()) + + +all_boolean = [True, False] +sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] + + +@pytest.mark.parametrize("sharding_dims", sharding_dims) +@pytest.mark.parametrize("fp8_init", all_boolean) +def test_distributed(fp8_init, sharding_dims): + if fp8_init and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4e995dabb1..dea31b5971 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "swa_6_0": ModelConfig( + 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_1": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), } diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 51f4c695dc..a25ffa773c 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -339,7 +339,7 @@ def test_serialization( del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(io.BytesIO(x_bytes)) + x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False) del x_bytes # Check results diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4f057c12fe..32d517460a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1101,7 +1101,7 @@ def get_model(dtype, config): del block block = get_model(dtype, config) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) torch.set_rng_state(_cpu_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 7bf8fb99d5..be77109cb7 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -124,7 +124,7 @@ def forward(self, inp, weight): torch.save(model_in.state_dict(), tmp_filename) model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename)) + model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) model_out.eval() # scaling fwd @@ -263,7 +263,7 @@ def test_fp8_model_checkpoint( # to load the fp8 metadata before loading tensors. # # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that loaded model matches saved model @@ -450,7 +450,7 @@ def train_step( torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that new model's FP8 metadata matches saved model diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index f242502261..b706eadace 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -661,6 +661,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3ecc9bcd75..53451b6a78 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -46,6 +46,42 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + +class QKVFormat(Enum): + """ + SBHD: q,k,v memory layout with [s, b, ..., h, d] + BSHD: q,k,v memory layout with [b, s, ..., h, d] + THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. + """ + + SBHD = NVTE_QKV_Format.NVTE_SBHD + BSHD = NVTE_QKV_Format.NVTE_BSHD + THD = NVTE_QKV_Format.NVTE_THD + class QKVLayout(Enum): """ @@ -66,17 +102,35 @@ class QKVLayout(Enum): THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD - -class QKVFormat(Enum): - """ - SBHD: q,k,v memory layout with [s, b, ..., h, d] - BSHD: q,k,v memory layout with [b, s, ..., h, d] - THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. - """ - - SBHD = NVTE_QKV_Format.NVTE_SBHD - BSHD = NVTE_QKV_Format.NVTE_BSHD - THD = NVTE_QKV_Format.NVTE_THD + def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ + return QKVFormat(nvte_get_qkv_format(self.value)) + + def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ + return self in [QKVLayout.BS3HD, QKVLayout.T3HD] + + def is_kvpacked(self): + """ + Return True if the key, value is packed + """ + return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] + + def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ + return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] + + def is_thd(self): + """ + Return True if the layout belongs to THD + """ + return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] class CPStrategy(Enum): @@ -92,13 +146,6 @@ class CPStrategy(Enum): RING = 2 -def get_qkv_format(qkv_layout): - """ - Get qkv_format from qkv_layout - """ - return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) - - def make_swa_mask( max_seqlen_q: int, max_seqlen_kv: int, @@ -136,12 +183,8 @@ def make_swa_mask( swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) if window_size is None: return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: + if attn_mask_type.is_bottom_right(): if left_window < 0: left_window = max_seqlen_kv if right_window < 0: @@ -310,7 +353,7 @@ def fused_attn( (jnp.ndarray): The output tensor from the fused attention. """ assert ( - get_qkv_format(qkv_layout) != QKVFormat.THD + not qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." # Check inputs qkv @@ -327,11 +370,7 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -448,7 +487,7 @@ def fused_attn_thd( QKVLayout.T3HD, 0.125, 0, True, 3) """ assert ( - get_qkv_format(qkv_layout) == QKVFormat.THD + qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." # Check inputs qkv diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..7f09e6f900 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6591861057..f3dfca21ef 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce, cache +from functools import partial, reduce import operator import os from typing import Optional, Tuple @@ -133,7 +133,6 @@ def get_fused_attn_backend(self): ) @staticmethod - @cache def is_non_deterministic_allowed(): """Check if non-deterministic kernels are allowed""" return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..3715e6f20c 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 69d7962b62..8ad7ee4fcb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..67053ecd8e 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +237,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a319b74d76..a986b91b30 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -61,26 +61,23 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization dict["te_layernorm_forward_ffi"] = diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8c529c58d0..be0d176520 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1024,27 +1024,51 @@ def swap_key_value_dict(self, batch_indices): @torch.no_grad() -def get_swa_mask( - window_size: Tuple[int, int], +def get_full_mask( max_seqlen_q: int, max_seqlen_kv: int, attn_mask_type: str = "no_mask", - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, ) -> torch.Tensor: """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. - For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, - and for other mask types, the bottom right corner. + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] Parameters ---------- - window_size: Tuple[int, int] - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. max_seqlen_q: int Maximum sequence length for queries. max_seqlen_kv: int @@ -1052,33 +1076,105 @@ def get_swa_mask( attn_mask_type: str, default = `no_mask` Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` - Boolean tensor(s) used to mask out attention softmax input. + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". Returns ---------- + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` attention_mask: torch.Tensor - Combined `attention_mask` (input) and sliding window attention mask. - The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; - else, the same shape as input `attention_mask`. + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. """ - mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") - if attn_mask_type in ["causal"]: - left = window_size[0] if window_size[0] != -1 else max_seqlen_q - right = window_size[1] if window_size[1] != -1 else max_seqlen_q - mask_upper = torch.triu(mask, diagonal=-left) - mask_lower = torch.tril(mask_upper, diagonal=right) - else: - left = window_size[0] if window_size[0] != -1 else max_seqlen_kv - right = window_size[1] if window_size[1] != -1 else max_seqlen_kv - mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) - mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) - attn_mask_type = "arbitrary" - mask = mask_lower.logical_not() + # perform basic checks + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if attention_type == "self": + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment + ): + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment + ): + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment + ): + batch_size = attention_mask.shape[0] + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) - return attn_mask_type, mask + attention_mask = torch.logical_or(swa_mask, attention_mask) + else: + attention_mask = swa_mask + + # change mask type + if change_type: + attn_mask_type = "arbitrary" + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv @torch.no_grad() @@ -4733,6 +4829,7 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + window_size: Optional[Tuple[int, int]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -4752,53 +4849,15 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type: - if self.attention_type == "self": - assert attention_mask.shape == ( - batch_size, - 1, - 1, - max_seqlen_q, - ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - assert ( - len(attention_mask) == 2 - and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) - and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) - ), ( - "attention_mask should be a tuple of two tensors with shapes " - "[b, 1, 1, sq] and [b, 1, 1, skv]!" - ) - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - mask = attention_mask.squeeze(1).logical_not() - actual_seqlens_q = mask[:, :, 0].sum(dim=1) - actual_seqlens_kv = mask[:, 0, :].sum(dim=1) - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if attn_mask_type == "padding_causal": - attention_mask = torch.logical_or( - torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), - attention_mask, - ) - if attn_mask_type == "padding_causal_bottom_right": - attention_mask = torch.logical_or( - torch.where( - mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - < 0, - 1, - 0, - ), - attention_mask, - ) + + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -8274,12 +8333,6 @@ def forward( ) if use_unfused_attention: - if window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ): - attn_mask_type, attention_mask = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask - ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -8291,6 +8344,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -8304,6 +8358,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ace68a222..414e819f53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -24,6 +24,19 @@ aten = torch.ops.aten updated_fp8_params = {} +_ops_to_preserve_subclass_in_fsdp2 = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, +} + def _make_fp8_attr_property_funcs(name: str) -> Any: """Make accessors for an FP8 attribute @@ -430,6 +443,37 @@ def __new__( return self + def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument + """ + A hook function used in torch fsdp2, called before all-gather + return (all-gather input), (metadata) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + + return (self._data,), (self,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, # pylint: disable=unused-argument + *, + out: Optional[torch.Tensor] = None, + ): + """ + A hook function used in torch fsdp2, called after all-gather + return (Float8Tensor class instance of all-gathered input), (Things to free after forward) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + (data,) = all_gather_outputs + (sample,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + return None + return Float8Tensor.make_like(sample, data=data), all_gather_outputs + @classmethod def make_like( cls, @@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8Tensor.make_like(tensor, data=data_view) - # Default case + # Related to FSDP2 + if func == aten.split.Tensor: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + if func == aten.new_zeros.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.as_strided.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.detach.default: + return cls.detach(args[0]) + if func == torch.ops.aten.clone.default: + return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: + # Implementation in the superclass (QuantizedTensor) returns a proper output + pass + elif func in _ops_to_preserve_subclass_in_fsdp2: + # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 + warnings.warn( + f"A function call({func}) in {cls} may not return {cls} tensor as an output. It" + " might cause an error in torch FSDP2!" + ) + else: + pass + return super().__torch_dispatch__(func, types, args, kwargs) @classmethod