Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optim deformable detr #33600

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,14 @@ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=N
def forward(self, pixel_values, pixel_mask):
if pixel_mask is None:
raise ValueError("No pixel mask provided")
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down Expand Up @@ -580,11 +580,14 @@ def build_position_encoding(config):


def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down Expand Up @@ -695,6 +698,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -704,7 +708,8 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
Expand Down Expand Up @@ -739,9 +744,11 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

if self.disable_custom_kernels:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
Expand All @@ -755,7 +762,9 @@ def forward(
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -900,6 +909,7 @@ def forward(
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand Down Expand Up @@ -932,6 +942,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -997,6 +1008,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1048,6 +1060,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1219,6 +1232,7 @@ def forward(
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand Down Expand Up @@ -1260,7 +1274,8 @@ def forward(
hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
spatial_shapes_tuple = tuple(spatial_shapes_list)
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
Expand All @@ -1275,6 +1290,7 @@ def forward(
position_embeddings,
reference_points,
spatial_shapes,
spatial_shapes_list,
level_start_index,
output_attentions,
)
Expand All @@ -1285,6 +1301,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1341,6 +1358,7 @@ def forward(
position_embeddings=None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand Down Expand Up @@ -1416,6 +1434,7 @@ def forward(
position_embeddings,
reference_points_input,
spatial_shapes,
spatial_shapes_list,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
Expand All @@ -1428,6 +1447,7 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1589,7 +1609,7 @@ def get_proposal_pos_embed(self, proposals):
temperature = 10000
scale = 2 * math.pi

dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
Expand Down Expand Up @@ -1720,7 +1740,9 @@ def forward(
source = self.input_proj[level](features[-1][0])
else:
source = self.input_proj[level](sources[-1])
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
torch.bool
)[0]
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
sources.append(source)
masks.append(mask)
Expand All @@ -1735,11 +1757,11 @@ def forward(
source_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
spatial_shapes_list = []
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
Expand All @@ -1750,7 +1772,7 @@ def forward(
source_flatten = torch.cat(source_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)

Expand All @@ -1762,6 +1784,7 @@ def forward(
attention_mask=mask_flatten,
position_embeddings=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1819,6 +1842,7 @@ def forward(
encoder_attention_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,11 +583,14 @@ def build_position_encoding(config):

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down Expand Up @@ -699,6 +702,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -708,6 +712,7 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
# Ignore copy
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
Expand Down Expand Up @@ -743,7 +748,7 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

if self.disable_custom_kernels:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -800,11 +800,14 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/oneformer/modeling_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -63,11 +63,14 @@ def _get_clones(module, N):

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,13 +733,14 @@ def forward(self, hidden_state):

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
# Ignore copy
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)

sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down Expand Up @@ -861,9 +862,7 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape

# Ignore copy
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
Expand Down Expand Up @@ -899,7 +898,6 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

# Ignore copy
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(
Expand Down
Loading