Skip to content

Commit

Permalink
fix hardcoded float16 and .float()
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 20, 2024
1 parent 6902c7b commit b1c23db
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,14 @@ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=N
def forward(self, pixel_values, pixel_mask):
if pixel_mask is None:
raise ValueError("No pixel mask provided")
y_embed = pixel_mask.cumsum(1, dtype=torch.float16)
x_embed = pixel_mask.cumsum(2, dtype=torch.float16)
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down Expand Up @@ -708,7 +708,7 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def get_proposal_pos_embed(self, proposals):
temperature = 10000
scale = 2 * math.pi

dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
Expand Down Expand Up @@ -1740,7 +1740,9 @@ def forward(
source = self.input_proj[level](features[-1][0])
else:
source = self.input_proj[level](sources[-1])
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
torch.bool
)[0]
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
sources.append(source)
masks.append(mask)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
Expand Down

0 comments on commit b1c23db

Please sign in to comment.