Skip to content

Commit

Permalink
allow for qk norm to be turned off for na vit nested tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 20, 2024
1 parent f6d7287 commit 24196a3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.7',
version = '1.8.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
13 changes: 7 additions & 6 deletions vit_pytorch/na_vit_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)

Expand All @@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors

self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()

self.dropout = dropout

Expand Down Expand Up @@ -111,13 +111,13 @@ def transpose_head_seq(t):
return self.to_out(out)

class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

Expand Down Expand Up @@ -146,6 +146,7 @@ def __init__(
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
Expand Down Expand Up @@ -184,7 +185,7 @@ def __init__(

self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)

# final attention pooling queries

Expand Down
13 changes: 7 additions & 6 deletions vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)

Expand All @@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors

self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()

self.dropout = dropout

Expand Down Expand Up @@ -123,13 +123,13 @@ def transpose_head_seq(t):
return self.to_out(out)

class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

Expand Down Expand Up @@ -161,6 +161,7 @@ def __init__(
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(

self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)

# final attention pooling queries

Expand Down

0 comments on commit 24196a3

Please sign in to comment.