Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiscale Vision Transformers #40

Closed
wants to merge 69 commits into from
Closed
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
9b717ff
Create multiscale.py
Amapocho Nov 26, 2021
f35e6e8
Update multiscale.py
Amapocho Nov 27, 2021
a3f6095
Create mlp.py
Amapocho Nov 28, 2021
441df76
Delete mlp.py
Amapocho Nov 28, 2021
03fc0f1
Create droppath.py
Amapocho Nov 28, 2021
cd159f8
Create mlp.py
Amapocho Nov 28, 2021
7aef807
Add docstrings
Amapocho Nov 28, 2021
5c8a00e
Update Mlp and DropPath import
Amapocho Nov 28, 2021
fb41ba3
Update docstring
Amapocho Nov 28, 2021
9aa3800
Add import statements
Amapocho Nov 28, 2021
6d38d9d
Add docstring
Amapocho Nov 28, 2021
38d9bc0
Merge branch 'SforAiDl:main' into main
Amapocho Nov 28, 2021
3788eb6
Run pre-commit
Amapocho Nov 28, 2021
db284a8
Update DropPath import
Amapocho Nov 28, 2021
60aa249
Delete droppath.py
Amapocho Nov 28, 2021
8df5268
Delete mlp.py
Amapocho Nov 28, 2021
5a08e33
Update Mlp import
Amapocho Nov 28, 2021
dd72553
Run pre-commit
Amapocho Nov 28, 2021
f144ed7
Update docstring
Amapocho Nov 28, 2021
77e9925
Update mlp usage for feedforward
Amapocho Nov 28, 2021
d6fcc13
Run pre-commit
Amapocho Nov 28, 2021
d716b63
Update docstring
Amapocho Nov 28, 2021
f857764
Remove multiscale block
Amapocho Nov 28, 2021
23a81f8
Update imports
Amapocho Nov 28, 2021
7a41441
Create multiscale.py
Amapocho Nov 28, 2021
f9b58da
Create multiscale.py
Amapocho Jan 4, 2022
a454ede
Update test_attention.py
Amapocho Jan 4, 2022
7731f1f
Delete test_attention.py
Amapocho Jan 4, 2022
ef649b0
Update
Amapocho Jan 4, 2022
6cbe9e5
Merge pull request #1 from SforAiDl/main
Amapocho Jan 4, 2022
6289f17
Create patch_multiscale.py
Amapocho Jan 4, 2022
b3e8506
Update multiscale.py
Amapocho Jan 4, 2022
a0ea7d1
Update multiscale.py
Amapocho Jan 4, 2022
c2bffb8
Update multiscale.py
Amapocho Jan 4, 2022
d616f0c
Update multiscale.py
Amapocho Jan 14, 2022
ebbc04b
Update vformer/encoder/multiscale.py
Amapocho Jan 14, 2022
f57e04a
Update vformer/attention/multiscale.py
Amapocho Jan 14, 2022
290e79f
Update multiscale.py
Amapocho Jan 14, 2022
c9d4044
Update __init__.py
Amapocho Jan 14, 2022
60d9e97
Update multiscale.py
Amapocho Jan 14, 2022
a0d2063
Update __init__.py
Amapocho Jan 14, 2022
930c596
Update test_attention.py
Amapocho Jan 14, 2022
0e61971
Update test_encoder.py
Amapocho Jan 14, 2022
313b2d0
Update multiscale.py
Amapocho Jan 14, 2022
8a9a06c
Update multiscale.py
Amapocho Jan 14, 2022
36010fd
Update test_models.py
Amapocho Jan 14, 2022
558dcff
update
Amapocho Jan 14, 2022
e555aa0
Update vformer/encoder/embedding/patch_multiscale.py
Amapocho Jan 14, 2022
6fac3a7
Update tests
Amapocho Jan 16, 2022
7cc4cea
Update tests
Amapocho Jan 16, 2022
077e2e7
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
6fea53d
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
e43a858
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
372aa69
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
5ad17bd
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
8b0b82e
Update vformer/models/classification/multiscale.py
Amapocho Jan 17, 2022
a0a2e18
Import trunc_normal_
Amapocho Jan 18, 2022
14ddfb8
Create multiscale.py
Amapocho Jan 18, 2022
5b603a2
Update multiscale.py
Amapocho Jan 18, 2022
c40f844
Update multiscale.py
Amapocho Jan 18, 2022
9c14715
Update __init__.py
Amapocho Jan 18, 2022
3947d2c
Update multiscale.py
Amapocho Jan 18, 2022
014c6e5
Update multiscale.py
Amapocho Jan 18, 2022
85fd771
Update multiscale.py
Amapocho Jan 18, 2022
b43eadb
Update
Amapocho Jan 18, 2022
161ca3f
Update multiscale.py
Amapocho Jan 24, 2022
beed6ea
remove logger
Amapocho Jan 25, 2022
4926fed
Update multiscale.py
Amapocho Jan 25, 2022
209b510
Update HEAD_ACT
Amapocho Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def test_CrossAttention():
assert out.shape == test_tensor1.shape
del attention

