Skip to content

Commit

Permalink
add value residual based simple vit
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2024
1 parent e300cdd commit 0b5c9b4
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2152,4 +2152,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.5',
version = '1.8.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
151 changes: 151 additions & 0 deletions vit_pytorch/simple_vit_with_value_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)

# classes

def FeedForward(dim, hidden_dim):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x, value_residual = None):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

if exists(value_residual):
v = v + value_residual

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

return self.to_out(out), v

class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
value_residual = None

for attn, ff in self.layers:

attn_out, values = attn(x, value_residual = value_residual)
value_residual = default(value_residual, values)

x = attn_out + x
x = ff(x) + x

return self.norm(x)

class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

patch_dim = channels * patch_height * patch_width

self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.pool = "mean"
self.to_latent = nn.Identity()

self.linear_head = nn.Linear(dim, num_classes)

def forward(self, img):
device = img.device

x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)

x = self.transformer(x)
x = x.mean(dim = 1)

x = self.to_latent(x)
return self.linear_head(x)

# quick test

if __name__ == '__main__':
v = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)

images = torch.randn(2, 3, 256, 256)

logits = v(images)

0 comments on commit 0b5c9b4

Please sign in to comment.