Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EfficientAdditiveAttnetion in QKV interactive perspective. #8

Open
lartpang opened this issue Oct 8, 2023 · 1 comment
Open

EfficientAdditiveAttnetion in QKV interactive perspective. #8

lartpang opened this issue Oct 8, 2023 · 1 comment

Comments

@lartpang
Copy link

lartpang commented Oct 8, 2023

Although the concept of "value" does not appear in the paper description and code implementation, it is actually very similar to the interaction form in MobileVit-V2.

As shown below, I have commented and organized the author's code.

As we can see, this is actually implicitly incorporating the interaction of Q and K into Q's own transformation. The "key" in the code is more like "value".

# https://github.com/Amshaker/SwiftFormer/blob/cd1f854e59f9e010279f8ff657a991d71ed9f13f/models/swiftformer.py#L141C1-L181C19
class EfficientAdditiveAttnetion(nn.Module):
    """
    Efficient Additive Attention module for SwiftFormer.
    Input: tensor in shape [B, N, D]
    Output: tensor in shape [B, N, D]
    """
    def __init__(self, in_dims=512, token_dim=256, num_heads=2):
        super().__init__()
        self.to_query = nn.Linear(in_dims, token_dim * num_heads)
        self.to_key = nn.Linear(in_dims, token_dim * num_heads)
        self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
        self.scale_factor = token_dim ** -0.5
        self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
        self.final = nn.Linear(token_dim * num_heads, token_dim)

    def forward(self, x):
        query = self.to_query(x)
        query = torch.nn.functional.normalize(query, dim=-1) #BxNxD      

        # convert query to the context vector
        query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)       
        A = query_weight * self.scale_factor # BxNx1
        A = torch.nn.functional.normalize(A, dim=1) # BxNx1        

        # similar to the Interaction of query and key in MobileVit-V2, 
        # and here A can be seen as "query" and query as "key"
        G = torch.sum(A * query, dim=1) # BxD
        G = einops.repeat(G, "b d -> b repeat d", repeat=key.shape[1]) # BxNxD
        
        key = self.to_key(x)
        key = torch.nn.functional.normalize(key, dim=-1) #BxNxD
        # here key can be seen as "value"
        out = self.Proj(G * key) + query #BxNxD
        return self.final(out) # BxNxD
@Amshaker
Copy link
Owner

Amshaker commented Oct 8, 2023

Hi @lartpang,
Thank you for your insights. SwiftFormer and MobileViT2 are similar in computing the interactions somehow, we already shown that in the attention comparison's figure. However, there are two major differences:

(1) We are built over Additive Attention, where you have learnable weights to learn where to attend "self.w_g". There is no learnable weights inside the linear attention of MobileViT2.

(2) We eliminate the need of a third interaction "We called it in the paper KV interactions". In MobileViT-2, they share the attention weights "context vector" by using a third branch "V". In our case, we revise this interaction and replace it by linear transformation and "Skip Connection" with the Q matrix. The skip connection acts as sharing the global context weights with the input 'Q', instead of having a third branch.

To summarize, there is common factor between them and we already showed that in the attention comparison's figure, but there are two major differences.

I hope it is clear now.

Best regards,
Abdelrahman.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants