diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 9ebf962b2a746a..19ca944187e622 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -265,6 +265,7 @@ def forward( return attn_output, attn_weights +# copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 class Idefics3VisionFlashAttention2(Idefics3VisionAttention): """ Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays @@ -386,10 +387,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Idefics3SimpleMLP(nn.Module): - def __init__(self, input_size, output_size): + def __init__(self, config): super().__init__() - self.input_size = input_size - self.output_size = output_size + self.config = config + + input_size=config.vision_config.hidden_size * (config.scale_factor**2) + output_size=config.text_config.hidden_size self.proj = nn.Linear(input_size, output_size, bias=False) def forward(self, x): @@ -550,6 +553,7 @@ def forward( [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ +#Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer @add_start_docstrings( "The Idefics3 Vision Transformer Model outputting raw image embedding.", IDEFICS3_VISION_START_DOCSTRING, @@ -671,20 +675,17 @@ class Idefics3Connector(nn.Module): def __init__(self, config): super().__init__() self.scale_factor = config.scale_factor - self.modality_projection = Idefics3SimpleMLP( - input_size=config.vision_config.hidden_size * (self.scale_factor**2), - output_size=config.text_config.hidden_size, - ) + self.modality_projection = Idefics3SimpleMLP(config) def pixel_shuffle(self, x, scale_factor=2): bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) - x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) - x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) return x def forward(self, image_hidden_states):