Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
micro_batch_size=-1,
use_fused_head_and_loss_fn=False,
token_balance_loss=False,
pp_first_stage_layers=0,
token_balance_seqlen=False, # calculated based on batchsize and seqlen
loss_subbatch_seqlen=32768,
cachekv_quant: bool = False,
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
self.fuse_ln = fuse_ln
self.use_rmsnorm = use_rmsnorm
self.micro_batch_size = micro_batch_size
self.pp_first_stage_layers=pp_first_stage_layers

self.max_sequence_length = max_sequence_length
self.use_bias = use_bias
Expand Down Expand Up @@ -562,6 +564,7 @@ def __init__(
super().__init__(**kwargs)

self.vision_config = DFNRopeVisionTransformerConfig(**vision_config) if vision_config else None
self.audio_config = None
self.im_patch_id = im_patch_id
self.pixel_hidden_size = pixel_hidden_size
self.modality_detach = modality_detach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __init__(self, config) -> None:
Args:
config (dict): model configuration
"""
print("DFNRopeVisionTransformerConfig", config)
super().__init__(config)
self.spatial_merge_size = config.spatial_merge_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, config, use_full_recompute=False):
if self.use_full_recompute:
logger.info("use full recompute, vision model will NOT use recompute inner")
config.vision_config.recompute = False
print("zhui debug vision config", config.vision_config)
super().__init__(config.vision_config)
if self.config.tensor_parallel_degree > 1:
logger.info("use sp extract feature, vit parameter will be marked as sequence parallel")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _fusion_flash_attention(
if attn_mask_start_row_indices is not None:
if use_sparse_flash_attn:
if rr_flash_attn is None:
print("zhui debug:", q.shape, k.shape, v.shape, attn_mask_start_row_indices.shape)
print(attn_mask_start_row_indices)
out = flashmask_attention(
q,
k,
Expand All @@ -90,8 +92,10 @@ def _fusion_flash_attention(
causal=True,
)
else:
print("zhui debug attn_mask_start_row_indices", attn_mask_start_row_indices)
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, q.dtype)
if rr_flash_attn is None:
print("zhui debug 94", attention_mask.shape if attention_mask is not None else None)
out = F.scaled_dot_product_attention(
q,
k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,7 @@ def rope_attn(
# tensors, so that we can clear the cache tensors for memory efficiency.
past_key_value = [key_states, value_states] if use_cache else None
seq_length = query_states.shape[1]
print("zhui debug 1065", attention_mask.shape if attention_mask is not None else None)
attn_output, attn_weights = self.attn_func(
query_states,
key_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def forward(
use_reentrant=self.config.recompute_use_reentrant,
)
else:
print("zhui debug 673", attention_mask.shape if attention_mask is not None else None)
(hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,23 @@ def forward(self, args):

token_type_ids = token_type_ids.clone()
if inbatch_pack_offset is not None:
attn_mask_start_row_indices = inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset)
print("zhui debug modeling pp", inbatch_pack_offset)
if len(inbatch_pack_offset.shape) == 2:
causal_mask_indices, attn_mask_min_start_row = inbatch_pack_offset_to_attn_mask_start_row_indices(
inbatch_pack_offset
)
attn_mask_start_row_indices = causal_mask_indices.unsqueeze(-1)
else:
# startend_row_indices (inbatch_pack_offset) shape: [batch_size, seq_len, {1, 2, 4}]
assert len(inbatch_pack_offset.shape) == 4, (
f"inbatch_pack_offset needs to be a 2 or 4-dimensional tensor when use flashmask attention, "
f"but got shape of {inbatch_pack_offset.shape}"
)
attn_mask_start_row_indices = inbatch_pack_offset
# attn_mask_start_row_indices = inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset)
print("zhui debug modeling pp attn", attn_mask_start_row_indices)
attn_mask_start_row_indices = attn_mask_start_row_indices.astype("int32")
attn_mask_start_row_indices.squeeze_(-1)
else:
attn_mask_start_row_indices = None

Expand Down
Loading