From b1c23dba91ee393eb9f18367a558f89e97980dc6 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 20 Sep 2024 02:41:11 +0000 Subject: [PATCH] fix hardcoded float16 and .float() --- .../deformable_detr/modeling_deformable_detr.py | 14 ++++++++------ .../models/rt_detr/modeling_rt_detr.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 29ef8724ad1e44..a315a321a4a860 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -523,14 +523,14 @@ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=N def forward(self, pixel_values, pixel_mask): if pixel_mask is None: raise ValueError("No pixel mask provided") - y_embed = pixel_mask.cumsum(1, dtype=torch.float16) - x_embed = pixel_mask.cumsum(2, dtype=torch.float16) + y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype) + x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype) if self.normalize: eps = 1e-6 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() + dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -708,7 +708,7 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) + total_elements = sum(height * width for height, width in spatial_shapes_list) if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" @@ -1609,7 +1609,7 @@ def get_proposal_pos_embed(self, proposals): temperature = 10000 scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float() + dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # batch_size, num_queries, 4 proposals = proposals.sigmoid() * scale @@ -1740,7 +1740,9 @@ def forward( source = self.input_proj[level](features[-1][0]) else: source = self.input_proj[level](sources[-1]) - mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0] + mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to( + torch.bool + )[0] pos_l = self.backbone.position_embedding(source, mask).to(source.dtype) sources.append(source) masks.append(mask) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 562e5c73a56ae7..69745a1ad969ed 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -862,7 +862,7 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) + total_elements = sum(height * width for height, width in spatial_shapes_list) if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"