diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 78dca3be..8e072973 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -284,6 +284,10 @@ enable_profiler: False skip_first_n_steps_for_profiler: 5 profiler_steps: 10 +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + # Generation parameters prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6a578899..cfe3c1fc 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import math from typing import Optional, Callable, Tuple @@ -805,6 +806,7 @@ def __init__( is_self_attention: bool = True, mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, + enable_jax_named_scopes: bool = False, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -820,6 +822,7 @@ def __init__( self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + self.enable_jax_named_scopes = enable_jax_named_scopes if is_self_attention: axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) @@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -966,29 +973,41 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - query_proj = self.query(hidden_states) - key_proj = self.key(encoder_hidden_states) - value_proj = self.value(encoder_hidden_states) + with self.conditional_named_scope("attn_qkv_proj"): + with self.conditional_named_scope("proj_query"): + query_proj = self.query(hidden_states) + with self.conditional_named_scope("proj_key"): + key_proj = self.key(encoder_hidden_states) + with self.conditional_named_scope("proj_value"): + value_proj = self.value(encoder_hidden_states) if self.qk_norm: - query_proj = self.norm_q(query_proj) - key_proj = self.norm_k(key_proj) + with self.conditional_named_scope("attn_q_norm"): + query_proj = self.norm_q(query_proj) + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + if rotary_emb is not None: - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) - value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim - query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + with self.conditional_named_scope("attn_rope"): + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) query_proj = checkpoint_name(query_proj, "query_proj") key_proj = checkpoint_name(key_proj, "key_proj") value_proj = checkpoint_name(value_proj, "value_proj") - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + with self.conditional_named_scope("attn_compute"): + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") - hidden_states = self.proj_attn(attn_output) - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + + with self.conditional_named_scope("attn_out_proj"): + hidden_states = self.proj_attn(attn_output) + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 4dc21d43..5d7aec10 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -15,6 +15,7 @@ """ from typing import Tuple, Optional, Dict, Union, Any +import contextlib import math import jax import jax.numpy as jnp @@ -205,11 +206,13 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, + enable_jax_named_scopes: bool = False, ): if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + self.enable_jax_named_scopes = enable_jax_named_scopes self.act_fn = nnx.data(None) if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( @@ -236,11 +239,17 @@ def __init__( ), ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) - hidden_states = checkpoint_name(hidden_states, "ffn_activation") - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - return self.proj_out(hidden_states) # output is (4, 75600, 5120) + with self.conditional_named_scope("mlp_up_proj_and_gelu"): + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + with self.conditional_named_scope("mlp_down_proj"): + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -265,8 +274,11 @@ def __init__( attention: str = "dot_product", dropout: float = 0.0, mask_padding_tokens: bool = True, + enable_jax_named_scopes: bool = False, ): + self.enable_jax_named_scopes = enable_jax_named_scopes + # 1. Self-attention self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) self.attn1 = FlaxWanAttention( @@ -287,6 +299,7 @@ def __init__( is_self_attention=True, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", + enable_jax_named_scopes=enable_jax_named_scopes, ) # 1. Cross-attention @@ -308,6 +321,7 @@ def __init__( is_self_attention=False, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", + enable_jax_named_scopes=enable_jax_named_scopes, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -322,6 +336,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, dropout=dropout, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) @@ -330,6 +345,10 @@ def __init__( jax.random.normal(key, (1, 6, dim)) / dim**0.5, ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -339,45 +358,59 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ): - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - hidden_states = checkpoint_name(hidden_states, "hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) - - # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( - hidden_states.dtype - ) - attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - rotary_emb=rotary_emb, - deterministic=deterministic, - rngs=rngs, - ) - hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - - # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) - attn_output = self.attn2( - hidden_states=norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - rngs=rngs, - ) - hidden_states = hidden_states + attn_output - - # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( - hidden_states.dtype - ) - ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( - hidden_states.dtype - ) - return hidden_states + with self.conditional_named_scope("transformer_block"): + with self.conditional_named_scope("adaln"): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + ) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + hidden_states = checkpoint_name(hidden_states, "hidden_states") + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + + # 1. Self-attention + with self.conditional_named_scope("self_attn"): + with self.conditional_named_scope("self_attn_norm"): + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("self_attn_attn"): + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("self_attn_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + with self.conditional_named_scope("cross_attn"): + with self.conditional_named_scope("cross_attn_norm"): + norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) + with self.conditional_named_scope("cross_attn_attn"): + attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("cross_attn_residual"): + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + with self.conditional_named_scope("mlp"): + with self.conditional_named_scope("mlp_norm"): + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("mlp_ffn"): + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) + with self.conditional_named_scope("mlp_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states.dtype + ) + return hidden_states class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -416,11 +449,13 @@ def __init__( names_which_can_be_offloaded: list = [], mask_padding_tokens: bool = True, scan_layers: bool = True, + enable_jax_named_scopes: bool = False, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.num_layers = num_layers self.scan_layers = scan_layers + self.enable_jax_named_scopes = enable_jax_named_scopes # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -472,6 +507,7 @@ def init_block(rngs): attention=attention, dropout=dropout, mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -497,6 +533,7 @@ def init_block(rngs): weights_dtype=weights_dtype, precision=precision, attention=attention, + enable_jax_named_scopes=enable_jax_named_scopes, ) blocks.append(block) self.blocks = blocks @@ -517,6 +554,10 @@ def init_block(rngs): kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -536,14 +577,15 @@ def __call__( post_patch_width = width // p_w hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - rotary_emb = self.rope(hidden_states) - with jax.named_scope("PatchEmbedding"): + with self.conditional_named_scope("rotary_embedding"): + rotary_emb = self.rope(hidden_states) + with self.conditional_named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) - hidden_states = jax.lax.collapse(hidden_states, 1, -1) - - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + with self.conditional_named_scope("condition_embedder"): + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: @@ -583,9 +625,10 @@ def layer_forward(hidden_states): hidden_states = rematted_layer_forward(hidden_states) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - - hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) - hidden_states = self.proj_out(hidden_states) + with self.conditional_named_scope("output_norm"): + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + with self.conditional_named_scope("output_proj"): + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7ed8007b..9068d256 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["dropout"] = config.dropout wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers + wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory.