Skip to content

Commit

Permalink
few improvements from amy's review
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Aug 12, 2024
1 parent 767c81d commit 0fc4880
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def forward(
return attn_output, attn_weights


# copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2
class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
"""
Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays
Expand Down Expand Up @@ -386,10 +387,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class Idefics3SimpleMLP(nn.Module):
def __init__(self, input_size, output_size):
def __init__(self, config):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.config = config

input_size=config.vision_config.hidden_size * (config.scale_factor**2)
output_size=config.text_config.hidden_size
self.proj = nn.Linear(input_size, output_size, bias=False)

def forward(self, x):
Expand Down Expand Up @@ -550,6 +553,7 @@ def forward(
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

#Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer
@add_start_docstrings(
"The Idefics3 Vision Transformer Model outputting raw image embedding.",
IDEFICS3_VISION_START_DOCSTRING,
Expand Down Expand Up @@ -671,20 +675,17 @@ class Idefics3Connector(nn.Module):
def __init__(self, config):
super().__init__()
self.scale_factor = config.scale_factor
self.modality_projection = Idefics3SimpleMLP(
input_size=config.vision_config.hidden_size * (self.scale_factor**2),
output_size=config.text_config.hidden_size,
)
self.modality_projection = Idefics3SimpleMLP(config)

def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x

def forward(self, image_hidden_states):
Expand Down

0 comments on commit 0fc4880

Please sign in to comment.