Skip to content

Commit

Permalink
Add ViViT variant with factorized self-attention (#327)
Browse files Browse the repository at this point in the history
* Add FactorizedTransformer

* Add variant param and check in fwd method

* Check if variant is implemented

* Describe new ViViT variant
  • Loading branch information
roydenwa authored Aug 22, 2024
1 parent 5e808f4 commit 9d43e4d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,8 @@ pred = cct(video)

<img src="./images/vivit.png" width="350px"></img>

This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.

```python
import torch
Expand All @@ -1234,7 +1235,8 @@ v = ViT(
spatial_depth = 6, # depth of the spatial transformer
temporal_depth = 6, # depth of the temporal transformer
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
variant = 'factorized_encoder', # or 'factorized_self_attention'
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
Expand Down
72 changes: 54 additions & 18 deletions vit_pytorch/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ def forward(self, x):
x = ff(x) + x
return self.norm(x)

class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

def forward(self, x):
b, f, n, _ = x.shape
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)

return self.norm(x)

class ViT(nn.Module):
def __init__(
self,
Expand All @@ -96,14 +120,16 @@ def __init__(
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
variant = 'factorized_encoder',
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'

num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
Expand All @@ -125,15 +151,20 @@ def __init__(
self.dropout = nn.Dropout(emb_dropout)

self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None

self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant

def forward(self, video):
x = self.to_patch_embedding(video)
Expand All @@ -147,32 +178,37 @@ def forward(self, video):

x = self.dropout(x)

x = rearrange(x, 'b f n d -> (b f) n d')
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')

# attend across space
# attend across space

x = self.spatial_transformer(x)
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)

x = rearrange(x, '(b f) n d -> b f n d', b = b)
# excise out the spatial cls tokens or average pool for temporal attention

# excise out the spatial cls tokens or average pool for temporal attention
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')

x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# append temporal CLS tokens

# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)

if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)


x = torch.cat((temporal_cls_tokens, x), dim = 1)
# attend across time

# attend across time
x = self.temporal_transformer(x)

x = self.temporal_transformer(x)
# excise out temporal cls token or average pool

# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')

x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit 9d43e4d

Please sign in to comment.