def test_MultiScaleAttention():

test_tensor1 = torch.randn(96,8,56,56)
test_tensor2 = torch.randn(768,8,14,14)
thw = [2,2,2]

attention = ATTENTION_REGISTRY.get("MultiScaleAttention")(dim=192)
out = attention(test_tensor1, thw)
assert out.shape == (192,8,28,28)
del attention

attention = ATTENTION_REGISTRY.get("VanillaSelfAttention")(dim=768)
out = attention(test_tensor2)
assert out.shape == (768,8,14,14)
del attention


def test_SpatialAttention():

Expand Down
17 changes: 16 additions & 1 deletion tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,22 @@ def test_SwinEncoder():
out = encoder_block(test_tensor)
assert out.shape == test_tensor.shape


def test_MultiScaleBlock():

test_tensor1 = torch.randn(96,8,56,56)
encoder1 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=192)
out1 = encoder1(test_tensor)
assert out1.shape == (192,8,28,28)

test_tensor2 = torch.randn(768,8,14,14)
encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=768)
out2 = encoder2(test_tensor)
assert out2.shape == (768,8,14,14) # shape remains same


del encoder1, encoder2, test_tensor1, test_tensor2


def test_PVTEncoder():

test_tensor = torch.randn(4, 3136, 64)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,15 @@ def test_CrossVit():
assert out.shape == (2, 10)
del model

def test_MultiScale():

model = MODEL_REGISTRY.get("MultiScaleViT")()
out = model(img_3channels_224)
assert out.shape == (8, 400)
del model



def test_pvt():
# classification
model = MODEL_REGISTRY.get("PVTClassification")(
Expand Down
1 change: 1 addition & 0 deletions vformer/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .spatial import SpatialAttention
from .vanilla import VanillaSelfAttention
from .window import WindowAttention
from .multiscale import MultiScaleAttention
270 changes: 270 additions & 0 deletions vformer/attention/multiscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import numpy
import torch
import torch.nn as nn
from ..utils import ATTENTION_REGISTRY

def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
"""
Attention pooling
Parameters:
-----------
tensor: tensor
Input tensor
pool: nn.Module
Pooling function
thw_shape: list of int
Reduced space-time resolution
has_cls_embed: bool, optional
Set to true if classification embeddding is provided
norm : nn.Module, optional
Normalization function
"""

if pool is None:
return tensor, thw_shape
tensor_dim = tensor.ndim
if tensor_dim == 4:
pass
elif tensor_dim == 3:
tensor = tensor.unsqueeze(1)
else:
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")

if has_cls_embed:
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]

B, N, L, C = tensor.shape
T, H, W = thw_shape
tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()

tensor = pool(tensor)

thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
if has_cls_embed:
tensor = torch.cat((cls_tok, tensor), dim=2)
if norm is not None:
tensor = norm(tensor)
# Assert tensor_dim in [3, 4]
if tensor_dim == 4:
pass
else: # tensor_dim == 3:
tensor = tensor.squeeze(1)
return tensor, thw_shape

@ATTENTION_REGISTRY.register()
class MultiScaleAttention(nn.Module):
"""
Multiscale Attention
Parameters:
-----------
dim: int
Dimension of the embedding
num_heads: int
Number of attention heads
qkv_bias :bool, optional
If True, add a learnable bias to query, key, value
drop_rate: float, optional
Dropout rate
kernel_q: tuple of int, optional
Kernel size of query
kernel_kv: tuple of int, optional
Kernel size of key and value
stride_q: tuple of int, optional
Kernel size of query
stride_kv: tuple of int, optional
Kernel size of key and value
norm_layer: nn.Module, optional
Normalization function
has_cls_embed: bool, optional
Set to true if classification embeddding is provided
mode: str, optional
Pooling function to be used. Options include `conv`, `avg`, and `max'
pool_first: bool, optional
Set to True to perform pool before projection
"""

def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
drop_rate=0.0,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
norm_layer=nn.LayerNorm,
has_cls_embed=True,
mode="conv",
pool_first=False,
):
super().__init__()
self.pool_first = pool_first
self.drop_rate = drop_rate
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.has_cls_embed = has_cls_embed
padding_q = [int(q // 2) for q in kernel_q]
padding_kv = [int(kv // 2) for kv in kernel_kv]

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
if drop_rate > 0.0:
self.proj_drop = nn.Dropout(drop_rate)

# Skip pooling with kernel and stride size of (1, 1, 1).
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1:
kernel_q = ()
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1:
kernel_kv = ()

if mode in ("avg", "max"):
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
self.pool_q = (
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
if len(kernel_q) > 0
else None
)
self.pool_k = (
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
self.pool_v = (
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
elif mode == "conv":
self.pool_q = (
nn.Conv3d(
head_dim,
head_dim,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=head_dim,
bias=False,
)
if len(kernel_q) > 0
else None
)
self.norm_q = norm_layer(head_dim) if len(kernel_q) > 0 else None
self.pool_k = (
nn.Conv3d(
head_dim,
head_dim,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=head_dim,
bias=False,
)
if len(kernel_kv) > 0
else None
)
self.norm_k = norm_layer(head_dim) if len(kernel_kv) > 0 else None
self.pool_v = (
nn.Conv3d(
head_dim,
head_dim,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=head_dim,
bias=False,
)
if len(kernel_kv) > 0
else None
)
self.norm_v = norm_layer(head_dim) if len(kernel_kv) > 0 else None
else:
raise NotImplementedError(f"Unsupported model {mode}")

def forward(self, x, thw_shape):
B, N, C = x.shape
if self.pool_first:
x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = k = v = x
else:
q = k = v = x
q = (
self.q(q)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = (
self.k(k)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = (
self.v(v)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)

q, q_shape = attention_pool(
q,
self.pool_q,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_q if hasattr(self, "norm_q") else None,
)
k, k_shape = attention_pool(
k,
self.pool_k,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_k if hasattr(self, "norm_k") else None,
)
v, v_shape = attention_pool(
v,
self.pool_v,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_v if hasattr(self, "norm_v") else None,
)

if self.pool_first:
q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape)
k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape)
v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape)

q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
q = (
self.q(q)
.reshape(B, q_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)

v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
v = (
self.v(v)
.reshape(B, v_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)

k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
k = (
self.k(k)
.reshape(B, k_N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

N = q.shape[2]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
if self.drop_rate > 0.0:
x = self.proj_drop(x)
return x, q_shape


1 change: 1 addition & 0 deletions vformer/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .pyramid import PVTEncoder
from .swin import SwinEncoder, SwinEncoderBlock
from .vanilla import VanillaEncoder
from .multiscale import MultiScaleBlock
47 changes: 47 additions & 0 deletions vformer/encoder/embedding/patch_multiscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
class PatchEmbed(nn.Module):
Amapocho marked this conversation as resolved.
Show resolved Hide resolved
"""
arameters
----------
img_size: int
Image Size

dim_in: int
Number of input channels in the image
dim_out: int
Number of linear projection output channels
kernel: int
kernel Size
stride: int
stride Size
padding: int
padding Size
conv_2d : bool
Use nn.Conv2D if true, nn.conv3D if fals3
"""

def __init__(
self,
dim_in=3,
dim_out=768,
kernel=(1, 16, 16),
stride=(1, 4, 4),
padding=(1, 7, 7),
conv_2d=False,
):
super().__init__()
if conv_2d:
conv = nn.Conv2d
else:
conv = nn.Conv3d
self.proj = conv(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)

def forward(self, x):
x = self.proj(x)
# B C (T) H W -> B (T)HW C
return x.flatten(2).transpose(1, 2)
Loading