Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding positional encoder changes and tests #32600

Merged
merged 54 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
63e8a34
adding positional encoder changes and tests
Aug 11, 2024
bf6ddf2
adding ruff suggestions
Aug 11, 2024
c1e5058
changes added by python utils/check_copies.py --fix_and_overwrite
Aug 11, 2024
19aaa92
removing pos_encoding added by script
Aug 11, 2024
b282796
adding interpolation to clipseg
Aug 11, 2024
14d6001
formatting
Aug 12, 2024
48128b1
adding further testing to altclip and better documentation to kosmos2
Aug 12, 2024
8eb1beb
skipping test_inputs_embeds_matches_input_ids_with_generate in git model
Aug 12, 2024
7ced086
fixing clipseg comment suggestions
Aug 15, 2024
cac7886
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Aug 15, 2024
a17b554
fixing bridgetower test
Aug 15, 2024
c4e56fb
fixing altclip tensor output POS test
Aug 15, 2024
e303547
adding ruff formatting
Aug 15, 2024
ee8318d
fixing several tests
Aug 16, 2024
20778a3
formatting with ruff
Aug 16, 2024
da9108a
Merge branch 'huggingface:main' into interpolate-clip-b
manuelsh Aug 20, 2024
024ea6e
adding positional encoder changes and tests
Aug 11, 2024
9c645e3
adding ruff suggestions
Aug 11, 2024
19ad494
changes added by python utils/check_copies.py --fix_and_overwrite
Aug 11, 2024
b383517
removing pos_encoding added by script
Aug 11, 2024
578411c
adding interpolation to clipseg
Aug 11, 2024
5517dab
formatting
Aug 12, 2024
34d8999
adding further testing to altclip and better documentation to kosmos2
Aug 12, 2024
48be16e
skipping test_inputs_embeds_matches_input_ids_with_generate in git model
Aug 12, 2024
633310a
fixing clipseg comment suggestions
Aug 15, 2024
9c3ccdd
fixing bridgetower test
Aug 15, 2024
3a62e94
fixing altclip tensor output POS test
Aug 15, 2024
153938f
adding ruff formatting
Aug 15, 2024
ca9682d
fixing several tests
Aug 16, 2024
8567408
formatting with ruff
Aug 16, 2024
9941dbd
adding right pretrained model
Aug 21, 2024
962989d
adding correct pretrained model to git
Aug 21, 2024
09301c5
Merge branch 'huggingface:main' into interpolate-clip-b
manuelsh Aug 21, 2024
b70ab52
Merge branch 'huggingface:main' into interpolate-clip-b
manuelsh Aug 31, 2024
9d05572
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 3, 2024
16363f6
fixing test_inference_image_segmentation
Sep 5, 2024
e3e2272
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 5, 2024
e35729a
fixing test_inference_interpolate_pos_encoding for the git model as t…
Sep 5, 2024
58a02f1
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 5, 2024
fcbf2d2
adding ruff formatting
Sep 5, 2024
d44e070
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 5, 2024
ea54d25
adding new interpolate_pos_encoding function
Sep 15, 2024
9d751a6
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 15, 2024
f36537b
fixing interpolate_POS funciton
Sep 18, 2024
4170cba
adapting output tensor in teests
Sep 18, 2024
5c593bc
fixing conflict to merge
Sep 24, 2024
d00d7b3
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 24, 2024
44f9695
modifying output tensor
Sep 24, 2024
d70c2b3
[run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, ko…
Sep 24, 2024
299b979
adding the correct tensor
Sep 25, 2024
55572b4
[run_slow] clipseg
Sep 25, 2024
d121d89
fixing spaces
Sep 25, 2024
7afedcf
[run_slow] clipseg
Sep 25, 2024
3be2b60
[run_slow] clipseg
Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -137,6 +139,8 @@
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -1009,15 +1013,56 @@ def __init__(self, config: AltCLIPVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand Down Expand Up @@ -1097,6 +1142,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -1111,7 +1157,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
Expand Down Expand Up @@ -1156,6 +1202,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -1186,6 +1233,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1546,6 +1594,7 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1578,6 +1627,7 @@ def get_image_features(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand All @@ -1598,6 +1648,7 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, AltCLIPOutput]:
r"""
Expand Down Expand Up @@ -1642,6 +1693,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
75 changes: 65 additions & 10 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -276,15 +278,56 @@ def __init__(self, config: BridgeTowerVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand All @@ -302,8 +345,13 @@ def __init__(self, config):
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)

def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = self.embeddings(pixel_values)
def forward(
self,
pixel_values: torch.Tensor,
attention_mask,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
Expand All @@ -324,8 +372,12 @@ def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = torch.stack(hidden_states_stack, dim=0)
return hidden_states

def forward_pre(self, pixel_values: torch.Tensor):
hidden_states = self.embeddings(pixel_values)
def forward_pre(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
Expand Down Expand Up @@ -1015,8 +1067,8 @@ def __init__(self, config):
def dtype(self):
return self.visual.embeddings.patch_embedding.weight.dtype

def forward(self, image, image_mask=None):
return self.visual(image.type(self.dtype), image_mask)
def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)


class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
Expand Down Expand Up @@ -1280,6 +1332,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]:
r"""
output_hidden_states (`bool`, *optional*):
Expand Down Expand Up @@ -1352,7 +1405,9 @@ def forward(
all_hidden_states_text += (text_embeds,)

if image_embeds is None:
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
image_embeds = self.vision_model.visual.forward_pre(
pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
)
else:
# Permute as BridgeTowerResidualAttention has batch_first=True
image_embeds = image_embeds.permute(1, 0, 2)
Expand Down
Loading