Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix swapped indices and names in 2D sincos pos_embed #10877

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Seas0
Copy link

@Seas0 Seas0 commented Feb 23, 2025

What does this PR do?

CURRENTLY STILL A DRAFT, I'M NOT SURE ALL NAMES ARE FIXED!

In get_2d_sincos_pos_embed and get_2d_sincos_pos_embed_from_grid the conventional embedding scheme originated from the mae models seems incorrectly swapped the order of h and w when calling meshgrid, and causing emb_h to actually encode width information, while emb_w to actually encode height information, which, given the nature of permutational equivariance found in Attention mechanism, should not cause any practical differences when training or inferencing, but still confuse people reading the code for a comprehensive understanding of model's internal structure.

This commit should fix the confusion and the swapped names in these functions, while keep being compatibles with pre-trained models using the old code.
Also this commit removed unnecessary stacking and reshaping of the grid tensors.

For reference of the behavior of the meshgrid function, see also: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html https://pytorch.org/docs/stable/generated/torch.meshgrid.html

Also referenced to the original PR to reimplement sincos pos_embed in Pytorch #10156 written kindly by @hlky

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed.

@yiyixuxu

@Seas0
Copy link
Author

Seas0 commented Feb 23, 2025

Sample code for visualizing 2d positional embeddings

# Visualization of the 2D positional embeddings

import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("svg")

from diffusers.models.embeddings import get_2d_sincos_pos_embed

# Create a grid for visualization
h = 16
w = 32
grid_size = (h, w)
embed_dim = 128  # Example embedding dimension
base_size = 16
interpolation_scale = 1.0

# Get positional embeddings
# First half of the hidden dimension is for the width, second half is for the height
pos_embed = get_2d_sincos_pos_embed(
    embed_dim, (h, w), base_size=base_size, interpolation_scale=interpolation_scale
)
pos_embed = pos_embed.reshape(h, w, embed_dim)

# Visualize first few dimensions
num_dims_to_show = 2
assert num_dims_to_show <= embed_dim
assert num_dims_to_show % 2 == 0
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes = axes.flatten()

for i in range(num_dims_to_show):
    # dim = embed_dim//2 - 1 - i if i < num_dims_to_show//2 else embed_dim - 1 - i
    dim = i if i < num_dims_to_show // 2 else embed_dim // 2 + i - num_dims_to_show // 2
    sns.heatmap(
        pos_embed[:, :, dim],
        ax=axes[i],
        cmap="viridis",
        square=True,
        cbar=False,
        vmin=-1,
        vmax=1,
    )
    print(pos_embed[:, :, dim])
    axes[i].set_title(f"Dimension {dim}")
    axes[i].set_xlabel("Width")
    axes[i].set_ylabel("Height")

plt.tight_layout()
plt.show()

In `get_3d_sincos_pos_embed`, `get_2d_sincos_pos_embed`
and `get_2d_sincos_pos_embed_from_grid`
the conventional embedding scheme originated from the MAE model appears to have
incorrectly swapped the order of `h` and `w` when calling `meshgrid`,
causing `emb_h` to actually encode width information,
while `emb_w` to actually encode height information.

This commit should fix the confusion and the swapped names in these functions,
while keep being compatible with pre-trained models using the old code.
Additionally, this commit removed unnecessary stacking and reshaping
of the grid tensors.

For reference of the original code, see also:
https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
https://github.com/facebookresearch/DiT/blob/main/models.py

For reference of the behavior of the `meshgrid` function, see also:
https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
https://pytorch.org/docs/stable/generated/torch.meshgrid.html
@Seas0 Seas0 force-pushed the 2d_sincos_pos_emb_naming_fix branch from 3bab727 to 9f3da6f Compare February 24, 2025 01:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant