diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/configuration.py b/paddleformers/transformers/ernie4_5_moe_vl/model/configuration.py index 59fcc1000b1..c77a3803cb9 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/configuration.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/configuration.py @@ -109,6 +109,7 @@ def __init__( micro_batch_size=-1, use_fused_head_and_loss_fn=False, token_balance_loss=False, + pp_first_stage_layers=0, token_balance_seqlen=False, # calculated based on batchsize and seqlen loss_subbatch_seqlen=32768, cachekv_quant: bool = False, @@ -199,6 +200,7 @@ def __init__( self.fuse_ln = fuse_ln self.use_rmsnorm = use_rmsnorm self.micro_batch_size = micro_batch_size + self.pp_first_stage_layers=pp_first_stage_layers self.max_sequence_length = max_sequence_length self.use_bias = use_bias @@ -562,6 +564,7 @@ def __init__( super().__init__(**kwargs) self.vision_config = DFNRopeVisionTransformerConfig(**vision_config) if vision_config else None + self.audio_config = None self.im_patch_id = im_patch_id self.pixel_hidden_size = pixel_hidden_size self.modality_detach = modality_detach diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling.py b/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling.py index 608dc7feb49..277e7b9c128 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling.py @@ -342,6 +342,7 @@ def __init__(self, config) -> None: Args: config (dict): model configuration """ + print("DFNRopeVisionTransformerConfig", config) super().__init__(config) self.spatial_merge_size = config.spatial_merge_size diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling_pp.py b/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling_pp.py index 928fe8bdffe..b0d4baaf4a1 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling_pp.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/dfnrope/modeling_pp.py @@ -44,6 +44,7 @@ def __init__(self, config, use_full_recompute=False): if self.use_full_recompute: logger.info("use full recompute, vision model will NOT use recompute inner") config.vision_config.recompute = False + print("zhui debug vision config", config.vision_config) super().__init__(config.vision_config) if self.config.tensor_parallel_degree > 1: logger.info("use sp extract feature, vit parameter will be marked as sequence parallel") diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/fusion_ops/common_fusion_ops.py b/paddleformers/transformers/ernie4_5_moe_vl/model/fusion_ops/common_fusion_ops.py index c936474e55a..08eb6d35d59 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/fusion_ops/common_fusion_ops.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/fusion_ops/common_fusion_ops.py @@ -73,6 +73,8 @@ def _fusion_flash_attention( if attn_mask_start_row_indices is not None: if use_sparse_flash_attn: if rr_flash_attn is None: + print("zhui debug:", q.shape, k.shape, v.shape, attn_mask_start_row_indices.shape) + print(attn_mask_start_row_indices) out = flashmask_attention( q, k, @@ -90,8 +92,10 @@ def _fusion_flash_attention( causal=True, ) else: + print("zhui debug attn_mask_start_row_indices", attn_mask_start_row_indices) attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, q.dtype) if rr_flash_attn is None: + print("zhui debug 94", attention_mask.shape if attention_mask is not None else None) out = F.scaled_dot_product_attention( q, k, diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling.py b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling.py index c9d24eaacf5..86cce6f9f09 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling.py @@ -1062,6 +1062,7 @@ def rope_attn( # tensors, so that we can clear the cache tensors for memory efficiency. past_key_value = [key_states, value_states] if use_cache else None seq_length = query_states.shape[1] + print("zhui debug 1065", attention_mask.shape if attention_mask is not None else None) attn_output, attn_weights = self.attn_func( query_states, key_states, diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe.py b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe.py index 6024b41a05e..0bcadc1fe27 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe.py @@ -670,6 +670,7 @@ def forward( use_reentrant=self.config.recompute_use_reentrant, ) else: + print("zhui debug 673", attention_mask.shape if attention_mask is not None else None) (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, diff --git a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe_vl_pp.py b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe_vl_pp.py index efe7183f9cd..d72fdcdca3e 100644 --- a/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe_vl_pp.py +++ b/paddleformers/transformers/ernie4_5_moe_vl/model/modeling_moe_vl_pp.py @@ -927,7 +927,23 @@ def forward(self, args): token_type_ids = token_type_ids.clone() if inbatch_pack_offset is not None: - attn_mask_start_row_indices = inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset) + print("zhui debug modeling pp", inbatch_pack_offset) + if len(inbatch_pack_offset.shape) == 2: + causal_mask_indices, attn_mask_min_start_row = inbatch_pack_offset_to_attn_mask_start_row_indices( + inbatch_pack_offset + ) + attn_mask_start_row_indices = causal_mask_indices.unsqueeze(-1) + else: + # startend_row_indices (inbatch_pack_offset) shape: [batch_size, seq_len, {1, 2, 4}] + assert len(inbatch_pack_offset.shape) == 4, ( + f"inbatch_pack_offset needs to be a 2 or 4-dimensional tensor when use flashmask attention, " + f"but got shape of {inbatch_pack_offset.shape}" + ) + attn_mask_start_row_indices = inbatch_pack_offset + # attn_mask_start_row_indices = inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset) + print("zhui debug modeling pp attn", attn_mask_start_row_indices) + attn_mask_start_row_indices = attn_mask_start_row_indices.astype("int32") + attn_mask_start_row_indices.squeeze_(-1) else: attn_mask_start_row_indices = None