Skip to content

Commit bca88e9

Browse files
committed
address #300
1 parent 96f66d2 commit bca88e9

File tree

2 files changed

+172
-1
lines changed

2 files changed

+172
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.6.7',
9+
version = '1.6.8',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from packaging import version
2+
from collections import namedtuple
3+
4+
import torch
5+
from torch import nn
6+
import torch.nn.functional as F
7+
from torch.nn import Module, ModuleList
8+
9+
from einops import rearrange
10+
from einops.layers.torch import Rearrange
11+
12+
# constants
13+
14+
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
15+
16+
# helpers
17+
18+
def pair(t):
19+
return t if isinstance(t, tuple) else (t, t)
20+
21+
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
22+
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
23+
24+
z, y, x = torch.meshgrid(
25+
torch.arange(f, device = device),
26+
torch.arange(h, device = device),
27+
torch.arange(w, device = device),
28+
indexing = 'ij')
29+
30+
fourier_dim = dim // 6
31+
32+
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
33+
omega = 1. / (temperature ** omega)
34+
35+
z = z.flatten()[:, None] * omega[None, :]
36+
y = y.flatten()[:, None] * omega[None, :]
37+
x = x.flatten()[:, None] * omega[None, :]
38+
39+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
40+
41+
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
42+
return pe.type(dtype)
43+
44+
# main class
45+
46+
class Attend(Module):
47+
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
48+
super().__init__()
49+
self.config = config
50+
self.use_flash = use_flash
51+
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52+
53+
def flash_attn(self, q, k, v):
54+
# flash attention - https://arxiv.org/abs/2205.14135
55+
56+
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
57+
out = F.scaled_dot_product_attention(q, k, v)
58+
59+
return out
60+
61+
def forward(self, q, k, v):
62+
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
63+
64+
if self.use_flash:
65+
return self.flash_attn(q, k, v)
66+
67+
# similarity
68+
69+
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
70+
71+
# attention
72+
73+
attn = sim.softmax(dim=-1)
74+
75+
# aggregate values
76+
77+
out = einsum("b h i j, b j d -> b h i d", attn, v)
78+
79+
return out
80+
81+
# classes
82+
83+
class FeedForward(Module):
84+
def __init__(self, dim, hidden_dim):
85+
super().__init__()
86+
self.net = nn.Sequential(
87+
nn.LayerNorm(dim),
88+
nn.Linear(dim, hidden_dim),
89+
nn.GELU(),
90+
nn.Linear(hidden_dim, dim),
91+
)
92+
def forward(self, x):
93+
return self.net(x)
94+
95+
class Attention(Module):
96+
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
97+
super().__init__()
98+
inner_dim = dim_head * heads
99+
self.heads = heads
100+
self.scale = dim_head ** -0.5
101+
self.norm = nn.LayerNorm(dim)
102+
103+
self.attend = Attend(use_flash = use_flash)
104+
105+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
106+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
107+
108+
def forward(self, x):
109+
x = self.norm(x)
110+
111+
qkv = self.to_qkv(x).chunk(3, dim = -1)
112+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
113+
114+
out = self.attend(q, k, v)
115+
116+
out = rearrange(out, 'b h n d -> b n (h d)')
117+
return self.to_out(out)
118+
119+
class Transformer(Module):
120+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
121+
super().__init__()
122+
self.layers = ModuleList([])
123+
for _ in range(depth):
124+
self.layers.append(ModuleList([
125+
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
126+
FeedForward(dim, mlp_dim)
127+
]))
128+
129+
def forward(self, x):
130+
for attn, ff in self.layers:
131+
x = attn(x) + x
132+
x = ff(x) + x
133+
134+
return x
135+
136+
class SimpleViT(Module):
137+
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
138+
super().__init__()
139+
image_height, image_width = pair(image_size)
140+
patch_height, patch_width = pair(image_patch_size)
141+
142+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
143+
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
144+
145+
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
146+
patch_dim = channels * patch_height * patch_width * frame_patch_size
147+
148+
self.to_patch_embedding = nn.Sequential(
149+
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
150+
nn.LayerNorm(patch_dim),
151+
nn.Linear(patch_dim, dim),
152+
nn.LayerNorm(dim),
153+
)
154+
155+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)
156+
157+
self.to_latent = nn.Identity()
158+
self.linear_head = nn.Linear(dim, num_classes)
159+
160+
def forward(self, video):
161+
*_, h, w, dtype = *video.shape, video.dtype
162+
163+
x = self.to_patch_embedding(video)
164+
pe = posemb_sincos_3d(x)
165+
x = rearrange(x, 'b ... d -> b (...) d') + pe
166+
167+
x = self.transformer(x)
168+
x = x.mean(dim = 1)
169+
170+
x = self.to_latent(x)
171+
return self.linear_head(x)

0 commit comments

Comments
 (0)