From 9b717ff4d3f9841d670d59d08d27540965142515 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sat, 27 Nov 2021 01:11:30 +0530 Subject: [PATCH 01/67] Create multiscale.py --- vformer/attention/multiscale.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 vformer/attention/multiscale.py diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py new file mode 100644 index 00000000..739d82e9 --- /dev/null +++ b/vformer/attention/multiscale.py @@ -0,0 +1,13 @@ +import numpy + +import torch +import torch.nn as nn + +def pool_attention(input, thw, pool, norm): + dim = input.dim() + if dim == 3: + input = input.unsqueeze(1) + elif dim != 4: + raise NotImplementedError(f"Unsupported input dimension {input.shape}") + T,H,W = thw + B,N,L,C = input.shape From f35e6e8375771ad33de9ec17c77c6a6bebc32ad3 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sat, 27 Nov 2021 11:38:19 +0530 Subject: [PATCH 02/67] Update multiscale.py --- vformer/attention/multiscale.py | 333 +++++++++++++++++++++++++++++++- 1 file changed, 324 insertions(+), 9 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 739d82e9..89a4d0f5 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -1,13 +1,328 @@ import numpy - import torch import torch.nn as nn -def pool_attention(input, thw, pool, norm): - dim = input.dim() - if dim == 3: - input = input.unsqueeze(1) - elif dim != 4: - raise NotImplementedError(f"Unsupported input dimension {input.shape}") - T,H,W = thw - B,N,L,C = input.shape +from slowfast.models.common import DropPath, Mlp + + +def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): + if pool is None: + return tensor, thw_shape + tensor_dim = tensor.ndim + if tensor_dim == 4: + pass + elif tensor_dim == 3: + tensor = tensor.unsqueeze(1) + else: + raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") + + if has_cls_embed: + cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] + + B, N, L, C = tensor.shape + T, H, W = thw_shape + tensor = ( + tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + ) + + tensor = pool(tensor) + + thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] + L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] + tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) + if has_cls_embed: + tensor = torch.cat((cls_tok, tensor), dim=2) + if norm is not None: + tensor = norm(tensor) + # Assert tensor_dim in [3, 4] + if tensor_dim == 4: + pass + else: # tensor_dim == 3: + tensor = tensor.squeeze(1) + return tensor, thw_shape + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + drop_rate=0.0, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + norm_layer=nn.LayerNorm, + has_cls_embed=True, + # Options include `conv`, `avg`, and `max`. + mode="conv", + # If True, perform pool before projection. + pool_first=False, + ): + super().__init__() + self.pool_first = pool_first + self.drop_rate = drop_rate + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.has_cls_embed = has_cls_embed + padding_q = [int(q // 2) for q in kernel_q] + padding_kv = [int(kv // 2) for kv in kernel_kv] + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + if drop_rate > 0.0: + self.proj_drop = nn.Dropout(drop_rate) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: + kernel_q = () + if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: + kernel_kv = () + + if mode in ("avg", "max"): + pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d + self.pool_q = ( + pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) + if len(kernel_q) > 0 + else None + ) + self.pool_k = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 + else None + ) + self.pool_v = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 + else None + ) + elif mode == "conv": + self.pool_q = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=head_dim, + bias=False, + ) + if len(kernel_q) > 0 + else None + ) + self.norm_q = norm_layer(head_dim) if len(kernel_q) > 0 else None + self.pool_k = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=head_dim, + bias=False, + ) + if len(kernel_kv) > 0 + else None + ) + self.norm_k = norm_layer(head_dim) if len(kernel_kv) > 0 else None + self.pool_v = ( + nn.Conv3d( + head_dim, + head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=head_dim, + bias=False, + ) + if len(kernel_kv) > 0 + else None + ) + self.norm_v = norm_layer(head_dim) if len(kernel_kv) > 0 else None + else: + raise NotImplementedError(f"Unsupported model {mode}") + + def forward(self, x, thw_shape): + B, N, C = x.shape + if self.pool_first: + x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute( + 0, 2, 1, 3 + ) + q = k = v = x + else: + q = k = v = x + q = ( + self.q(q) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k(k) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = ( + self.v(v) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + q, q_shape = attention_pool( + q, + self.pool_q, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_q if hasattr(self, "norm_q") else None, + ) + k, k_shape = attention_pool( + k, + self.pool_k, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_k if hasattr(self, "norm_k") else None, + ) + v, v_shape = attention_pool( + v, + self.pool_v, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_v if hasattr(self, "norm_v") else None, + ) + + if self.pool_first: + q_N = ( + numpy.prod(q_shape) + 1 + if self.has_cls_embed + else numpy.prod(q_shape) + ) + k_N = ( + numpy.prod(k_shape) + 1 + if self.has_cls_embed + else numpy.prod(k_shape) + ) + v_N = ( + numpy.prod(v_shape) + 1 + if self.has_cls_embed + else numpy.prod(v_shape) + ) + + q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) + q = ( + self.q(q) + .reshape(B, q_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) + v = ( + self.v(v) + .reshape(B, v_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) + k = ( + self.k(k) + .reshape(B, k_N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + N = q.shape[2] + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + if self.drop_rate > 0.0: + x = self.proj_drop(x) + return x, q_shape + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + up_rate=None, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + mode="conv", + has_cls_embed=True, + pool_first=False, + ): + super().__init__() + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + self.attn = MultiScaleAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=nn.LayerNorm, + has_cls_embed=has_cls_embed, + mode=mode, + pool_first=pool_first, + ) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.has_cls_embed = has_cls_embed + # TODO: check the use case for up_rate, and merge the following lines + if up_rate is not None and up_rate > 1: + mlp_dim_out = dim * up_rate + else: + mlp_dim_out = dim_out + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=mlp_dim_out, + act_layer=act_layer, + drop_rate=drop_rate, + ) + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + self.pool_skip = ( + nn.MaxPool3d( + kernel_skip, stride_skip, padding_skip, ceil_mode=False + ) + if len(kernel_skip) > 0 + else None + ) + + def forward(self, x, thw_shape): + x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape) + x_res, _ = attention_pool( + x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed + ) + x = x_res + self.drop_path(x_block) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + if self.dim != self.dim_out: + x = self.proj(x_norm) + x = x + self.drop_path(x_mlp) + return x, thw_shape_new From a3f609530f1bb4c345f440490c70acbab042789e Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:24:10 +0530 Subject: [PATCH 03/67] Create mlp.py --- vformer/attention/mlp.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 vformer/attention/mlp.py diff --git a/vformer/attention/mlp.py b/vformer/attention/mlp.py new file mode 100644 index 00000000..f773d53d --- /dev/null +++ b/vformer/attention/mlp.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop_rate=0.0, + ): + super().__init__() + self.drop_rate = drop_rate + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + if self.drop_rate > 0.0: + self.drop = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + if self.drop_rate > 0.0: + x = self.drop(x) + x = self.fc2(x) + if self.drop_rate > 0.0: + x = self.drop(x) + return x From 441df7670a5b4cd188892a058434d7d2183ff81c Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:25:14 +0530 Subject: [PATCH 04/67] Delete mlp.py --- vformer/attention/mlp.py | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 vformer/attention/mlp.py diff --git a/vformer/attention/mlp.py b/vformer/attention/mlp.py deleted file mode 100644 index f773d53d..00000000 --- a/vformer/attention/mlp.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import torch.nn as nn - - -class Mlp(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop_rate=0.0, - ): - super().__init__() - self.drop_rate = drop_rate - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - if self.drop_rate > 0.0: - self.drop = nn.Dropout(drop_rate) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - if self.drop_rate > 0.0: - x = self.drop(x) - x = self.fc2(x) - if self.drop_rate > 0.0: - x = self.drop(x) - return x From 03fc0f11d3174b77ec77c442a18f387ffd4c2e62 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:25:43 +0530 Subject: [PATCH 05/67] Create droppath.py --- vformer/common/droppath.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 vformer/common/droppath.py diff --git a/vformer/common/droppath.py b/vformer/common/droppath.py new file mode 100644 index 00000000..2cddabcc --- /dev/null +++ b/vformer/common/droppath.py @@ -0,0 +1,9 @@ +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) From cd159f8f8fe7f2e60ad907b249703812c7c8189b Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:26:00 +0530 Subject: [PATCH 06/67] Create mlp.py --- vformer/common/mlp.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 vformer/common/mlp.py diff --git a/vformer/common/mlp.py b/vformer/common/mlp.py new file mode 100644 index 00000000..f773d53d --- /dev/null +++ b/vformer/common/mlp.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop_rate=0.0, + ): + super().__init__() + self.drop_rate = drop_rate + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + if self.drop_rate > 0.0: + self.drop = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + if self.drop_rate > 0.0: + x = self.drop(x) + x = self.fc2(x) + if self.drop_rate > 0.0: + x = self.drop(x) + return x From 7aef8072d1169b1f76cb889cdac436fa80d59dff Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:34:40 +0530 Subject: [PATCH 07/67] Add docstrings --- vformer/attention/multiscale.py | 94 +++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 89a4d0f5..77301d0a 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -6,6 +6,22 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): + """ + Attention pooling + Parameters: + ----------- + tensor: tensor + Input tensor + pool: nn.Module + Pooling function + thw_shape: list of int + Reduced space-time resolution + has_cls_embed: boolean, optional + Set to true if classification embeddding is provided + norm : nn.Module, optional + Normalization function + """ + if pool is None: return tensor, thw_shape tensor_dim = tensor.ndim @@ -43,10 +59,40 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): class MultiScaleAttention(nn.Module): + """ + Multiscale Attention + Parameters: + ----------- + dim: int + Dimension of the embedding + num_heads: int + Number of attention heads + qkv_bias: boolean,optional + + drop_rate: float, optional + Dropout rate + kernel_q: tuple of int, optional + Kernel size of query + kernel_kv: tuple of int, optional + Kernel size of key and value + stride_q: tuple of int, optional + Kernel size of query + stride_kv: tuple of int, optional + Kernel size of key and value + norm_layer: nn.Module, optional + Normalization function + has_cls_embed: boolean, optional + Set to true if classification embeddding is provided + mode: str, optional + Pooling function to be used. Options include `conv`, `avg`, and `max' + pool_first: boolean, optional + Set to True to perform pool before projection + """ + def __init__( self, dim, - num_heads=8, + num_heads = 8, qkv_bias=False, drop_rate=0.0, kernel_q=(1, 1, 1), @@ -55,9 +101,7 @@ def __init__( stride_kv=(1, 1, 1), norm_layer=nn.LayerNorm, has_cls_embed=True, - # Options include `conv`, `avg`, and `max`. mode="conv", - # If True, perform pool before projection. pool_first=False, ): super().__init__() @@ -243,6 +287,50 @@ def forward(self, x, thw_shape): class MultiScaleBlock(nn.Module): + """ + Multiscale Attention Block + Parameters: + ----------- + dim: int + Dimension of the embedding + dim_out: int + Output dimension of the embedding + num_heads: int + Number of attention heads + mlp_ratio: float, optional + + qkv_bias: boolean, optional + + qk_scale: + + drop_rate: float, optional + Dropout rate + drop_path: float, optional + Dropout rate for dropping paths in mlp + act_layer= nn.Module, optional + Normalization function + norm_layer= nn.Module, optional + Normalization function + p_rate= + + kernel_q: tuple of int, optional + Kernel size of query + kernel_kv: tuple of int, optional + Kernel size of key and value + stride_q: tuple of int, optional + Kernel size of query + stride_kv: tuple of int, optional + Kernel size of key and value + norm_layer: nn.Module, optional + Normalization function + mode: str, optional + Pooling function to be used. Options include `conv`, `avg`, and `max' + has_cls_embed: boolean, optional + Set to true if classification embeddding is provided + pool_first: boolean, optional + Set to True to perform pool before projection + """ + def __init__( self, dim, From 5c8a00efeb0418440686e7d6c19412ec689ff755 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:35:22 +0530 Subject: [PATCH 08/67] Update Mlp and DropPath import --- vformer/attention/multiscale.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 77301d0a..d208cf2d 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn -from slowfast.models.common import DropPath, Mlp - +from vformer.vformer.common.mlp import Mlp +from vformer.vformer.common.dropapath import DropPath def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): """ From fb41ba360c5450ea4dc8503a6f5b5bcf65c25f43 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:37:48 +0530 Subject: [PATCH 09/67] Update docstring --- vformer/common/droppath.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vformer/common/droppath.py b/vformer/common/droppath.py index 2cddabcc..3c68ec5f 100644 --- a/vformer/common/droppath.py +++ b/vformer/common/droppath.py @@ -1,5 +1,11 @@ class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) + Parameters: + ----------- + dim: float, optional + Probability of dropping paths + """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() From 9aa3800dd7f190eec5bf72f194c78c52909babf9 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:38:09 +0530 Subject: [PATCH 10/67] Add import statements --- vformer/common/droppath.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vformer/common/droppath.py b/vformer/common/droppath.py index 3c68ec5f..714244a0 100644 --- a/vformer/common/droppath.py +++ b/vformer/common/droppath.py @@ -1,3 +1,6 @@ +import torch +import torch.nn as nn + class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) From 6d38d9d5acd8e8977d359f897163ec9f77b474f4 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:42:28 +0530 Subject: [PATCH 11/67] Add docstring --- vformer/common/mlp.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vformer/common/mlp.py b/vformer/common/mlp.py index f773d53d..2b5c4c99 100644 --- a/vformer/common/mlp.py +++ b/vformer/common/mlp.py @@ -3,6 +3,21 @@ class Mlp(nn.Module): + """ + Multilayer Perceptron + Parameters: + ----------- + in_features: int + Size of input + hidden_features: int, optional + Size of hidden layer + out_features: int, optional + Size of output + act_layer: nn.Module, optional + Activation function + drop_rate: float, optional + Dropout rate + """ def __init__( self, in_features, From 3788eb66793c80da0860651dd5096eb7422ff469 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 11:47:21 +0530 Subject: [PATCH 12/67] Run pre-commit --- vformer/attention/multiscale.py | 55 +++++++++++---------------------- vformer/common/droppath.py | 1 + vformer/common/mlp.py | 1 + 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index d208cf2d..7bcbc14f 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn -from vformer.vformer.common.mlp import Mlp from vformer.vformer.common.dropapath import DropPath +from vformer.vformer.common.mlp import Mlp + def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): """ @@ -21,7 +22,7 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): norm : nn.Module, optional Normalization function """ - + if pool is None: return tensor, thw_shape tensor_dim = tensor.ndim @@ -37,9 +38,7 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): B, N, L, C = tensor.shape T, H, W = thw_shape - tensor = ( - tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - ) + tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() tensor = pool(tensor) @@ -68,7 +67,7 @@ class MultiScaleAttention(nn.Module): num_heads: int Number of attention heads qkv_bias: boolean,optional - + drop_rate: float, optional Dropout rate kernel_q: tuple of int, optional @@ -88,11 +87,11 @@ class MultiScaleAttention(nn.Module): pool_first: boolean, optional Set to True to perform pool before projection """ - + def __init__( self, dim, - num_heads = 8, + num_heads=8, qkv_bias=False, drop_rate=0.0, kernel_q=(1, 1, 1), @@ -193,9 +192,7 @@ def __init__( def forward(self, x, thw_shape): B, N, C = x.shape if self.pool_first: - x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute( - 0, 2, 1, 3 - ) + x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) q = k = v = x else: q = k = v = x @@ -238,21 +235,9 @@ def forward(self, x, thw_shape): ) if self.pool_first: - q_N = ( - numpy.prod(q_shape) + 1 - if self.has_cls_embed - else numpy.prod(q_shape) - ) - k_N = ( - numpy.prod(k_shape) + 1 - if self.has_cls_embed - else numpy.prod(k_shape) - ) - v_N = ( - numpy.prod(v_shape) + 1 - if self.has_cls_embed - else numpy.prod(v_shape) - ) + q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape) + k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape) + v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape) q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) q = ( @@ -298,11 +283,11 @@ class MultiScaleBlock(nn.Module): num_heads: int Number of attention heads mlp_ratio: float, optional - + qkv_bias: boolean, optional - + qk_scale: - + drop_rate: float, optional Dropout rate drop_path: float, optional @@ -312,7 +297,7 @@ class MultiScaleBlock(nn.Module): norm_layer= nn.Module, optional Normalization function p_rate= - + kernel_q: tuple of int, optional Kernel size of query kernel_kv: tuple of int, optional @@ -330,7 +315,7 @@ class MultiScaleBlock(nn.Module): pool_first: boolean, optional Set to True to perform pool before projection """ - + def __init__( self, dim, @@ -373,9 +358,7 @@ def __init__( mode=mode, pool_first=pool_first, ) - self.drop_path = ( - DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.has_cls_embed = has_cls_embed @@ -395,9 +378,7 @@ def __init__( self.proj = nn.Linear(dim, dim_out) self.pool_skip = ( - nn.MaxPool3d( - kernel_skip, stride_skip, padding_skip, ceil_mode=False - ) + nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) if len(kernel_skip) > 0 else None ) diff --git a/vformer/common/droppath.py b/vformer/common/droppath.py index 714244a0..0aefec3c 100644 --- a/vformer/common/droppath.py +++ b/vformer/common/droppath.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) diff --git a/vformer/common/mlp.py b/vformer/common/mlp.py index 2b5c4c99..620f9184 100644 --- a/vformer/common/mlp.py +++ b/vformer/common/mlp.py @@ -18,6 +18,7 @@ class Mlp(nn.Module): drop_rate: float, optional Dropout rate """ + def __init__( self, in_features, From db284a82a36786997e0e1ef3e348d866a30bd4a9 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 12:40:10 +0530 Subject: [PATCH 13/67] Update DropPath import --- vformer/attention/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 7bcbc14f..e3059a2d 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from vformer.vformer.common.dropapath import DropPath +from timm.models.layers import DropPath from vformer.vformer.common.mlp import Mlp From 60aa249311b80e90269d15ae3e6856c353aaab88 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 12:40:31 +0530 Subject: [PATCH 14/67] Delete droppath.py --- vformer/common/droppath.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 vformer/common/droppath.py diff --git a/vformer/common/droppath.py b/vformer/common/droppath.py deleted file mode 100644 index 0aefec3c..00000000 --- a/vformer/common/droppath.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn as nn - - -class DropPath(nn.Module): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) - Parameters: - ----------- - dim: float, optional - Probability of dropping paths - """ - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) From 8df5268fb5c7ad6e1f25f519a76f7c5db303ed07 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 12:41:57 +0530 Subject: [PATCH 15/67] Delete mlp.py --- vformer/common/mlp.py | 48 ------------------------------------------- 1 file changed, 48 deletions(-) delete mode 100644 vformer/common/mlp.py diff --git a/vformer/common/mlp.py b/vformer/common/mlp.py deleted file mode 100644 index 620f9184..00000000 --- a/vformer/common/mlp.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torch.nn as nn - - -class Mlp(nn.Module): - """ - Multilayer Perceptron - Parameters: - ----------- - in_features: int - Size of input - hidden_features: int, optional - Size of hidden layer - out_features: int, optional - Size of output - act_layer: nn.Module, optional - Activation function - drop_rate: float, optional - Dropout rate - """ - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop_rate=0.0, - ): - super().__init__() - self.drop_rate = drop_rate - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - if self.drop_rate > 0.0: - self.drop = nn.Dropout(drop_rate) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - if self.drop_rate > 0.0: - x = self.drop(x) - x = self.fc2(x) - if self.drop_rate > 0.0: - x = self.drop(x) - return x From 5a08e33ba14ceca3a3959fd42b3bced18cf370e2 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 12:43:38 +0530 Subject: [PATCH 16/67] Update Mlp import --- vformer/attention/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index e3059a2d..fbcd5fd3 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -3,7 +3,7 @@ import torch.nn as nn from timm.models.layers import DropPath -from vformer.vformer.common.mlp import Mlp +from vformer.vformer.encoder.nn import FeedForward as Mlp def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): From dd72553625b520cb6563ef51e061df953277ba86 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 12:49:49 +0530 Subject: [PATCH 17/67] Run pre-commit --- vformer/attention/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index fbcd5fd3..78e4bb2f 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -1,8 +1,8 @@ import numpy import torch import torch.nn as nn - from timm.models.layers import DropPath + from vformer.vformer.encoder.nn import FeedForward as Mlp From f144ed7f3071c9e98d4a97a53c33b63bf4f3a456 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 13:17:34 +0530 Subject: [PATCH 18/67] Update docstring --- vformer/attention/multiscale.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 78e4bb2f..02a7b863 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -283,7 +283,7 @@ class MultiScaleBlock(nn.Module): num_heads: int Number of attention heads mlp_ratio: float, optional - + Ratio of hidden dimension to input dimension for feedforward qkv_bias: boolean, optional qk_scale: @@ -296,8 +296,8 @@ class MultiScaleBlock(nn.Module): Normalization function norm_layer= nn.Module, optional Normalization function - p_rate= - + up_rate= float, optional + Ratio of output dimension to input dimension for feedforward kernel_q: tuple of int, optional Kernel size of query kernel_kv: tuple of int, optional From 77e992586a548d4262597a7d5277f15f629adcc0 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 13:21:43 +0530 Subject: [PATCH 19/67] Update mlp usage for feedforward --- vformer/attention/multiscale.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 02a7b863..8d07191b 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -292,8 +292,6 @@ class MultiScaleBlock(nn.Module): Dropout rate drop_path: float, optional Dropout rate for dropping paths in mlp - act_layer= nn.Module, optional - Normalization function norm_layer= nn.Module, optional Normalization function up_rate= float, optional @@ -326,7 +324,6 @@ def __init__( qk_scale=None, drop_rate=0.0, drop_path=0.0, - act_layer=nn.GELU, norm_layer=nn.LayerNorm, up_rate=None, kernel_q=(1, 1, 1), @@ -368,11 +365,10 @@ def __init__( else: mlp_dim_out = dim_out self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - out_features=mlp_dim_out, - act_layer=act_layer, - drop_rate=drop_rate, + dim=dim, + hidden_dim=mlp_hidden_dim, + out_dim=mlp_dim_out, + p_dropout=drop_rate, ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) From d6fcc132c660e4b07e8229e00552e3e330bf6de3 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 13:23:10 +0530 Subject: [PATCH 20/67] Run pre-commit --- vformer/attention/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 8d07191b..ca02b203 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -283,7 +283,7 @@ class MultiScaleBlock(nn.Module): num_heads: int Number of attention heads mlp_ratio: float, optional - Ratio of hidden dimension to input dimension for feedforward + Ratio of hidden dimension to input dimension for feedforward qkv_bias: boolean, optional qk_scale: From d716b63d8dcb132198cbd8fde65f08046f31a51b Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 13:36:18 +0530 Subject: [PATCH 21/67] Update docstring --- vformer/attention/multiscale.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index ca02b203..c80d033f 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -17,7 +17,7 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): Pooling function thw_shape: list of int Reduced space-time resolution - has_cls_embed: boolean, optional + has_cls_embed: bool, optional Set to true if classification embeddding is provided norm : nn.Module, optional Normalization function @@ -66,8 +66,8 @@ class MultiScaleAttention(nn.Module): Dimension of the embedding num_heads: int Number of attention heads - qkv_bias: boolean,optional - + qkv_bias :bool, optional + If True, add a learnable bias to query, key, value drop_rate: float, optional Dropout rate kernel_q: tuple of int, optional @@ -80,11 +80,11 @@ class MultiScaleAttention(nn.Module): Kernel size of key and value norm_layer: nn.Module, optional Normalization function - has_cls_embed: boolean, optional + has_cls_embed: bool, optional Set to true if classification embeddding is provided mode: str, optional Pooling function to be used. Options include `conv`, `avg`, and `max' - pool_first: boolean, optional + pool_first: bool, optional Set to True to perform pool before projection """ @@ -283,11 +283,11 @@ class MultiScaleBlock(nn.Module): num_heads: int Number of attention heads mlp_ratio: float, optional - Ratio of hidden dimension to input dimension for feedforward - qkv_bias: boolean, optional - - qk_scale: - + Ratio of hidden dimension to input dimension for MLP + qkv_bias :bool, optional + If True, add a learnable bias to query, key, value. + qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 if set drop_rate: float, optional Dropout rate drop_path: float, optional @@ -295,7 +295,7 @@ class MultiScaleBlock(nn.Module): norm_layer= nn.Module, optional Normalization function up_rate= float, optional - Ratio of output dimension to input dimension for feedforward + Ratio of output dimension to input dimension for MLP kernel_q: tuple of int, optional Kernel size of query kernel_kv: tuple of int, optional @@ -308,9 +308,9 @@ class MultiScaleBlock(nn.Module): Normalization function mode: str, optional Pooling function to be used. Options include `conv`, `avg`, and `max' - has_cls_embed: boolean, optional + has_cls_embed: bool, optional Set to true if classification embeddding is provided - pool_first: boolean, optional + pool_first: bool, optional Set to True to perform pool before projection """ From f8577649e5cc0a4f7cd5ef3a44dda6d14b49a1a6 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 16:38:44 +0530 Subject: [PATCH 22/67] Remove multiscale block --- vformer/attention/multiscale.py | 120 -------------------------------- 1 file changed, 120 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index c80d033f..01c71286 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -271,123 +271,3 @@ def forward(self, x, thw_shape): return x, q_shape -class MultiScaleBlock(nn.Module): - """ - Multiscale Attention Block - Parameters: - ----------- - dim: int - Dimension of the embedding - dim_out: int - Output dimension of the embedding - num_heads: int - Number of attention heads - mlp_ratio: float, optional - Ratio of hidden dimension to input dimension for MLP - qkv_bias :bool, optional - If True, add a learnable bias to query, key, value. - qk_scale: float, optional - Override default qk scale of head_dim ** -0.5 if set - drop_rate: float, optional - Dropout rate - drop_path: float, optional - Dropout rate for dropping paths in mlp - norm_layer= nn.Module, optional - Normalization function - up_rate= float, optional - Ratio of output dimension to input dimension for MLP - kernel_q: tuple of int, optional - Kernel size of query - kernel_kv: tuple of int, optional - Kernel size of key and value - stride_q: tuple of int, optional - Kernel size of query - stride_kv: tuple of int, optional - Kernel size of key and value - norm_layer: nn.Module, optional - Normalization function - mode: str, optional - Pooling function to be used. Options include `conv`, `avg`, and `max' - has_cls_embed: bool, optional - Set to true if classification embeddding is provided - pool_first: bool, optional - Set to True to perform pool before projection - """ - - def __init__( - self, - dim, - dim_out, - num_heads, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop_rate=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - up_rate=None, - kernel_q=(1, 1, 1), - kernel_kv=(1, 1, 1), - stride_q=(1, 1, 1), - stride_kv=(1, 1, 1), - mode="conv", - has_cls_embed=True, - pool_first=False, - ): - super().__init__() - self.dim = dim - self.dim_out = dim_out - self.norm1 = norm_layer(dim) - kernel_skip = [s + 1 if s > 1 else s for s in stride_q] - stride_skip = stride_q - padding_skip = [int(skip // 2) for skip in kernel_skip] - self.attn = MultiScaleAttention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - drop_rate=drop_rate, - kernel_q=kernel_q, - kernel_kv=kernel_kv, - stride_q=stride_q, - stride_kv=stride_kv, - norm_layer=nn.LayerNorm, - has_cls_embed=has_cls_embed, - mode=mode, - pool_first=pool_first, - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.has_cls_embed = has_cls_embed - # TODO: check the use case for up_rate, and merge the following lines - if up_rate is not None and up_rate > 1: - mlp_dim_out = dim * up_rate - else: - mlp_dim_out = dim_out - self.mlp = Mlp( - dim=dim, - hidden_dim=mlp_hidden_dim, - out_dim=mlp_dim_out, - p_dropout=drop_rate, - ) - if dim != dim_out: - self.proj = nn.Linear(dim, dim_out) - - self.pool_skip = ( - nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) - if len(kernel_skip) > 0 - else None - ) - - def forward(self, x, thw_shape): - x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape) - x_res, _ = attention_pool( - x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed - ) - x = x_res + self.drop_path(x_block) - x_norm = self.norm2(x) - x_mlp = self.mlp(x_norm) - if self.dim != self.dim_out: - x = self.proj(x_norm) - x = x + self.drop_path(x_mlp) - return x, thw_shape_new From 23a81f89e306b118e2bc9e0cf81d8660086c59ad Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 16:39:42 +0530 Subject: [PATCH 23/67] Update imports --- vformer/attention/multiscale.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 01c71286..637d6750 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -1,10 +1,6 @@ import numpy import torch import torch.nn as nn -from timm.models.layers import DropPath - -from vformer.vformer.encoder.nn import FeedForward as Mlp - def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): """ From 7a41441fe7982b64a7b37e695382b5deae536346 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sun, 28 Nov 2021 16:45:38 +0530 Subject: [PATCH 24/67] Create multiscale.py --- vformer/encoder/multiscale.py | 126 ++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 vformer/encoder/multiscale.py diff --git a/vformer/encoder/multiscale.py b/vformer/encoder/multiscale.py new file mode 100644 index 00000000..6e4f9506 --- /dev/null +++ b/vformer/encoder/multiscale.py @@ -0,0 +1,126 @@ +import torch +from torch import nn + +from timm.models.layers import DropPath +from .nn import FeedForward as Mlp +from ..attention import MultiScaleAttention + +class MultiScaleBlock(nn.Module): + """ + Multiscale Attention Block + Parameters: + ----------- + dim: int + Dimension of the embedding + dim_out: int + Output dimension of the embedding + num_heads: int + Number of attention heads + mlp_ratio: float, optional + Ratio of hidden dimension to input dimension for MLP + qkv_bias :bool, optional + If True, add a learnable bias to query, key, value. + qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 if set + drop_rate: float, optional + Dropout rate + drop_path: float, optional + Dropout rate for dropping paths in mlp + norm_layer= nn.Module, optional + Normalization function + up_rate= float, optional + Ratio of output dimension to input dimension for MLP + kernel_q: tuple of int, optional + Kernel size of query + kernel_kv: tuple of int, optional + Kernel size of key and value + stride_q: tuple of int, optional + Kernel size of query + stride_kv: tuple of int, optional + Kernel size of key and value + norm_layer: nn.Module, optional + Normalization function + mode: str, optional + Pooling function to be used. Options include `conv`, `avg`, and `max' + has_cls_embed: bool, optional + Set to true if classification embeddding is provided + pool_first: bool, optional + Set to True to perform pool before projection + """ + + def __init__( + self, + dim, + dim_out, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + up_rate=None, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + mode="conv", + has_cls_embed=True, + pool_first=False, + ): + super().__init__() + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + self.attn = MultiScaleAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=nn.LayerNorm, + has_cls_embed=has_cls_embed, + mode=mode, + pool_first=pool_first, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.has_cls_embed = has_cls_embed + if up_rate is not None and up_rate > 1: + mlp_dim_out = dim * up_rate + else: + mlp_dim_out = dim_out + self.mlp = Mlp( + dim=dim, + hidden_dim=mlp_hidden_dim, + out_dim=mlp_dim_out, + p_dropout=drop_rate, + ) + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + self.pool_skip = ( + nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False) + if len(kernel_skip) > 0 + else None + ) + + def forward(self, x, thw_shape): + x_block, thw_shape_new = self.attn(self.norm1(x), thw_shape) + x_res, _ = attention_pool( + x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed + ) + x = x_res + self.drop_path(x_block) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + if self.dim != self.dim_out: + x = self.proj(x_norm) + x = x + self.drop_path(x_mlp) + return x, thw_shape_new From f9b58da2accce2c28cabeba984457adbb3cc1b43 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 14:23:16 +0530 Subject: [PATCH 25/67] Create multiscale.py --- vformer/models/classification/multiscale.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 vformer/models/classification/multiscale.py diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py new file mode 100644 index 00000000..80cf910d --- /dev/null +++ b/vformer/models/classification/multiscale.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + +from vformer.common import BaseClassificationModel +from vformer.decoder.mlp import MLPDecoder +from vformer.encoder.multiscale import MultiScaleBlock + +attention_block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=self.drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + kernel_q=pool_q[i] if len(pool_q) > i else [], + kernel_kv=pool_kv[i] if len(pool_kv) > i else [], + stride_q=stride_q[i] if len(stride_q) > i else [], + stride_kv=stride_kv[i] if len(stride_kv) > i else [], + mode=mode, + has_cls_embed=self.cls_embed_on, + pool_first=pool_first, + ) From a454edecd566d79c7740d2ee793d38f58c779758 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 14:25:10 +0530 Subject: [PATCH 26/67] Update test_attention.py --- tests/test_attention.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 1d1b3a93..40c372d7 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -4,6 +4,7 @@ from vformer.attention.spatial import SpatialAttention from vformer.attention.vanilla import VanillaSelfAttention from vformer.attention.window import WindowAttention +from vformer.attention.multiscale import MultiScaleAttention def test_VanillaSelfAttention(): @@ -43,6 +44,7 @@ def test_CrossAttention(): del attention + def test_SpatialAttention(): test_tensor1 = torch.randn(4, 3136, 64) test_tensor2 = torch.randn(4, 50, 512) @@ -63,3 +65,8 @@ def test_SpatialAttention(): attention = SpatialAttention(dim=64, num_heads=1, sr_ratio=8, linear=True) out = attention(test_tensor1, 56, 56) assert out.shape == test_tensor1.shape + + + def test_MultiScaleAttention(): + test_tensor1 = torch.randn(256, 49, 96) + test_tensor2 = torch.randn(32, 64, 96) From 7731f1fa1a414252ebc3f7f7a333374ee81daded Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 14:56:43 +0530 Subject: [PATCH 27/67] Delete test_attention.py --- tests/test_attention.py | 72 ----------------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 tests/test_attention.py diff --git a/tests/test_attention.py b/tests/test_attention.py deleted file mode 100644 index 40c372d7..00000000 --- a/tests/test_attention.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch - -from vformer.attention.cross import CrossAttention -from vformer.attention.spatial import SpatialAttention -from vformer.attention.vanilla import VanillaSelfAttention -from vformer.attention.window import WindowAttention -from vformer.attention.multiscale import MultiScaleAttention - - -def test_VanillaSelfAttention(): - test_tensor1 = torch.randn(2, 65, 1024) - test_tensor2 = torch.randn(2, 257, 1024) - attention = VanillaSelfAttention(dim=1024) - out = attention(test_tensor1) - - assert out.shape == (2, 65, 1024) - del attention - attention = VanillaSelfAttention(dim=1024, heads=16) - out = attention(test_tensor2) - assert out.shape == (2, 257, 1024) - del attention - - -def test_WindowAttention(): - test_tensor1 = torch.randn(256, 49, 96) - test_tensor2 = torch.randn(32, 64, 96) - attention = WindowAttention(dim=96, window_size=7, num_heads=3) - out = attention(test_tensor1) - assert out.shape == test_tensor1.shape - del attention - - attention = WindowAttention(dim=96, window_size=8, num_heads=4) - out = attention(test_tensor2) - assert out.shape == test_tensor2.shape - del attention - - -def test_CrossAttention(): - test_tensor1 = torch.randn(64, 1, 64) - test_tensor2 = torch.randn(64, 24, 128) - attention = CrossAttention(64, 128, 64) - out = attention(test_tensor1, test_tensor2) - assert out.shape == test_tensor1.shape - del attention - - - -def test_SpatialAttention(): - test_tensor1 = torch.randn(4, 3136, 64) - test_tensor2 = torch.randn(4, 50, 512) - - attention = SpatialAttention( - dim=64, - num_heads=1, - sr_ratio=8, - ) - out = attention(test_tensor1, H=56, W=56) - assert out.shape == test_tensor1.shape - del attention - attention = SpatialAttention(dim=512, num_heads=8, sr_ratio=1, linear=False) - out = attention(test_tensor2, H=7, W=7) - assert out.shape == test_tensor2.shape - del attention - - attention = SpatialAttention(dim=64, num_heads=1, sr_ratio=8, linear=True) - out = attention(test_tensor1, 56, 56) - assert out.shape == test_tensor1.shape - - - def test_MultiScaleAttention(): - test_tensor1 = torch.randn(256, 49, 96) - test_tensor2 = torch.randn(32, 64, 96) From ef649b038ec2e7024a7a60ff3532ce72f45fe6f6 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 15:07:02 +0530 Subject: [PATCH 28/67] Update --- .github/ISSUE_TEMPLATE.md | 7 - .github/workflows/linting.yml | 2 +- .github/workflows/package-test.yml | 2 +- AUTHORS.rst | 3 +- README.md | 35 ++- README.rst | 37 --- docs/api/attention/attention_cross.rst | 7 + docs/api/attention/attention_spatial.rst | 7 + docs/api/attention/attention_vanilla.rst | 7 + docs/api/attention/attention_window.rst | 7 + docs/api/attention/index.rst | 11 + docs/api/common/common_base_model.rst | 5 + docs/api/common/common_block.rst | 5 + docs/api/common/index.rst | 8 + docs/api/decoder/index.rst | 9 + docs/api/decoder/mlp.rst | 6 + .../api/decoder/task_heads/detection/head.rst | 0 docs/api/decoder/task_heads/index.rst | 9 + .../decoder/task_heads/segmentation/head.rst | 6 + .../decoder/task_heads/segmentation/index.rst | 8 + docs/api/encoder/cross.rst | 5 + docs/api/encoder/embedding/index.rst | 11 + docs/api/encoder/embedding/linear.rst | 5 + docs/api/encoder/embedding/overlappatch.rst | 5 + docs/api/encoder/embedding/patch.rst | 5 + docs/api/encoder/embedding/pos_embedding.rst | 5 + docs/api/encoder/index.rst | 13 ++ docs/api/encoder/nn.rst | 5 + docs/api/encoder/pyramid.rst | 5 + docs/api/encoder/swin.rst | 5 + docs/api/encoder/vanilla.rst | 5 + docs/api/functional/index.rst | 9 + docs/api/functional/merge.rst | 5 + docs/api/functional/norm.rst | 5 + docs/api/models/classification/cct.rst | 5 + docs/api/models/classification/cross.rst | 5 + docs/api/models/classification/cvt.rst | 5 + docs/api/models/classification/index.rst | 13 ++ docs/api/models/classification/pyramid.rst | 5 + docs/api/models/classification/swin.rst | 5 + docs/api/models/classification/vanilla.rst | 5 + docs/api/models/dense/PVT/detection.rst | 5 + docs/api/models/dense/PVT/index.rst | 9 + docs/api/models/dense/PVT/segmentation.rst | 5 + docs/api/models/dense/index.rst | 8 + docs/api/models/index.rst | 9 + docs/api/utils/index.rst | 9 + docs/api/utils/utils.rst | 5 + docs/api/utils/window_utils.rst | 5 + docs/api/viz/index.rst | 9 + docs/api/viz/vit_grad_rollout.rst | 5 + docs/api/viz/vit_rollout.rst | 5 + docs/conf.py | 6 +- docs/history.rst | 1 - docs/index.rst | 21 +- docs/installation.rst | 41 ++-- docs/readme.rst | 26 ++- docs/usage.rst | 7 - tests/test_attention.py | 80 +++++++ tests/test_decoder.py | 24 +- tests/test_encoder.py | 39 ++-- tests/test_models.py | 211 +++++++++++++----- vformer/attention/cross.py | 34 ++- vformer/attention/spatial.py | 28 ++- vformer/attention/vanilla.py | 39 +++- vformer/attention/window.py | 39 +++- vformer/common/base_model.py | 2 +- vformer/common/blocks.py | 31 ++- vformer/decoder/mlp.py | 18 +- .../decoder/task_heads/segmentation/head.py | 11 +- vformer/encoder/cross.py | 71 +++--- vformer/encoder/embedding/__init__.py | 3 +- vformer/encoder/embedding/cvt.py | 110 +++++++++ vformer/encoder/embedding/linear.py | 19 +- vformer/encoder/embedding/overlappatch.py | 37 ++- vformer/encoder/embedding/patch.py | 36 ++- vformer/encoder/embedding/pos_embedding.py | 51 ++++- vformer/encoder/nn.py | 21 +- vformer/encoder/pyramid.py | 56 +++-- vformer/encoder/swin.py | 104 ++++++--- vformer/encoder/vanilla.py | 69 ++++-- vformer/functional/merge.py | 5 +- vformer/functional/norm.py | 6 +- vformer/models/classification/__init__.py | 2 + vformer/models/classification/cct.py | 210 +++++++++++++++++ vformer/models/classification/cross.py | 170 ++++++++------ vformer/models/classification/cvt.py | 196 ++++++++++++++++ vformer/models/classification/pyramid.py | 75 ++++--- vformer/models/classification/swin.py | 72 +++--- vformer/models/classification/vanilla.py | 60 +++-- vformer/models/dense/PVT/detection.py | 107 ++++++--- vformer/models/dense/PVT/segmentation.py | 102 ++++++--- vformer/utils/__init__.py | 1 + vformer/utils/registry.py | 96 ++++++++ vformer/utils/utils.py | 4 +- vformer/utils/window_utils.py | 22 +- 96 files changed, 2178 insertions(+), 571 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE.md delete mode 100644 README.rst create mode 100644 docs/api/attention/attention_cross.rst create mode 100644 docs/api/attention/attention_spatial.rst create mode 100644 docs/api/attention/attention_vanilla.rst create mode 100644 docs/api/attention/attention_window.rst create mode 100644 docs/api/attention/index.rst create mode 100644 docs/api/common/common_base_model.rst create mode 100644 docs/api/common/common_block.rst create mode 100644 docs/api/common/index.rst create mode 100644 docs/api/decoder/index.rst create mode 100644 docs/api/decoder/mlp.rst create mode 100644 docs/api/decoder/task_heads/detection/head.rst create mode 100644 docs/api/decoder/task_heads/index.rst create mode 100644 docs/api/decoder/task_heads/segmentation/head.rst create mode 100644 docs/api/decoder/task_heads/segmentation/index.rst create mode 100644 docs/api/encoder/cross.rst create mode 100644 docs/api/encoder/embedding/index.rst create mode 100644 docs/api/encoder/embedding/linear.rst create mode 100644 docs/api/encoder/embedding/overlappatch.rst create mode 100644 docs/api/encoder/embedding/patch.rst create mode 100644 docs/api/encoder/embedding/pos_embedding.rst create mode 100644 docs/api/encoder/index.rst create mode 100644 docs/api/encoder/nn.rst create mode 100644 docs/api/encoder/pyramid.rst create mode 100644 docs/api/encoder/swin.rst create mode 100644 docs/api/encoder/vanilla.rst create mode 100644 docs/api/functional/index.rst create mode 100644 docs/api/functional/merge.rst create mode 100644 docs/api/functional/norm.rst create mode 100644 docs/api/models/classification/cct.rst create mode 100644 docs/api/models/classification/cross.rst create mode 100644 docs/api/models/classification/cvt.rst create mode 100644 docs/api/models/classification/index.rst create mode 100644 docs/api/models/classification/pyramid.rst create mode 100644 docs/api/models/classification/swin.rst create mode 100644 docs/api/models/classification/vanilla.rst create mode 100644 docs/api/models/dense/PVT/detection.rst create mode 100644 docs/api/models/dense/PVT/index.rst create mode 100644 docs/api/models/dense/PVT/segmentation.rst create mode 100644 docs/api/models/dense/index.rst create mode 100644 docs/api/models/index.rst create mode 100644 docs/api/utils/index.rst create mode 100644 docs/api/utils/utils.rst create mode 100644 docs/api/utils/window_utils.rst create mode 100644 docs/api/viz/index.rst create mode 100644 docs/api/viz/vit_grad_rollout.rst create mode 100644 docs/api/viz/vit_rollout.rst delete mode 100644 docs/history.rst delete mode 100644 docs/usage.rst create mode 100644 tests/test_attention.py create mode 100644 vformer/encoder/embedding/cvt.py create mode 100644 vformer/models/classification/cct.py create mode 100644 vformer/models/classification/cvt.py create mode 100644 vformer/utils/registry.py diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 8b7abb80..00000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,7 +0,0 @@ -* Paper: -* Paper Link: - -### Brief Description -``` -Mention the salient ideas of the paper/model/concept. -``` diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 33b80590..0950b178 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.7'] + python-version: ['3.7', '3.8'] steps: diff --git a/.github/workflows/package-test.yml b/.github/workflows/package-test.yml index 45e1efad..7bb98c85 100644 --- a/.github/workflows/package-test.yml +++ b/.github/workflows/package-test.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7] + python-version: ['3.7', '3.8'] steps: diff --git a/AUTHORS.rst b/AUTHORS.rst index 4a632157..cf4d347a 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -4,6 +4,7 @@ Credits * Neelay Shah -* Abhijit Deo +* Abhijit Deo +* Aditya Agrawal diff --git a/README.md b/README.md index c4fe7966..c1774575 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@
[![Tests](https://github.com/SforAiDl/vformer/actions/workflows/package-test.yml/badge.svg)](https://github.com/SforAiDl/vformer/actions/workflows/package-test.yml) +[![Docs](https://readthedocs.org/projects/vformer/badge/?version=latest)](https://vformer.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/SforAiDl/vformer/branch/main/graph/badge.svg?token=5QKCZ67CM2)](https://codecov.io/gh/SforAiDl/vformer) @@ -29,6 +30,9 @@ python setup.py install - [x] [Vanilla ViT](https://arxiv.org/abs/2010.11929) - [x] [Swin Transformer](https://arxiv.org/abs/2103.14030) - [x] [Pyramid Vision Transformer](https://arxiv.org/abs/2102.12122) +- [x] [CrossViT](https://arxiv.org/abs/2103.14899) +- [x] [Compact Vision Transformer](https://arxiv.org/abs/2104.05704) +- [x] [Compact Convolutional Transformer](https://arxiv.org/abs/2104.05704) ## Example usage @@ -51,7 +55,7 @@ model = SwinTransformer( window_size=7, drop_rate=0.2, ) -logits = model(image) +logits = model(image) ``` `VFormer` has a modular design and allows for easy experimentation using blocks/modules of different architectures. For example, if desired, you can use just the encoder or the windowed attention layer of the Swin Transformer model. @@ -92,6 +96,8 @@ swin_encoder = SwinEncoder( - [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) - [PVT](https://github.com/whai362/PVT) - [vit-explain](https://github.com/jacobgil/vit-explain) +- [CrossViT](https://github.com/IBM/CrossViT) +- [Compact-Transformers](https://github.com/SHI-Labs/Compact-Transformers)
@@ -123,7 +129,7 @@ swin_encoder = SwinEncoder( Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions ```bibtex @misc{wang2021pyramid, - title={Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions}, + title={Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions}, author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and Tong Lu and Ping Luo and Ling Shao}, year={2021}, eprint={2102.12122}, @@ -131,5 +137,28 @@ swin_encoder = SwinEncoder( primaryClass={cs.CV} } ``` - + CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification + +```bibtex +@inproceedings{chen2021crossvit, + title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}}, + author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda}, + booktitle={International Conference on Computer Vision (ICCV)}, + year={2021} +} +``` + + Escaping the Big Data Paradigm with Compact Transformers + +```bibtex +@article{hassani2021escaping, + title = {Escaping the Big Data Paradigm with Compact Transformers}, + author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi}, + year = 2021, + url = {https://arxiv.org/abs/2104.05704}, + eprint = {2104.05704}, + archiveprefix = {arXiv}, + primaryclass = {cs.CV} +} +``` diff --git a/README.rst b/README.rst deleted file mode 100644 index 515031a8..00000000 --- a/README.rst +++ /dev/null @@ -1,37 +0,0 @@ -======= -vformer -======= - - -.. image:: https://img.shields.io/pypi/v/vformer.svg - :target: https://pypi.python.org/pypi/vformer - -.. image:: https://img.shields.io/travis/SforAiDl/vformer.svg - :target: https://travis-ci.com/SforAiDl/vformer - -.. image:: https://readthedocs.org/projects/vformer/badge/?version=latest - :target: https://vformer.readthedocs.io/en/latest/?version=latest - :alt: Documentation Status - - - - -A PyTorch library for vision transformer models - - -* Free software: MIT license -* Documentation: https://vformer.readthedocs.io. - - -Features --------- - -* TODO - -Credits -------- - -This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. - -.. _Cookiecutter: https://github.com/audreyr/cookiecutter -.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage diff --git a/docs/api/attention/attention_cross.rst b/docs/api/attention/attention_cross.rst new file mode 100644 index 00000000..9d211c64 --- /dev/null +++ b/docs/api/attention/attention_cross.rst @@ -0,0 +1,7 @@ +Cross +================= + +.. automodule:: vformer.attention.cross + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/attention/attention_spatial.rst b/docs/api/attention/attention_spatial.rst new file mode 100644 index 00000000..65fb60b0 --- /dev/null +++ b/docs/api/attention/attention_spatial.rst @@ -0,0 +1,7 @@ +Spatial +================= + +.. automodule:: vformer.attention.spatial + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/attention/attention_vanilla.rst b/docs/api/attention/attention_vanilla.rst new file mode 100644 index 00000000..fab3c775 --- /dev/null +++ b/docs/api/attention/attention_vanilla.rst @@ -0,0 +1,7 @@ +Vanilla O(n^2) +================= + +.. automodule:: vformer.attention.vanilla + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/attention/attention_window.rst b/docs/api/attention/attention_window.rst new file mode 100644 index 00000000..b3bb8161 --- /dev/null +++ b/docs/api/attention/attention_window.rst @@ -0,0 +1,7 @@ +Window +================= + +.. automodule:: vformer.attention.window + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/attention/index.rst b/docs/api/attention/index.rst new file mode 100644 index 00000000..227f8aee --- /dev/null +++ b/docs/api/attention/index.rst @@ -0,0 +1,11 @@ +Attention +================ + +.. toctree:: + :maxdepth: 2 + :caption: Contents + + attention_vanilla + attention_cross + attention_spatial + attention_window diff --git a/docs/api/common/common_base_model.rst b/docs/api/common/common_base_model.rst new file mode 100644 index 00000000..6ca01b8a --- /dev/null +++ b/docs/api/common/common_base_model.rst @@ -0,0 +1,5 @@ +Base Classification Model +============================ + +.. automodule:: vformer.common.base_model + :members: diff --git a/docs/api/common/common_block.rst b/docs/api/common/common_block.rst new file mode 100644 index 00000000..ce0c14d9 --- /dev/null +++ b/docs/api/common/common_block.rst @@ -0,0 +1,5 @@ +Blocks +======== + +.. automodule:: vformer.common.blocks + :members: diff --git a/docs/api/common/index.rst b/docs/api/common/index.rst new file mode 100644 index 00000000..56f96080 --- /dev/null +++ b/docs/api/common/index.rst @@ -0,0 +1,8 @@ +Common +============ + +.. toctree:: + :maxdepth: 2 + + common_base_model + common_block \ No newline at end of file diff --git a/docs/api/decoder/index.rst b/docs/api/decoder/index.rst new file mode 100644 index 00000000..9f91ae56 --- /dev/null +++ b/docs/api/decoder/index.rst @@ -0,0 +1,9 @@ +Decoder +============ + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + mlp + task_heads/index \ No newline at end of file diff --git a/docs/api/decoder/mlp.rst b/docs/api/decoder/mlp.rst new file mode 100644 index 00000000..6693c3d9 --- /dev/null +++ b/docs/api/decoder/mlp.rst @@ -0,0 +1,6 @@ +MLP +=========== + +.. automodule:: vformer.decoder.mlp + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/decoder/task_heads/detection/head.rst b/docs/api/decoder/task_heads/detection/head.rst new file mode 100644 index 00000000..e69de29b diff --git a/docs/api/decoder/task_heads/index.rst b/docs/api/decoder/task_heads/index.rst new file mode 100644 index 00000000..5a321009 --- /dev/null +++ b/docs/api/decoder/task_heads/index.rst @@ -0,0 +1,9 @@ +Task Heads +========== + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + detection/index + segmentation/index \ No newline at end of file diff --git a/docs/api/decoder/task_heads/segmentation/head.rst b/docs/api/decoder/task_heads/segmentation/head.rst new file mode 100644 index 00000000..c80d0451 --- /dev/null +++ b/docs/api/decoder/task_heads/segmentation/head.rst @@ -0,0 +1,6 @@ +Double Convolution +=================== + +.. automodule:: vformer.decoder.task_heads.segmentation.head + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/decoder/task_heads/segmentation/index.rst b/docs/api/decoder/task_heads/segmentation/index.rst new file mode 100644 index 00000000..db79673d --- /dev/null +++ b/docs/api/decoder/task_heads/segmentation/index.rst @@ -0,0 +1,8 @@ +Segmentation +============= + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + head \ No newline at end of file diff --git a/docs/api/encoder/cross.rst b/docs/api/encoder/cross.rst new file mode 100644 index 00000000..2dc94225 --- /dev/null +++ b/docs/api/encoder/cross.rst @@ -0,0 +1,5 @@ +Cross +============= + +.. automodule:: vformer.encoder.cross + :members: \ No newline at end of file diff --git a/docs/api/encoder/embedding/index.rst b/docs/api/encoder/embedding/index.rst new file mode 100644 index 00000000..4f8890a9 --- /dev/null +++ b/docs/api/encoder/embedding/index.rst @@ -0,0 +1,11 @@ +Embedding Layers +================= + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + linear + overlappatch + patch + pos_embedding \ No newline at end of file diff --git a/docs/api/encoder/embedding/linear.rst b/docs/api/encoder/embedding/linear.rst new file mode 100644 index 00000000..eddf68eb --- /dev/null +++ b/docs/api/encoder/embedding/linear.rst @@ -0,0 +1,5 @@ +Linear +================ + +.. automodule:: vformer.encoder.embedding.linear + :members: \ No newline at end of file diff --git a/docs/api/encoder/embedding/overlappatch.rst b/docs/api/encoder/embedding/overlappatch.rst new file mode 100644 index 00000000..41daa893 --- /dev/null +++ b/docs/api/encoder/embedding/overlappatch.rst @@ -0,0 +1,5 @@ +Patch Overlap +======================= + +.. automodule:: vformer.encoder.embedding.overlappatch + :members: \ No newline at end of file diff --git a/docs/api/encoder/embedding/patch.rst b/docs/api/encoder/embedding/patch.rst new file mode 100644 index 00000000..124c494b --- /dev/null +++ b/docs/api/encoder/embedding/patch.rst @@ -0,0 +1,5 @@ +Patch +================ + +.. automodule:: vformer.encoder.embedding.patch + :members: \ No newline at end of file diff --git a/docs/api/encoder/embedding/pos_embedding.rst b/docs/api/encoder/embedding/pos_embedding.rst new file mode 100644 index 00000000..ed84bc94 --- /dev/null +++ b/docs/api/encoder/embedding/pos_embedding.rst @@ -0,0 +1,5 @@ +Positional +================ + +.. automodule:: vformer.encoder.embedding.pos_embedding + :members: \ No newline at end of file diff --git a/docs/api/encoder/index.rst b/docs/api/encoder/index.rst new file mode 100644 index 00000000..cdb7beb0 --- /dev/null +++ b/docs/api/encoder/index.rst @@ -0,0 +1,13 @@ +Encoder +========= + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + cross + embedding/index + nn + pyramid + swin + vanilla \ No newline at end of file diff --git a/docs/api/encoder/nn.rst b/docs/api/encoder/nn.rst new file mode 100644 index 00000000..a142705d --- /dev/null +++ b/docs/api/encoder/nn.rst @@ -0,0 +1,5 @@ +NN +============= + +.. automodule:: vformer.encoder.nn + :members: \ No newline at end of file diff --git a/docs/api/encoder/pyramid.rst b/docs/api/encoder/pyramid.rst new file mode 100644 index 00000000..f54cb5ad --- /dev/null +++ b/docs/api/encoder/pyramid.rst @@ -0,0 +1,5 @@ +Pyramid +============= + +.. automodule:: vformer.encoder.pyramid + :members: \ No newline at end of file diff --git a/docs/api/encoder/swin.rst b/docs/api/encoder/swin.rst new file mode 100644 index 00000000..f4285b83 --- /dev/null +++ b/docs/api/encoder/swin.rst @@ -0,0 +1,5 @@ +Swin +============= + +.. automodule:: vformer.encoder.swin + :members: \ No newline at end of file diff --git a/docs/api/encoder/vanilla.rst b/docs/api/encoder/vanilla.rst new file mode 100644 index 00000000..633e0996 --- /dev/null +++ b/docs/api/encoder/vanilla.rst @@ -0,0 +1,5 @@ +Vanilla Transformer +==================== + +.. automodule:: vformer.encoder.vanilla + :members: \ No newline at end of file diff --git a/docs/api/functional/index.rst b/docs/api/functional/index.rst new file mode 100644 index 00000000..abbc7aae --- /dev/null +++ b/docs/api/functional/index.rst @@ -0,0 +1,9 @@ +Functional +=========== + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + merge + norm \ No newline at end of file diff --git a/docs/api/functional/merge.rst b/docs/api/functional/merge.rst new file mode 100644 index 00000000..c8e41c61 --- /dev/null +++ b/docs/api/functional/merge.rst @@ -0,0 +1,5 @@ +Patch Merging +================ + +.. automodule:: vformer.functional.merge + :members: \ No newline at end of file diff --git a/docs/api/functional/norm.rst b/docs/api/functional/norm.rst new file mode 100644 index 00000000..84986f11 --- /dev/null +++ b/docs/api/functional/norm.rst @@ -0,0 +1,5 @@ +Normalization Layers +====================== + +.. automodule:: vformer.functional.norm + :members: \ No newline at end of file diff --git a/docs/api/models/classification/cct.rst b/docs/api/models/classification/cct.rst new file mode 100644 index 00000000..80f6d68b --- /dev/null +++ b/docs/api/models/classification/cct.rst @@ -0,0 +1,5 @@ +Compact Convolutional Transformer +================================ + +.. automodule:: vformer.models.classification.cct + :members: diff --git a/docs/api/models/classification/cross.rst b/docs/api/models/classification/cross.rst new file mode 100644 index 00000000..bd877024 --- /dev/null +++ b/docs/api/models/classification/cross.rst @@ -0,0 +1,5 @@ +Cross-attention Transformer +================================ + +.. automodule:: vformer.models.classification.cross + :members: \ No newline at end of file diff --git a/docs/api/models/classification/cvt.rst b/docs/api/models/classification/cvt.rst new file mode 100644 index 00000000..bae128bc --- /dev/null +++ b/docs/api/models/classification/cvt.rst @@ -0,0 +1,5 @@ +Compact Vision Transformer +================================ + +.. automodule:: vformer.models.classification.cvt + :members: diff --git a/docs/api/models/classification/index.rst b/docs/api/models/classification/index.rst new file mode 100644 index 00000000..e4f8254d --- /dev/null +++ b/docs/api/models/classification/index.rst @@ -0,0 +1,13 @@ +Classification +====================== + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + cct + cross + cvt + pyramid + swin + vanilla diff --git a/docs/api/models/classification/pyramid.rst b/docs/api/models/classification/pyramid.rst new file mode 100644 index 00000000..2c0d139b --- /dev/null +++ b/docs/api/models/classification/pyramid.rst @@ -0,0 +1,5 @@ +Pyramid Vision Transformer +================== + +.. automodule:: vformer.models.classification.pyramid + :members: \ No newline at end of file diff --git a/docs/api/models/classification/swin.rst b/docs/api/models/classification/swin.rst new file mode 100644 index 00000000..23dac9cd --- /dev/null +++ b/docs/api/models/classification/swin.rst @@ -0,0 +1,5 @@ +Swin Transformer +=================== + +.. automodule:: vformer.models.classification.swin + :members: \ No newline at end of file diff --git a/docs/api/models/classification/vanilla.rst b/docs/api/models/classification/vanilla.rst new file mode 100644 index 00000000..799abb93 --- /dev/null +++ b/docs/api/models/classification/vanilla.rst @@ -0,0 +1,5 @@ +Vanilla Vision Transformer +============================== + +.. automodule:: vformer.models.classification.vanilla + :members: \ No newline at end of file diff --git a/docs/api/models/dense/PVT/detection.rst b/docs/api/models/dense/PVT/detection.rst new file mode 100644 index 00000000..5f050dde --- /dev/null +++ b/docs/api/models/dense/PVT/detection.rst @@ -0,0 +1,5 @@ +Detection +================ + +.. automodule:: vformer.models.dense.PVT.detection + :members: diff --git a/docs/api/models/dense/PVT/index.rst b/docs/api/models/dense/PVT/index.rst new file mode 100644 index 00000000..ec7d4f47 --- /dev/null +++ b/docs/api/models/dense/PVT/index.rst @@ -0,0 +1,9 @@ +Pyramid Vision Transformer +============ + +.. toctree:: + :caption: Contents + :maxdepth: 2 + + detection + segmentation \ No newline at end of file diff --git a/docs/api/models/dense/PVT/segmentation.rst b/docs/api/models/dense/PVT/segmentation.rst new file mode 100644 index 00000000..3294d311 --- /dev/null +++ b/docs/api/models/dense/PVT/segmentation.rst @@ -0,0 +1,5 @@ +Segmentation +================= + +.. automodule:: vformer.models.dense.PVT.segmentation + :members: \ No newline at end of file diff --git a/docs/api/models/dense/index.rst b/docs/api/models/dense/index.rst new file mode 100644 index 00000000..067b2b31 --- /dev/null +++ b/docs/api/models/dense/index.rst @@ -0,0 +1,8 @@ +Dense Prediction +========================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents + + PVT/index diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst new file mode 100644 index 00000000..cf451436 --- /dev/null +++ b/docs/api/models/index.rst @@ -0,0 +1,9 @@ +Models +======== + +.. toctree:: + :maxdepth: 2 + :caption: Contents + + classification/index + dense/index diff --git a/docs/api/utils/index.rst b/docs/api/utils/index.rst new file mode 100644 index 00000000..cbe9b31c --- /dev/null +++ b/docs/api/utils/index.rst @@ -0,0 +1,9 @@ +Utilities +========= + +.. toctree:: + :maxdepth: 2 + :caption: Contents + + utils + window_utils diff --git a/docs/api/utils/utils.rst b/docs/api/utils/utils.rst new file mode 100644 index 00000000..c3fb0685 --- /dev/null +++ b/docs/api/utils/utils.rst @@ -0,0 +1,5 @@ +Generic Utilities +==================== + +.. automodule:: vformer.utils.utils + :members: \ No newline at end of file diff --git a/docs/api/utils/window_utils.rst b/docs/api/utils/window_utils.rst new file mode 100644 index 00000000..575ca679 --- /dev/null +++ b/docs/api/utils/window_utils.rst @@ -0,0 +1,5 @@ +Window Attention Utilities +============================ + +.. automodule:: vformer.utils.window_utils + :members: \ No newline at end of file diff --git a/docs/api/viz/index.rst b/docs/api/viz/index.rst new file mode 100644 index 00000000..ef4075a0 --- /dev/null +++ b/docs/api/viz/index.rst @@ -0,0 +1,9 @@ +Visualisation +==================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents + + vit_rollout + vit_grad_rollout \ No newline at end of file diff --git a/docs/api/viz/vit_grad_rollout.rst b/docs/api/viz/vit_grad_rollout.rst new file mode 100644 index 00000000..026e51a6 --- /dev/null +++ b/docs/api/viz/vit_grad_rollout.rst @@ -0,0 +1,5 @@ +Gradient Rollout +====================== + +.. automodule:: vformer.viz.vit_grad_rollout + :members: \ No newline at end of file diff --git a/docs/api/viz/vit_rollout.rst b/docs/api/viz/vit_rollout.rst new file mode 100644 index 00000000..a12d9d30 --- /dev/null +++ b/docs/api/viz/vit_rollout.rst @@ -0,0 +1,5 @@ +Rollout +=========== + +.. automodule:: vformer.viz.vit_rollout + :members: \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 8184dfe5..ad8dea7d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,8 +54,8 @@ # General information about the project. project = "vformer" -copyright = "2021, Neelay Shah" -author = "Neelay Shah" +copyright = "2021, Neelay Shah, Abhijit Deo, Aditya Agrawal" +author = "Neelay Shah, Abhijit Deo, Aditya Agrawal" # The version info for the project you're documenting, acts as replacement # for |version| and |release|, also used in various other places throughout @@ -159,6 +159,6 @@ "Miscellaneous", ), ] - +autodoc_mock_imports = ["einops", "timm"] napoleon_google_docstring = False napoleon_numpy_docstring = True diff --git a/docs/history.rst b/docs/history.rst deleted file mode 100644 index 25064996..00000000 --- a/docs/history.rst +++ /dev/null @@ -1 +0,0 @@ -.. include:: ../HISTORY.rst diff --git a/docs/index.rst b/docs/index.rst index 3092aaed..c9dac0e6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,16 +1,25 @@ -Welcome to vformer's documentation! +Welcome to VFormer's documentation! ====================================== .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Contents readme installation - usage - modules - contributing - history + +.. toctree:: + :maxdepth: 2 + :caption: API Reference + + api/attention/index + api/common/index + api/decoder/index + api/encoder/index + api/functional/index + api/models/index + api/utils/index + api/viz/index Indices and tables ================== diff --git a/docs/installation.rst b/docs/installation.rst index a8ca0740..0f1baab2 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -5,47 +5,38 @@ Installation ============ -Stable release --------------- - -To install vformer, run this command in your terminal: +From source (recommended) +------------ -.. code-block:: console +VFormer can be installed from the `GitHub repo`_. - $ pip install vformer +Clone the public repository: -This is the preferred method to install vformer, as it will always install the most recent stable release. +.. code-block:: console -If you don't have `pip`_ installed, this `Python installation guide`_ can guide -you through the process. + $ git clone https://github.com/SforAiDl/vformer.git -.. _pip: https://pip.pypa.io -.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ +and then run the following command to install VFormer: +.. code-block:: console -From sources ------------- + $ python setup.py install -The sources for vformer can be downloaded from the `Github repo`_. -You can either clone the public repository: +.. _Github repo: https://github.com/SforAiDl/vformer -.. code-block:: console - $ git clone git://github.com/SforAiDl/vformer +Stable release +-------------- -Or download the `tarball`_: +To install VFormer, run this command in your terminal: .. code-block:: console - $ curl -OJL https://github.com/SforAiDl/vformer/tarball/master + $ pip install vformer -Once you have a copy of the source, you can install it with: +Note that VFormer is an active project and routinely publishes new releases. In order to upgrade VFormer to the latest version, use pip as follows: .. code-block:: console - $ python setup.py install - - -.. _Github repo: https://github.com/SforAiDl/vformer -.. _tarball: https://github.com/SforAiDl/vformer/tarball/master + $ pip install -U vformer \ No newline at end of file diff --git a/docs/readme.rst b/docs/readme.rst index 72a33558..57de499c 100644 --- a/docs/readme.rst +++ b/docs/readme.rst @@ -1 +1,25 @@ -.. include:: ../README.rst +======= +VFormer +======= + + +.. image:: https://github.com/SforAiDl/vformer/actions/workflows/package-test.yml/badge.svg + :target: https://github.com/SforAiDl/vformer/actions/workflows/package-test.yml + +.. image:: https://readthedocs.org/projects/vformer/badge/?version=latest + :target: https://vformer.readthedocs.io/en/latest/?version=latest + :alt: Documentation Status + +.. image:: https://codecov.io/gh/SforAiDl/vformer/branch/main/graph/badge.svg?token=5QKCZ67CM2 + :target: https://codecov.io/gh/SforAiDl/vformer/branch/main + :alt: Code Coverage + +.. image:: https://img.shields.io/pypi/v/vformer.svg + :target: https://pypi.python.org/pypi/vformer + + +A modular PyTorch library for vision transformer models + + +* Free software: MIT license +* Documentation: https://vformer.readthedocs.io. diff --git a/docs/usage.rst b/docs/usage.rst deleted file mode 100644 index e90f0dbc..00000000 --- a/docs/usage.rst +++ /dev/null @@ -1,7 +0,0 @@ -===== -Usage -===== - -To use vformer in a project:: - - import vformer diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 00000000..2b38692d --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,80 @@ +import torch + +from vformer.utils import ATTENTION_REGISTRY + +attention_modules = ATTENTION_REGISTRY.get_list() + + +def test_VanillaSelfAttention(): + + test_tensor1 = torch.randn(2, 65, 1024) + test_tensor2 = torch.randn(2, 257, 1024) + + attention = ATTENTION_REGISTRY.get("VanillaSelfAttention")(dim=1024) + out = attention(test_tensor1) + assert out.shape == (2, 65, 1024) + del attention + + attention = ATTENTION_REGISTRY.get("VanillaSelfAttention")(dim=1024, num_heads=16) + out = attention(test_tensor2) + assert out.shape == (2, 257, 1024) + del attention + + +def test_WindowAttention(): + + test_tensor1 = torch.randn(256, 49, 96) + test_tensor2 = torch.randn(32, 64, 96) + + attention = ATTENTION_REGISTRY.get("WindowAttention")( + dim=96, window_size=7, num_heads=3 + ) + out = attention(test_tensor1) + assert out.shape == test_tensor1.shape + del attention + + attention = ATTENTION_REGISTRY.get("WindowAttention")( + dim=96, window_size=8, num_heads=4 + ) + out = attention(test_tensor2) + assert out.shape == test_tensor2.shape + del attention + + +def test_CrossAttention(): + + test_tensor1 = torch.randn(64, 1, 64) + test_tensor2 = torch.randn(64, 24, 128) + + attention = ATTENTION_REGISTRY.get("CrossAttention")(64, 128, 64) + out = attention(test_tensor1, test_tensor2) + assert out.shape == test_tensor1.shape + del attention + + +def test_SpatialAttention(): + + test_tensor1 = torch.randn(4, 3136, 64) + test_tensor2 = torch.randn(4, 50, 512) + + attention = ATTENTION_REGISTRY.get("SpatialAttention")( + dim=64, + num_heads=1, + sr_ratio=8, + ) + out = attention(test_tensor1, H=56, W=56) + assert out.shape == test_tensor1.shape + del attention + + attention = ATTENTION_REGISTRY.get("SpatialAttention")( + dim=512, num_heads=8, sr_ratio=1, linear=False + ) + out = attention(test_tensor2, H=7, W=7) + assert out.shape == test_tensor2.shape + del attention + + attention = ATTENTION_REGISTRY.get("SpatialAttention")( + dim=64, num_heads=1, sr_ratio=8, linear=True + ) + out = attention(test_tensor1, 56, 56) + assert out.shape == test_tensor1.shape diff --git a/tests/test_decoder.py b/tests/test_decoder.py index 0afb0839..e1ffd5cf 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -1,25 +1,31 @@ import torch -from vformer.decoder import MLPDecoder -from vformer.decoder.task_heads import SegmentationHead +from vformer.utils import DECODER_REGISTRY + +decoder_modules = DECODER_REGISTRY.get_list() def test_MLPDecoder(): + test_tensor = torch.randn(2, 3, 100) - decoder = MLPDecoder(config=100, n_classes=10) + + decoder = DECODER_REGISTRY.get("MLPDecoder")(config=100, n_classes=10) out = decoder(test_tensor) assert out.shape == (2, 3, 10) del decoder - decoder = MLPDecoder(config=(100, 50), n_classes=10) + + decoder = DECODER_REGISTRY.get("MLPDecoder")(config=(100, 50), n_classes=10) out = decoder(test_tensor) assert out.shape == (2, 3, 10) del decoder - decoder = MLPDecoder(config=[100, 10], n_classes=5) + + decoder = DECODER_REGISTRY.get("MLPDecoder")(config=[100, 10], n_classes=5) out = decoder(test_tensor) assert out.shape == (2, 3, 5) def test_SegmentationHead(): + test_tensor_segmentation_head_256 = [ torch.randn([2, 64, 64, 64]), torch.randn([2, 128, 32, 32]), @@ -39,18 +45,20 @@ def test_SegmentationHead(): torch.randn([3, 1024, 12, 12]), ] - head = SegmentationHead( + head = DECODER_REGISTRY.get("SegmentationHead")( out_channels=1, ) out = head(test_tensor_segmentation_head_256) assert out.shape == (2, 1, 256, 256) - head = SegmentationHead( + head = DECODER_REGISTRY.get("SegmentationHead")( out_channels=10, ) out = head(test_tensor_segmentation_head_224) assert out.shape == (2, 10, 224, 224) - head = SegmentationHead(out_channels=2, embed_dims=[128, 256, 512, 1024]) + head = DECODER_REGISTRY.get("SegmentationHead")( + out_channels=2, embed_dims=[128, 256, 512, 1024] + ) out = head(test_tensor_segmentation_head) assert out.shape == (3, 2, 384, 384) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index cea824aa..cae8386a 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -1,31 +1,29 @@ import torch import torch.nn as nn -from vformer.encoder import ( - CrossEncoder, - PVTEncoder, - SwinEncoder, - SwinEncoderBlock, - VanillaEncoder, -) from vformer.functional import PatchMerging +from vformer.utils import ENCODER_REGISTRY + +encoder_modules = ENCODER_REGISTRY.get_list() def test_VanillaEncoder(): + test_tensor = torch.randn(2, 65, 1024) - encoder = VanillaEncoder( - latent_dim=1024, depth=6, heads=16, dim_head=64, mlp_dim=2048 + encoder = ENCODER_REGISTRY.get("VanillaEncoder")( + embedding_dim=1024, depth=6, num_heads=16, head_dim=64, mlp_dim=2048 ) out = encoder(test_tensor) assert out.shape == test_tensor.shape # shape remains same - del encoder - del test_tensor + del encoder, test_tensor def test_SwinEncoder(): + test_tensor = torch.randn(3, 3136, 96) + # when downsampled - encoder = SwinEncoder( + encoder = ENCODER_REGISTRY.get("SwinEncoder")( dim=96, input_resolution=(224 // 4, 224 // 4), depth=2, @@ -37,8 +35,9 @@ def test_SwinEncoder(): assert out.shape == (3, 784, 192) del encoder + # when not downsampled - encoder = SwinEncoder( + encoder = ENCODER_REGISTRY.get("SwinEncoder")( dim=96, input_resolution=(224 // 4, 224 // 4), depth=2, @@ -51,7 +50,7 @@ def test_SwinEncoder(): assert out.shape == (3, 3136, 96) del encoder - encoder_block = SwinEncoderBlock( + encoder_block = ENCODER_REGISTRY.get("SwinEncoderBlock")( dim=96, input_resolution=(224 // 4, 224 // 4), num_heads=3, window_size=7 ) out = encoder_block(test_tensor) @@ -59,14 +58,16 @@ def test_SwinEncoder(): def test_PVTEncoder(): + test_tensor = torch.randn(4, 3136, 64) - encoder = PVTEncoder( + + encoder = ENCODER_REGISTRY.get("PVTEncoder")( dim=64, depth=3, qkv_bias=True, qk_scale=0.0, p_dropout=0.0, - attn_drop=0.1, + attn_dropout=0.1, drop_path=[0.0] * 3, act_layer=nn.GELU, sr_ratio=1, @@ -80,10 +81,12 @@ def test_PVTEncoder(): def test_CrossEncoder(): + test_tensor1 = torch.randn(3, 5, 128) test_tensor2 = torch.randn(3, 5, 256) - encoder = CrossEncoder(128, 256) + + encoder = ENCODER_REGISTRY.get("CrossEncoder")(128, 256) out = encoder(test_tensor1, test_tensor2) assert out[0].shape == test_tensor1.shape - assert out[1].shape == test_tensor2.shape # shape remains same + assert out[1].shape == test_tensor2.shape del encoder diff --git a/tests/test_models.py b/tests/test_models.py index be1373cb..5995aaf1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,17 +1,9 @@ import torch import torch.nn as nn -from vformer.models import ( - CrossViT, - PVTClassification, - PVTClassificationV2, - PVTDetection, - PVTDetectionV2, - PVTSegmentation, - PVTSegmentationV2, - SwinTransformer, - VanillaViT, -) +from vformer.utils import MODEL_REGISTRY + +models = MODEL_REGISTRY.get_list() img_3channels_256 = torch.randn(2, 3, 256, 256) img_3channels_224 = torch.randn(4, 3, 224, 224) @@ -20,15 +12,18 @@ def test_VanillaViT(): - model = VanillaViT(img_size=256, patch_size=32, n_classes=10, in_channels=3) + model = MODEL_REGISTRY.get("VanillaViT")( + img_size=256, patch_size=32, n_classes=10, in_channels=3 + ) out = model(img_3channels_256) assert out.shape == (2, 10) del model - model = VanillaViT( + + model = MODEL_REGISTRY.get("VanillaViT")( img_size=256, patch_size=32, n_classes=10, - latent_dim=1024, + embedding_dim=1024, decoder_config=(1024, 512), ) out = model(img_3channels_256) @@ -37,20 +32,21 @@ def test_VanillaViT(): def test_SwinTransformer(): - model = SwinTransformer( + + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=224, patch_size=4, in_channels=3, n_classes=1000, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, + p_dropout=0.0, + attn_dropout=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, @@ -59,78 +55,83 @@ def test_SwinTransformer(): out = model(img_3channels_224) assert out.shape == (4, 1000) del model + # tiny_patch4_window7_224 - model = SwinTransformer( + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=224, patch_size=4, in_channels=3, n_classes=10, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, - drop_rate=0.2, + p_dropout=0.2, ) out = model(img_3channels_224) assert out.shape == (4, 10) del model + # tiny_c24_patch4_window8_256 - model = SwinTransformer( + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=256, patch_size=4, in_channels=3, n_classes=10, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], window_size=8, - drop_rate=0.2, + p_dropout=0.2, ) out = model(img_3channels_256) assert out.shape == (2, 10) del model + # for greyscale image - model = SwinTransformer( + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=224, patch_size=4, in_channels=1, n_classes=10, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, - drop_rate=0.2, + p_dropout=0.2, ) out = model(img_1channels_224) assert out.shape == (2, 10) del model + # testing for decoder_config parameter - model = SwinTransformer( + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=224, patch_size=4, in_channels=3, n_classes=10, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, - drop_rate=0.2, + p_dropout=0.2, decoder_config=(768, 256, 10, 2), ) out = model(img_3channels_224) del model assert out.shape == (4, 10) + # ape=false - model = SwinTransformer( + model = MODEL_REGISTRY.get("SwinTransformer")( img_size=224, patch_size=4, in_channels=3, n_classes=10, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, - drop_rate=0.2, + p_dropout=0.2, decoder_config=(768, 256, 10, 2), ape=False, ) @@ -140,11 +141,13 @@ def test_SwinTransformer(): def test_CrossVit(): - model = CrossViT(256, 16, 64, 10) + + model = MODEL_REGISTRY.get("CrossViT")(256, 16, 64, 10) out = model(img_3channels_256) assert out.shape == (2, 10) del model - model = CrossViT( + + model = MODEL_REGISTRY.get("CrossViT")( 256, 16, 64, @@ -159,8 +162,7 @@ def test_CrossVit(): def test_pvt(): # classification - - model = PVTClassification( + model = MODEL_REGISTRY.get("PVTClassification")( patch_size=[7, 3, 3, 3], embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], @@ -176,7 +178,7 @@ def test_pvt(): assert out.shape == (4, 10) del model - model = PVTClassification( + model = MODEL_REGISTRY.get("PVTClassification")( patch_size=[7, 3, 3, 3], embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], @@ -188,33 +190,32 @@ def test_pvt(): decoder_config=512, num_classes=10, ) - out = model(img_3channels_224) assert out.shape == (4, 10) del model - model = PVTClassificationV2(linear=False) + model = MODEL_REGISTRY.get("PVTClassificationV2")(linear=False) out = model(img_3channels_224) assert out.shape == (4, 1000) del model - model = PVTClassificationV2(num_classes=10) + model = MODEL_REGISTRY.get("PVTClassificationV2")(num_classes=10) out = model(img_3channels_224) assert out.shape == (4, 10) del model - model = PVTClassificationV2(num_classes=10) + model = MODEL_REGISTRY.get("PVTClassificationV2")(num_classes=10) out = model(img_3channels_224) assert out.shape == (4, 10) del model - model = PVTClassification(num_classes=12) + model = MODEL_REGISTRY.get("PVTClassification")(num_classes=12) out = model(img_3channels_224) assert out.shape == (4, 12) del model - model = PVTClassificationV2( - embed_dims=[64, 128, 320, 512], + model = MODEL_REGISTRY.get("PVTClassificationV2")( + embedding_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratio=[8, 8, 4, 4], qkv_bias=True, @@ -224,8 +225,10 @@ def test_pvt(): linear=True, ) out = model(img_3channels_224) + assert out.shape == (4, 1000) + # segmentation - model = PVTSegmentation() + model = MODEL_REGISTRY.get("PVTSegmentation")() outs = model(img_3channels_224) assert outs.shape == ( 4, @@ -235,7 +238,7 @@ def test_pvt(): ), f"expected: {(4,1,224,224)}, got : {outs.shape}" del model - model = PVTSegmentation() + model = MODEL_REGISTRY.get("PVTSegmentation")() outs = model(img_3channels_256) assert outs.shape == ( 2, @@ -245,7 +248,7 @@ def test_pvt(): ), f"expected: {(4,1,256,256)}, got : {outs.shape}" del model - model = PVTSegmentation() + model = MODEL_REGISTRY.get("PVTSegmentation")() outs = model(img_3channels_256) assert outs.shape == ( 2, @@ -255,7 +258,7 @@ def test_pvt(): ), f"expected: {(4,1,256,256)}, got : {outs.shape}" del model - model = PVTSegmentationV2(return_pyramid=False) + model = MODEL_REGISTRY.get("PVTSegmentationV2")(return_pyramid=False) outs = model(img_3channels_224) assert outs.shape == ( 4, @@ -265,10 +268,10 @@ def test_pvt(): ), f"expected: {(4,1,224,224)}, got : {outs.shape}" del model - model = PVTSegmentationV2(return_pyramid=True) + model = MODEL_REGISTRY.get("PVTSegmentationV2")(return_pyramid=True) out = model(img_3channels_224) - model = PVTSegmentationV2(return_pyramid=False) + model = MODEL_REGISTRY.get("PVTSegmentationV2")(return_pyramid=False) outs = model(img_3channels_256) assert outs.shape == ( 2, @@ -279,12 +282,112 @@ def test_pvt(): del model # detection + model = MODEL_REGISTRY.get("PVTDetection")() + outs = model(img_3channels_224) + del model - model = PVTDetection() + model = MODEL_REGISTRY.get("PVTDetectionV2")() outs = model(img_3channels_224) + del model + + +def test_cvt(): + model = MODEL_REGISTRY.get("CVT")(img_size=256, patch_size=4, in_channels=3) + out = model(img_3channels_256) + assert out.shape == (2, 1000) del model - model = PVTDetectionV2() - outs = model(img_3channels_224) + model = MODEL_REGISTRY.get("CVT")( + img_size=224, + patch_size=4, + in_channels=3, + seq_pool=False, + embedding_dim=768, + num_heads=1, + mlp_ratio=4.0, + num_classes=10, + p_dropout=0.5, + attn_dropout=0.3, + drop_path=0.2, + positional_embedding="sine", + decoder_config=(768, 12024, 512, 256, 128, 64, 32), + ) + out = model(img_3channels_224) + assert out.shape == (4, 10) + del model + + model = MODEL_REGISTRY.get("CVT")( + img_size=224, + in_channels=3, + patch_size=4, + positional_embedding="none", + seq_pool=False, + decoder_config=None, + ) + f = model(img_3channels_224) + assert f.shape == (4, 1000) + del model + + model = MODEL_REGISTRY.get("CVT")( + img_size=224, + in_channels=3, + patch_size=4, + positional_embedding="none", + seq_pool=True, + decoder_config=768, + ) + f = model(img_3channels_224) + assert f.shape == (4, 1000) + del model + + +def test_cct(): + + model = MODEL_REGISTRY.get("CCT")(img_size=256, patch_size=4, in_channels=3) + out = model(img_3channels_256) + assert out.shape == (2, 1000) + del model + + model = MODEL_REGISTRY.get("CCT")( + img_size=224, + patch_size=4, + in_channels=3, + seq_pool=False, + embedding_dim=768, + num_heads=1, + mlp_ratio=4.0, + num_classes=10, + p_dropout=0.5, + attn_dropout=0.3, + drop_path=0.2, + positional_embedding="sine", + decoder_config=(768, 12024, 512, 256, 128, 64, 32), + ) + out = model(img_3channels_224) + assert out.shape == (4, 10) + del model + + model = MODEL_REGISTRY.get("CCT")( + img_size=224, + in_channels=3, + patch_size=4, + positional_embedding="none", + seq_pool=False, + decoder_config=None, + ) + f = model(img_3channels_224) + assert f.shape == (4, 1000) + del model + + model = MODEL_REGISTRY.get("CCT")( + img_size=224, + in_channels=3, + patch_size=4, + positional_embedding="none", + seq_pool=True, + decoder_config=768, + ) + f = model(img_3channels_224) + assert f.shape == (4, 1000) del model diff --git a/vformer/attention/cross.py b/vformer/attention/cross.py index e6714743..8f4695fc 100644 --- a/vformer/attention/cross.py +++ b/vformer/attention/cross.py @@ -2,10 +2,13 @@ import torch.nn as nn from einops import rearrange +from ..utils import ATTENTION_REGISTRY + class _Projection(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() + if not in_dim == out_dim: self.l1 = nn.Linear(in_dim, out_dim) else: @@ -15,26 +18,30 @@ def forward(self, x): return self.l1(x) +@ATTENTION_REGISTRY.register() class CrossAttention(nn.Module): """ Cross-Attention Fusion - Parameters: - ----------- + + Parameters + ---------- cls_dim: int Dimension of cls token embedding patch_dim: int Dimension of patch token embeddings cls token to be fused with - heads: int + num_heads: int Number of cross-attention heads - dim_head: int + head_dim: int Dimension of each head + """ - def __init__(self, cls_dim, patch_dim, heads=8, dim_head=64): + def __init__(self, cls_dim, patch_dim, num_heads=8, head_dim=64): super().__init__() - inner_dim = heads * dim_head - self.heads = heads - self.scale = dim_head ** -0.5 + + inner_dim = num_heads * head_dim + self.num_heads = num_heads + self.scale = head_dim ** -0.5 self.fl = _Projection(cls_dim, patch_dim) self.gl = _Projection(patch_dim, cls_dim) self.to_k = nn.Linear(patch_dim, inner_dim) @@ -44,20 +51,25 @@ def __init__(self, cls_dim, patch_dim, heads=8, dim_head=64): self.attend = nn.Softmax(dim=-1) def forward(self, cls, patches): + cls = self.fl(cls) + x = torch.cat([cls, patches], dim=-2) q = self.to_q(cls) k = self.to_k(x) v = self.to_v(x) - k = rearrange(k, "b n (h d) -> b h n d", h=self.heads) - q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) - v = rearrange(v, "b n (h d) -> b h n d", h=self.heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) k = torch.transpose(k, -2, -1) + attention = (q @ k) * self.scale attention = self.attend(attention) attention_value = attention @ v attention_value = rearrange(attention_value, "b h n d -> b n (h d)") attention_value = self.cls_project(attention_value) + ycls = cls + attention_value ycls = self.gl(ycls) + return ycls diff --git a/vformer/attention/spatial.py b/vformer/attention/spatial.py index 87568772..d809f0ae 100644 --- a/vformer/attention/spatial.py +++ b/vformer/attention/spatial.py @@ -1,10 +1,13 @@ import torch.nn as nn from ..functional import PreNorm +from ..utils import ATTENTION_REGISTRY +@ATTENTION_REGISTRY.register() class SpatialAttention(nn.Module): """ + Spatial Reduction Attention- Linear complexity attention layer Parameters @@ -24,9 +27,10 @@ class SpatialAttention(nn.Module): proj_drop :float, optional Dropout rate linear : bool - Whether to use linear spatial attention,default is False - act_fn : activation function + Whether to use linear Spatial attention,default is False + act_fn : nn.Module Activation function, default is False + """ def __init__( @@ -42,9 +46,11 @@ def __init__( act_fn=nn.GELU, ): super(SpatialAttention, self).__init__() + assert ( dim % num_heads == 0 ), f"dim {dim} should be divided by num_heads {num_heads}." + self.num_heads = num_heads self.sr_ratio = sr_ratio head_dim = dim // num_heads @@ -61,6 +67,7 @@ def __init__( self.linear = linear self.sr_ratio = sr_ratio self.norm = PreNorm(dim=dim, fn=act_fn() if linear else nn.Identity()) + if not linear: if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) @@ -69,6 +76,22 @@ def __init__( self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) def forward(self, x, H, W): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + H: int + Height of image patches + W: int + Width of image patches + Returns + ---------- + torch.Tensor + Returns output tensor by applying spatial attention on input tensor + + """ B, N, C = x.shape q = ( self.q(x) @@ -106,4 +129,5 @@ def forward(self, x, H, W): attn = self.attn(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) + return self.to_out(x) diff --git a/vformer/attention/vanilla.py b/vformer/attention/vanilla.py index e320510b..8844f1d3 100644 --- a/vformer/attention/vanilla.py +++ b/vformer/attention/vanilla.py @@ -2,30 +2,35 @@ import torch.nn as nn from einops import rearrange +from ..utils import ATTENTION_REGISTRY + +@ATTENTION_REGISTRY.register() class VanillaSelfAttention(nn.Module): """ Vanilla O(n^2) Self attention - Parameters: - ----------- + + Parameters + ---------- dim: int Dimension of the embedding - heads: int + num_heads: int Number of the attention heads - dim_head: int + head_dim: int Dimension of each head p_dropout: float Dropout Probability + """ - def __init__(self, dim, heads=8, dim_head=64, p_dropout=0.0): + def __init__(self, dim, num_heads=8, head_dim=64, p_dropout=0.0): super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) + inner_dim = head_dim * num_heads + project_out = not (num_heads == 1 and head_dim == dim) - self.heads = heads - self.scale = dim_head ** -0.5 + self.num_heads = num_heads + self.scale = head_dim ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) @@ -37,8 +42,22 @@ def __init__(self, dim, heads=8, dim_head=64, p_dropout=0.0): ) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns output tensor by applying self-attention on input tensor + + """ qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), qkv + ) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale diff --git a/vformer/attention/window.py b/vformer/attention/window.py index 2d006d75..5094b783 100644 --- a/vformer/attention/window.py +++ b/vformer/attention/window.py @@ -2,13 +2,14 @@ import torch.nn as nn from timm.models.layers import trunc_normal_ -from ..utils import get_relative_position_bias_index, pair +from ..utils import ATTENTION_REGISTRY, get_relative_position_bias_index, pair +@ATTENTION_REGISTRY.register() class WindowAttention(nn.Module): """ - Parameters: - ----------- + Parameters + ---------- dim: int Number of input channels. window_size : int or tuple[int] @@ -19,9 +20,9 @@ class WindowAttention(nn.Module): If True, add a learnable bias to query, key, value. qk_scale: float, optional Override default qk scale of head_dim ** -0.5 if set - attn_drop: float, optional + attn_dropout: float, optional Dropout rate - proj_drop: float, optional + proj_dropout: float, optional Dropout rate """ @@ -33,10 +34,11 @@ def __init__( num_heads, qkv_bias=True, qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, + attn_dropout=0.0, + proj_dropout=0.0, ): super(WindowAttention, self).__init__() + self.dim = dim self.window_size = pair(window_size) self.num_heads = num_heads @@ -52,11 +54,29 @@ def __init__( self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.to_out_1 = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(attn_drop)) - self.to_out_2 = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(proj_drop)) + self.to_out_1 = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(attn_dropout)) + self.to_out_2 = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(proj_dropout)) trunc_normal_(self.relative_position_bias_table, std=0.2) def forward(self, x, mask=None): + """ + + Parameters + ---------- + x: torch.Tensor + input Tensor + mask: torch.Tensor + Attention mask used for shifted window attention, if None, window attention will be used, + else attention mask will be taken into consideration. + for better understanding you may refer `this ` + + Returns + ---------- + torch.Tensor + Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor + + """ + B_, N, C = x.shape qkv = ( self.qkv(x) @@ -95,4 +115,5 @@ def forward(self, x, mask=None): attn = self.to_out_1(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.to_out_2(x) + return x diff --git a/vformer/common/base_model.py b/vformer/common/base_model.py index b9f45642..f4c8deec 100644 --- a/vformer/common/base_model.py +++ b/vformer/common/base_model.py @@ -7,7 +7,7 @@ class BaseClassificationModel(nn.Module): """ img_size: int Size of the image - patch_size: int or list(int) + patch_size: int or tuple(int) Size of the patch in_channels: int Number of channels in input image diff --git a/vformer/common/blocks.py b/vformer/common/blocks.py index dc7757ba..b0bb4952 100644 --- a/vformer/common/blocks.py +++ b/vformer/common/blocks.py @@ -4,18 +4,20 @@ class DWConv(nn.Module): """ Depth Wise Convolution - Parameters: - ----------- + + Parameters + ---------- dim: int Dimension of the input tensor kernel_size_dwconv: int,optional - Size of the convolution kernel + Size of the convolution kernel, default is 3 stride_dwconv: int - Stride of the convolution + Stride of the convolution, default is 1 padding_dwconv: int or tuple or str - Padding added to all sides of the input + Padding added to all sides of the input, default is 1 bias_dwconv:bool Whether to add learnable bias to the output,default is True. + """ def __init__( @@ -38,8 +40,27 @@ def __init__( ) def forward(self, x, H, W): + """ + + Parameters: + ---------- + x: torch.Tensor + Input tensor + H: int + Height of image patch + W: int + Width of image patch + + Returns: + ---------- + torch.Tensor + Returns output tensor after performing depth-wise convolution operation + + """ + B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) return x diff --git a/vformer/decoder/mlp.py b/vformer/decoder/mlp.py index eddc19f8..e8abeddc 100644 --- a/vformer/decoder/mlp.py +++ b/vformer/decoder/mlp.py @@ -1,10 +1,13 @@ import torch.nn as nn +from ..utils import DECODER_REGISTRY + +@DECODER_REGISTRY.register() class MLPDecoder(nn.Module): """ Parameters - ----------- + ---------- config : int or tuple or list Configuration of the hidden layer(s) n_classes : int @@ -31,5 +34,16 @@ def __init__(self, config=(1024,), n_classes=10): self.decoder = nn.Sequential(*self.decoder) def forward(self, x): - + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns output tensor of size `n_classes`, Note that `torch.nn.Softmax` is not applied to the output tensor. + + """ return self.decoder(x) diff --git a/vformer/decoder/task_heads/segmentation/head.py b/vformer/decoder/task_heads/segmentation/head.py index 747f2313..9612836d 100644 --- a/vformer/decoder/task_heads/segmentation/head.py +++ b/vformer/decoder/task_heads/segmentation/head.py @@ -2,11 +2,12 @@ import torch.nn as nn from torchvision.transforms.functional import resize +from ....utils import DECODER_REGISTRY + -# U-net like structre is used here class DoubleConv(nn.Module): """ - This is a module consisting two convolution layers and activations, we will use this in up-sampling block + Module consisting of two convolution layers and activations """ def __init__( @@ -15,6 +16,7 @@ def __init__( out_channels, ): super(DoubleConv, self).__init__() + self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), @@ -28,6 +30,7 @@ def forward(self, x): return self.conv(x) +@DECODER_REGISTRY.register() class SegmentationHead(nn.Module): """ U-net like up-sampling block @@ -39,6 +42,7 @@ def __init__( embed_dims=[64, 128, 256, 512], ): super(SegmentationHead, self).__init__() + self.ups = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) @@ -63,7 +67,9 @@ def forward(self, skip_connections): x = self.bottleneck(skip_connections[-1]) skip_connections = skip_connections[::-1] + for idx in range(0, len(self.ups), 2): + x = self.ups[idx](x) skip_connection = skip_connections[idx // 2] @@ -74,4 +80,5 @@ def forward(self, skip_connections): x = self.ups[idx + 1](concat_skip) x = self.conv1(x) + return self.conv2(x) diff --git a/vformer/encoder/cross.py b/vformer/encoder/cross.py index 3543d827..25e7a220 100644 --- a/vformer/encoder/cross.py +++ b/vformer/encoder/cross.py @@ -1,59 +1,61 @@ import torch import torch.nn as nn -from vformer.attention import CrossAttention -from vformer.encoder.embedding import LinearEmbedding -from vformer.encoder.vanilla import VanillaEncoder +from ..attention import CrossAttention +from ..utils import ENCODER_REGISTRY +from .vanilla import VanillaEncoder +@ENCODER_REGISTRY.register() class CrossEncoder(nn.Module): """ + Parameters ---------- - latent_dim_s : int - Dimension of the embedding of smaller patches - latent_dim_l : int - Dimension of the embedding of larger patches + embedding_dim_s : int + Dimension of the embedding of smaller patches, default is 1024 + embedding_dim_l : int + Dimension of the embedding of larger patches, default is 1024 attn_heads_s : int - Number of self-attention heads for the smaller patches + Number of self-attention heads for the smaller patches, default is 16 attn_heads_l : int - Number of self-attention heads for the larger patches + Number of self-attention heads for the larger patches, default is 16 cross_head_s : int - Number of cross-attention heads for the smaller patches + Number of cross-attention heads for the smaller patches, default is 8 cross_head_l : int - Number of cross-attention heads for the larger patches - dim_head_s : int - Dimension of the head of the attention for the smaller patches - dim_head_l : int - Dimension of the head of the attention for the larger patches + Number of cross-attention heads for the larger patches, default is 8 + head_dim_s : int + Dimension of the head of the attention for the smaller patches, default is 64 + head_dim_l : int + Dimension of the head of the attention for the larger patches, default is 64 cross_dim_head_s : int - Dimension of the head of the cross-attention for the smaller patches + Dimension of the head of the cross-attention for the smaller patches, default is 64 cross_dim_head_l : int - Dimension of the head of the cross-attention for the larger patches + Dimension of the head of the cross-attention for the larger patches, default is 64 depth_s : int - Number of self-attention layers in encoder for the smaller patches + Number of self-attention layers in encoder for the smaller patches, default is 6 depth_l : int - Number of self-attention layers in encoder for the larger patches + Number of self-attention layers in encoder for the larger patches, default is 6 mlp_dim_s : int - Dimension of the hidden layer in the feed-forward layer for the smaller patches + Dimension of the hidden layer in the feed-forward layer for the smaller patches, default is 2048 mlp_dim_l : int - Dimension of the hidden layer in the feed-forward layer for the larger patches + Dimension of the hidden layer in the feed-forward layer for the larger patches, default is 2048 p_dropout_s : float - Dropout probability for the smaller patches + Dropout probability for the smaller patches, default is 0.0 p_dropout_l : float - Dropout probability for the larger patches + Dropout probability for the larger patches, default is 0.0 """ def __init__( self, - latent_dim_s=1024, - latent_dim_l=1024, + embedding_dim_s=1024, + embedding_dim_l=1024, attn_heads_s=16, attn_heads_l=16, cross_head_s=8, cross_head_l=8, - dim_head_s=64, - dim_head_l=64, + head_dim_s=64, + head_dim_l=64, cross_dim_head_s=64, cross_dim_head_l=64, depth_s=6, @@ -64,30 +66,32 @@ def __init__( p_dropout_l=0.0, ): super().__init__() + self.s = VanillaEncoder( - latent_dim_s, + embedding_dim_s, depth_s, attn_heads_s, - dim_head_s, + head_dim_s, mlp_dim_s, p_dropout_s, ) self.l = VanillaEncoder( - latent_dim_l, + embedding_dim_l, depth_l, attn_heads_l, - dim_head_l, + head_dim_l, mlp_dim_l, p_dropout_l, ) self.attend_s = CrossAttention( - latent_dim_s, latent_dim_l, cross_head_s, cross_dim_head_s + embedding_dim_s, embedding_dim_l, cross_head_s, cross_dim_head_s ) self.attend_l = CrossAttention( - latent_dim_l, latent_dim_s, cross_head_l, cross_dim_head_l + embedding_dim_l, embedding_dim_s, cross_head_l, cross_dim_head_l ) def forward(self, emb_s, emb_l): + emb_s = self.s(emb_s) emb_l = self.l(emb_l) s_cls, s_patches = (lambda t: (t[:, 0:1, :], t[:, 1:, :]))(emb_s) @@ -96,4 +100,5 @@ def forward(self, emb_s, emb_l): l_cls = self.attend_l(l_cls, s_patches) emb_l = torch.cat([l_cls, l_patches], dim=1) emb_s = torch.cat([s_cls, s_patches], dim=1) + return emb_s, emb_l diff --git a/vformer/encoder/embedding/__init__.py b/vformer/encoder/embedding/__init__.py index e18b9d84..94e49c27 100644 --- a/vformer/encoder/embedding/__init__.py +++ b/vformer/encoder/embedding/__init__.py @@ -1,4 +1,5 @@ +from .cvt import CVTEmbedding from .linear import LinearEmbedding from .overlappatch import OverlapPatchEmbed from .patch import PatchEmbedding -from .pos_embedding import AbsolutePositionEmbedding +from .pos_embedding import * diff --git a/vformer/encoder/embedding/cvt.py b/vformer/encoder/embedding/cvt.py new file mode 100644 index 00000000..084599f4 --- /dev/null +++ b/vformer/encoder/embedding/cvt.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn + + +class CVTEmbedding(nn.Module): + """ + This class converts the image patches to tensors. Size of the image patches is controlled by `stride` parameter. + + Parameters + ---------- + kernel_size: int or tuple + Size of the kernel used in convolution + stride: int or tuple + Stride of the convolution operation + padding: int + Padding to all sides of the input + pooling_kernel_size: int or tuple(int) + Size of the kernel used in MaxPool2D,default is 3 + pooling_stride: int or tuple(int) + Size of the stride in MaxPool2D, default is 2 + pooling_padding: int + padding in the MaxPool2D + num_conv_layers: int + Number of Convolution layers in the encoder,default is 1 + in_channels: int + Number of input channels in image, default is 3 + out_channels: int + Number of output channels, default is 64 + in_planes: int + This will be number of channels in the self.conv_layer's convolution except 1st layer and last layer. + activation: nn.Module, optional + Activation Layer, default is None + max_pool: bool + Whether to have max-pooling or not, change this parameter to False when using in CVT model, default is True + conv_bias:bool + Whether to add learnable bias in the convolution operation,default is False + """ + + def __init__( + self, + kernel_size, + stride, + padding, + pooling_kernel_size=3, + pooling_stride=2, + pooling_padding=1, + num_conv_layers=1, + in_channels=3, + out_channels=64, + in_planes=64, + activation=None, + max_pool=True, + conv_bias=False, + ): + super(CVTEmbedding, self).__init__() + + n_filter_list = ( + [in_channels] + + [in_planes for _ in range(num_conv_layers - 1)] + + [out_channels] + ) + self.conv_layers = nn.ModuleList([]) + for i in range(num_conv_layers): + self.conv_layers.append( + nn.ModuleList( + [ + nn.Conv2d( + n_filter_list[i], + n_filter_list[i + 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=conv_bias, + ), + nn.Identity() if activation is None else activation(), + nn.MaxPool2d( + kernel_size=pooling_kernel_size, + stride=pooling_stride, + padding=pooling_padding, + ) + if max_pool + else nn.Identity(), + ] + ) + ) + + self.flatten = nn.Flatten(2, 3) + + def sequence_length(self, n_channels=3, height=224, width=224): + return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] + + def forward(self, x): + """ + + Parameters + ---------- + x: torch.tensor + Input tensor + + Returns + ----------- + torch.Tensor + Returns output tensor (embedding) by applying multiple convolution and max-pooling operations on input tensor + + """ + + for conv2d, activation, maxpool in self.conv_layers: + x = maxpool(activation(conv2d(x))) + + return self.flatten(x).transpose(-2, -1) diff --git a/vformer/encoder/embedding/linear.py b/vformer/encoder/embedding/linear.py index b07f8455..0b22eeb6 100644 --- a/vformer/encoder/embedding/linear.py +++ b/vformer/encoder/embedding/linear.py @@ -4,8 +4,9 @@ class LinearEmbedding(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- embedding_dim: int Dimension of the resultant embedding patch_height: int @@ -14,6 +15,7 @@ class LinearEmbedding(nn.Module): Width of the patch patch_dim: int Dimension of the patch + """ def __init__( @@ -35,5 +37,18 @@ def __init__( ) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + + Returns + ---------- + torch.Tensor + Returns patch embeddings of size `embedding_dim` + + """ return self.patch_embedding(x) diff --git a/vformer/encoder/embedding/overlappatch.py b/vformer/encoder/embedding/overlappatch.py index ef00a78c..0077901e 100644 --- a/vformer/encoder/embedding/overlappatch.py +++ b/vformer/encoder/embedding/overlappatch.py @@ -5,20 +5,22 @@ class OverlapPatchEmbed(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- img_size: int Image Size - patch_size: int + patch_size: int or tuple(int) Patch Size stride: int Stride of the convolution, default is 4 in_channels: int Number of input channels in the image, default is 3 - embed_dim: int + embedding_dim: int Number of linear projection output channels,default is 768 norm_layer: nn.Module, optional Normalization layer, default is nn.LayerNorm + """ def __init__( @@ -27,10 +29,11 @@ def __init__( patch_size, stride=4, in_channels=3, - embed_dim=768, + embedding_dim=768, norm_layer=nn.LayerNorm, ): super(OverlapPatchEmbed, self).__init__() + img_size = pair(img_size) patch_size = pair(patch_size) @@ -41,16 +44,34 @@ def __init__( self.proj = nn.Conv2d( in_channels=in_channels, - out_channels=embed_dim, + out_channels=embedding_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2), ) - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embedding_dim) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + + Returns + ---------- + x: torch.Tensor + Input tensor + H: int + Height of Patch + W: int + Width of Patch + + """ + x = self.proj(x) H, W = x.shape[2:] - x = self.norm(x.flatten(2).transpose(1, 2)) + return x, H, W diff --git a/vformer/encoder/embedding/patch.py b/vformer/encoder/embedding/patch.py index 7f8115b9..14b303fa 100644 --- a/vformer/encoder/embedding/patch.py +++ b/vformer/encoder/embedding/patch.py @@ -5,18 +5,20 @@ class PatchEmbedding(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- img_size: int Image Size patch_size: int Patch Size in_channels: int - Number of input channels in the image, default is 3 - embed_dim: int + Number of input channels in the image + embedding_dim: int Number of linear projection output channels - norm_layer: nn.Module, optional - Normalization layer + norm_layer: nn.Module, + Normalization layer, Default is `nn.LayerNorm` + """ def __init__( @@ -24,10 +26,11 @@ def __init__( img_size, patch_size, in_channels, - embed_dim, + embedding_dim, norm_layer=nn.LayerNorm, ): super(PatchEmbedding, self).__init__() + self.img_size = pair(img_size) self.patch_size = pair(patch_size) self.patch_resolution = [ @@ -37,18 +40,33 @@ def __init__( self.proj = nn.Conv2d( in_channels=in_channels, - out_channels=embed_dim, + out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size, ) - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embedding_dim) def forward(self, x): + """ + + Parameters + ---------- + x:torch.Tensor + Input tensor + + Returns + ---------- + torch.Tensor + Returns output tensor by applying convolution operation with same `kernel_size` and `stride` on input tensor. + + """ B, C, H, W = x.shape assert ( H == self.img_size[0] and W == self.img_size[1] ), f"Input Image Size {H}*{W} doesnt match model {self.img_size[0]}*{self.img_size[1]}" + x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) + return x diff --git a/vformer/encoder/embedding/pos_embedding.py b/vformer/encoder/embedding/pos_embedding.py index 08fb780a..4859cfc5 100644 --- a/vformer/encoder/embedding/pos_embedding.py +++ b/vformer/encoder/embedding/pos_embedding.py @@ -6,10 +6,11 @@ from ...utils import pair -class AbsolutePositionEmbedding(nn.Module): +class PVTPosEmbedding(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- pos_shape : int or tuple(int) The shape of the absolute position embedding. pos_dim : int @@ -24,19 +25,19 @@ def __init__(self, pos_shape, pos_dim, p_dropout=0.0, std=0.02): super().__init__() pos_shape = pair(pos_shape) - self.pos_shape = pos_shape - self.pos_dim = pos_dim - self.pos_embed = nn.Parameter( torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim) ) + self.pos_shape = pos_shape + self.pos_dim = pos_dim + self.drop = nn.Dropout(p=p_dropout) trunc_normal_(self.pos_embed, std=std) def resize_pos_embed(self, pos_embed, shape, mode="bilinear", **kwargs): """ - Parameters: - ----------- + Parameters + ---------- pos_embed : torch.Tensor Position embedding weights shape : tuple @@ -45,6 +46,7 @@ def resize_pos_embed(self, pos_embed, shape, mode="bilinear", **kwargs): Algorithm used for up/down sampling, default is 'bilinear' """ assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" + pos_h, pos_w = self.pos_shape pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] pos_embed_weight = ( @@ -63,5 +65,34 @@ def resize_pos_embed(self, pos_embed, shape, mode="bilinear", **kwargs): return pos_embed def forward(self, x, H, W, mode="bilinear"): - pos_embed = self.resize_pos_embed(self.pos_embed, (H, W), mode) - return self.drop(x + pos_embed) + try: + x = x + self.pos_embed + + except: + x = x + self.resize_pos_embed(self.pos_embed, (H, W), mode) + + return self.drop(x) + + +class PosEmbedding(nn.Module): + def __init__(self, shape, dim, drop=None, sinusoidal=False, std=0.02): + super(PosEmbedding, self).__init__() + if not sinusoidal: + self.pos_embed = torch.zeros(1, shape, dim) + else: + pe = torch.FloatTensor( + [ + [p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] + for p in range(shape) + ] + ) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + self.pos_embed = pe + self.pos_embed.requires_grad = False + trunc_normal_(self.pos_embed, std=std) + self.drop = nn.Dropout(drop) if drop is not None else nn.Identity() + + def forward(self, x): + x = x + self.pos_embed + return self.drop(x) diff --git a/vformer/encoder/nn.py b/vformer/encoder/nn.py index 47e6729d..4be25bc8 100644 --- a/vformer/encoder/nn.py +++ b/vformer/encoder/nn.py @@ -3,8 +3,9 @@ class FeedForward(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- dim: int Dimension of the input tensor hidden_dim: int, optional @@ -13,10 +14,12 @@ class FeedForward(nn.Module): Dimension of the output tensor p_dropout: float Dropout probability, default=0.0 + """ def __init__(self, dim, hidden_dim=None, out_dim=None, p_dropout=0.0): super().__init__() + out_dim = out_dim if out_dim is not None else dim hidden_dim = hidden_dim if hidden_dim is not None else dim @@ -29,4 +32,18 @@ def __init__(self, dim, hidden_dim=None, out_dim=None, p_dropout=0.0): ) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + + torch.Tensor + Returns output tensor by performing linear operations and activation on input tensor + + """ + return self.net(x) diff --git a/vformer/encoder/pyramid.py b/vformer/encoder/pyramid.py index 724fa3c7..0c95d8a0 100644 --- a/vformer/encoder/pyramid.py +++ b/vformer/encoder/pyramid.py @@ -4,29 +4,30 @@ from ..attention import SpatialAttention from ..common.blocks import DWConv from ..functional import PreNorm +from ..utils import ENCODER_REGISTRY class PVTFeedForward(nn.Module): """ + + Parameters + ---------- dim: int Dimension of the input tensor hidden_dim: int, optional Dimension of hidden layer out_dim:int, optional Dimension of output tensor - act_layer: Activation class + act_layer: nn.Module Activation Layer, default is nn.GELU p_dropout: float Dropout probability/rate, default is 0.0 linear: bool - default=False + Whether to use linear Spatial attention,default is False use_dwconv: bool - default=False - + Whether to use Depth-wise convolutions, default is False - Kwargs: - ---------- - kernel_size_dwconv: int,optional + kernel_size_dwconv: int `kernel_size` parameter for 2D convolution used in Depth wise convolution stride_dwconv: int `stride` parameter for 2D convolution used in Depth wise convolution @@ -48,13 +49,16 @@ def __init__( **kwargs ): super(PVTFeedForward, self).__init__() + out_dim = out_dim if out_dim is not None else dim hidden_dim = hidden_dim if hidden_dim is not None else dim self.use_dwconv = use_dwconv self.fc1 = nn.Linear(dim, hidden_dim) self.relu = nn.ReLU(inplace=True) if linear else nn.Identity() + if use_dwconv: self.dw_conv = DWConv(dim=hidden_dim, **kwargs) + self.to_out = nn.Sequential( act_layer(), nn.Dropout(p=p_dropout), @@ -63,16 +67,37 @@ def __init__( ) def forward(self, x, **kwargs): + + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + H: int + Height of image patch + W: int + Width of image patch + + Returns + -------- + torch.Tensor + Returns output tensor + + """ x = self.relu(self.fc1(x)) + if self.use_dwconv: x = self.dw_conv(x, **kwargs) + return self.to_out(x) +@ENCODER_REGISTRY.register() class PVTEncoder(nn.Module): """ - Parameters: - ------------ + Parameters + ---------- dim: int Dimension of the input tensor num_heads: int @@ -84,9 +109,10 @@ class PVTEncoder(nn.Module): qkv_bias: bool Whether to add a bias vector to the q,k, and v matrices qk_scale:float, optional + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float Dropout probability - attn_drop: float + attn_dropout: float Dropout probability drop_path: tuple(float) List of stochastic drop rate @@ -97,7 +123,7 @@ class PVTEncoder(nn.Module): sr_ratio: float Spatial Reduction ratio linear: bool - + Whether to use linear Spatial attention, default is False """ def __init__( @@ -109,7 +135,7 @@ def __init__( qkv_bias, qk_scale, p_dropout, - attn_drop, + attn_dropout, drop_path, act_layer, use_dwconv, @@ -117,7 +143,9 @@ def __init__( linear=False, ): super(PVTEncoder, self).__init__() + self.encoder = nn.ModuleList([]) + for i in range(depth): self.encoder.append( nn.ModuleList( @@ -129,7 +157,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, + attn_drop=attn_dropout, proj_drop=p_dropout, sr_ratio=sr_ratio, linear=linear, @@ -156,7 +184,9 @@ def __init__( ) def forward(self, x, **kwargs): + for prenorm_attn, prenorm_ff in self.encoder: x = x + self.drop_path(prenorm_attn(x, **kwargs)) x = x + self.drop_path(prenorm_ff(x, **kwargs)) + return x diff --git a/vformer/encoder/swin.py b/vformer/encoder/swin.py index b3dd35a4..720e3d1f 100644 --- a/vformer/encoder/swin.py +++ b/vformer/encoder/swin.py @@ -3,18 +3,27 @@ from timm.models.layers import DropPath from ..attention.window import WindowAttention -from ..utils import create_mask, cyclicshift, pair, window_partition, window_reverse +from ..utils import ( + ENCODER_REGISTRY, + create_mask, + cyclicshift, + pair, + window_partition, + window_reverse, +) from .nn import FeedForward +@ENCODER_REGISTRY.register() class SwinEncoderBlock(nn.Module): """ - Parameters: - ----------- + + Parameters + ---------- dim: int Number of the input channels input_resolution: int or tuple[int] - Input resolution + Input resolution of patches num_heads: int Number of attention heads window_size: int @@ -26,14 +35,16 @@ class SwinEncoderBlock(nn.Module): qkv_bias: bool, default= True Whether to add a bias vector to the q,k, and v matrices qk_scale: float, Optional - drop: float + + p_dropout: float Dropout rate - attn_drop: float - Attention dropout rate - drop_path: float - stochastic depth rate + attn_dropout: float + Dropout rate + drop_path_rate: float + Stochastic depth rate norm_layer:nn.Module - Normalization layer + Normalization layer, default is `nn.LayerNorm` + """ def __init__( @@ -46,12 +57,13 @@ def __init__( mlp_ratio=4.0, qkv_bias=True, qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, + p_dropout=0.0, + attn_dropout=0.0, + drop_path_rate=0.0, norm_layer=nn.LayerNorm, ): super(SwinEncoderBlock, self).__init__() + self.dim = dim self.input_resolution = pair(input_resolution) self.num_heads = num_heads @@ -59,9 +71,11 @@ def __init__( self.mlp_ratio = mlp_ratio self.shift_size = shift_size hidden_dim = int(dim * mlp_ratio) + if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) + assert ( 0 <= self.shift_size < window_size ), "shift size must range from 0 to window size" @@ -73,13 +87,15 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, + attn_dropout=attn_dropout, + proj_dropout=p_dropout, ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) self.norm2 = norm_layer(dim) - self.mlp = FeedForward(dim=dim, hidden_dim=hidden_dim, p_dropout=drop) + self.mlp = FeedForward(dim=dim, hidden_dim=hidden_dim, p_dropout=p_dropout) if self.shift_size > 0: attn_mask = create_mask( @@ -90,11 +106,24 @@ def __init__( ) else: attn_mask = None + self.register_buffer("attn_mask", attn_mask) def forward(self, x): - H, W = self.input_resolution + """ + Parameters + ---------- + x: torch.Tensor + + Returns + ---------- + torch.Tensor + Returns output tensor + + """ + + H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "Input tensor shape not compatible" @@ -125,9 +154,11 @@ def forward(self, x): x = skip_connection + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) + return x +@ENCODER_REGISTRY.register() class SwinEncoder(nn.Module): """ dim: int @@ -145,15 +176,16 @@ class SwinEncoder(nn.Module): qkv_bias: bool, default is True Whether to add a bias vector to the q,k, and v matrices qk_scale: float, optional - drop: float, + Override default qk scale of head_dim ** -0.5 in Window Attention if set + p_dropout: float, Dropout rate. - attn_drop: float, optional + attn_dropout: float, optional Attention dropout rate - drop_path: float,tuple[float] + drop_path_rate: float or tuple[float] Stochastic depth rate. - norm_layer (nn.Module, optional): + norm_layer: nn.Module Normalization layer. default is nn.LayerNorm - downsample (nn.Module | None, optional): + downsample: nn.Module, optional Downsample layer(like PatchMerging) at the end of the layer, default is None """ @@ -168,14 +200,15 @@ def __init__( mlp_ratio=4.0, qkv_bias=True, qkv_scale=None, - drop=0.0, - attn_drop=0.0, + p_dropout=0.0, + attn_dropout=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, ): super(SwinEncoder, self).__init__() + self.dim = dim self.input_resolution = input_resolution self.depth = depth @@ -192,9 +225,9 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qkv_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] + p_dropout=p_dropout, + attn_dropout=attn_dropout, + drop_path_rate=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, @@ -210,11 +243,26 @@ def __init__( self.downsample = None def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + + Returns + ---------- + torch.Tensor + Returns output tensor + + """ + for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) + if self.downsample is not None: x = self.downsample(x) + return x diff --git a/vformer/encoder/vanilla.py b/vformer/encoder/vanilla.py index 8d2782b2..b8885e22 100644 --- a/vformer/encoder/vanilla.py +++ b/vformer/encoder/vanilla.py @@ -1,55 +1,94 @@ import torch.nn as nn +from timm.models.layers import DropPath from ..attention import VanillaSelfAttention from ..functional import PreNorm +from ..utils import ENCODER_REGISTRY from .nn import FeedForward +@ENCODER_REGISTRY.register() class VanillaEncoder(nn.Module): """ - Parameters: - ----------- - latent_dim: int + + Parameters + ---------- + embedding_dim: int Dimension of the embedding depth: int Number of self-attention layers - heads: int + num_heads: int Number of the attention heads - dim_head: int + head_dim: int Dimension of each head mlp_dim: int Dimension of the hidden layer in the feed-forward layer p_dropout: float Dropout Probability + attn_dropout: float + Dropout Probability + drop_path_rate: float + Stochastic drop path rate """ - def __init__(self, latent_dim, depth, heads, dim_head, mlp_dim, p_dropout=0.0): + def __init__( + self, + embedding_dim, + depth, + num_heads, + head_dim, + mlp_dim, + p_dropout=0.0, + attn_dropout=0.0, + drop_path_rate=0.0, + ): super().__init__() + self.encoder = nn.ModuleList([]) for _ in range(depth): self.encoder.append( nn.ModuleList( [ PreNorm( - latent_dim, - VanillaSelfAttention( - latent_dim, - heads=heads, - dim_head=dim_head, - p_dropout=p_dropout, + dim=embedding_dim, + fn=VanillaSelfAttention( + dim=embedding_dim, + num_heads=num_heads, + head_dim=head_dim, + p_dropout=attn_dropout, ), ), PreNorm( - latent_dim, - FeedForward(latent_dim, mlp_dim, p_dropout=p_dropout), + dim=embedding_dim, + fn=FeedForward( + dim=embedding_dim, + hidden_dim=mlp_dim, + p_dropout=p_dropout, + ), ), ] ) ) + self.drop_path = ( + DropPath(drop_prob=drop_path_rate) + if drop_path_rate > 0.0 + else nn.Identity() + ) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + + Returns + ---------- + torch.Tensor + Returns output tensor + """ for attn, ff in self.encoder: x = attn(x) + x - x = ff(x) + x + x = self.drop_path(ff(x)) + x return x diff --git a/vformer/functional/merge.py b/vformer/functional/merge.py index 73d2e664..8e8e0f48 100644 --- a/vformer/functional/merge.py +++ b/vformer/functional/merge.py @@ -6,8 +6,9 @@ class PatchMerging(nn.Module): """ - Parameters : - ------------ + + Parameters + ---------- input_resolution: int or tuple[int] Resolution of input features dim : int diff --git a/vformer/functional/norm.py b/vformer/functional/norm.py index 16fe95fe..c0431d67 100644 --- a/vformer/functional/norm.py +++ b/vformer/functional/norm.py @@ -3,11 +3,11 @@ class PreNorm(nn.Module): """ - Parameters: - ----------- + Parameters + ---------- dim: int Dimension of the embedding - fn: + fn:nn.Module Attention class """ diff --git a/vformer/models/classification/__init__.py b/vformer/models/classification/__init__.py index 382e48a8..e6daa1ad 100644 --- a/vformer/models/classification/__init__.py +++ b/vformer/models/classification/__init__.py @@ -1,4 +1,6 @@ +from .cct import CCT from .cross import CrossViT +from .cvt import CVT from .pyramid import PVTClassification, PVTClassificationV2 from .swin import SwinTransformer from .vanilla import VanillaViT diff --git a/vformer/models/classification/cct.py b/vformer/models/classification/cct.py new file mode 100644 index 00000000..6bbc36d3 --- /dev/null +++ b/vformer/models/classification/cct.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...common import BaseClassificationModel +from ...decoder import MLPDecoder +from ...encoder import CVTEmbedding, PosEmbedding, VanillaEncoder +from ...utils import MODEL_REGISTRY, pair + + +@MODEL_REGISTRY.register() +class CCT(BaseClassificationModel): + """ + Implementation of Escaping the Big Data Paradigm with Compact Transformers: + https://arxiv.org/abs/2104.05704 + + Parameters: + ------------ + img_size: int + Size of the image + patch_size: int + Size of the single patch in the image + in_channels: int + Number of input channels in image + seq_pool:bool + Whether to use sequence pooling or not + embedding_dim: int + Patch embedding dimension + num_layers: int + Number of Encoders in encoder block + num_heads: int + Number of heads in each transformer layer + mlp_ratio:float + Ratio of mlp heads to embedding dimension + num_classes: int + Number of classes for classification + p_dropout: float + Dropout probability + attn_dropout: float + Dropout probability + drop_path: float + Stochastic depth rate, default is 0.1 + positional_embedding: str + One of the string values {'learnable','sine','None'}, default is learnable + decoder_config: tuple(int) or int + Configuration of the decoder. If None, the default configuration is used. + pooling_kernel_size: int or tuple(int) + Size of the kernel in MaxPooling operation + pooling_stride: int or tuple(int) + Stride of MaxPooling operation + pooling_padding: int + Padding in MaxPooling operation + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_channels=3, + seq_pool=True, + embedding_dim=768, + num_layers=1, + head_dim=96, + num_heads=1, + mlp_ratio=4.0, + num_classes=1000, + p_dropout=0.1, + attn_dropout=0.1, + drop_path=0.1, + positional_embedding="learnable", + decoder_config=( + 768, + 1024, + ), + pooling_kernel_size=3, + pooling_stride=2, + pooling_padding=1, + ): + super().__init__( + img_size=img_size, + patch_size=patch_size, + ) + + assert ( + img_size % patch_size == 0 + ), f"Image size ({img_size}) has to be divisible by patch size ({patch_size})" + + img_size = pair(img_size) + self.in_channels = in_channels + + self.embedding = CVTEmbedding( + in_channels=in_channels, + out_channels=embedding_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + max_pool=True, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + pooling_padding=pooling_padding, + activation=nn.ReLU, + num_conv_layers=1, + conv_bias=True, + ) + + positional_embedding = ( + positional_embedding + if positional_embedding in ["sine", "learnable", "none"] + else "sine" + ) + hidden_dim = int(embedding_dim * mlp_ratio) + self.embedding_dim = embedding_dim + self.sequence_length = self.embedding.sequence_length( + n_channels=in_channels, height=img_size[0], width=img_size[1] + ) + self.seq_pool = seq_pool + + assert ( + self.sequence_length is not None or positional_embedding == "none" + ), f"Positional embedding is set to {positional_embedding} and the sequence length was not specified." + + if not seq_pool: + self.sequence_length += 1 + self.class_emb = nn.Parameter( + torch.zeros(1, 1, self.embedding_dim), requires_grad=True + ) + else: + self.attention_pool = nn.Linear(self.embedding_dim, 1) + + if positional_embedding != "none": + self.positional_emb = PosEmbedding( + self.sequence_length, + dim=embedding_dim, + drop=p_dropout, + sinusoidal=True if positional_embedding is "sine" else False, + ) + else: + self.positional_emb = None + + dpr = [x.item() for x in torch.linspace(0, drop_path, num_layers)] + self.encoder_blocks = nn.ModuleList( + [ + VanillaEncoder( + embedding_dim=embedding_dim, + num_heads=num_heads, + depth=1, + head_dim=head_dim, + mlp_dim=hidden_dim, + p_dropout=p_dropout, + attn_dropout=attn_dropout, + drop_path_rate=dpr[i], + ) + for i in range(num_layers) + ] + ) + if decoder_config is not None: + + if not isinstance(decoder_config, list) and not isinstance( + decoder_config, tuple + ): + decoder_config = [decoder_config] + + assert ( + decoder_config[0] == embedding_dim + ), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} " + self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes) + + else: + self.decoder = MLPDecoder(config=embedding_dim, n_classes=num_classes) + + def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + + """ + x = self.embedding(x) + + if self.positional_emb is None and x.size(1) < self.sequence_length: + x = F.pad( + x, (0, 0, 0, self.in_channels - x.size(1)), mode="constant", value=0 + ) + + if not self.seq_pool: + cls_token = self.class_emb.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + + if self.positional_emb is not None: + x = self.positional_emb(x) + + for blk in self.encoder_blocks: + x = blk(x) + + if self.seq_pool: + x = torch.matmul( + F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x + ).squeeze(-2) + else: + x = x[:, 0] + + x = self.decoder(x) + + return x diff --git a/vformer/models/classification/cross.py b/vformer/models/classification/cross.py index 33884c7e..cedea344 100644 --- a/vformer/models/classification/cross.py +++ b/vformer/models/classification/cross.py @@ -2,10 +2,10 @@ import torch.nn as nn from einops import repeat -from vformer.common import BaseClassificationModel -from vformer.decoder.mlp import MLPDecoder -from vformer.encoder.cross import CrossEncoder -from vformer.encoder.embedding import LinearEmbedding +from ...common import BaseClassificationModel +from ...decoder import MLPDecoder +from ...encoder import CrossEncoder, LinearEmbedding +from ...utils import MODEL_REGISTRY class _cross_p(BaseClassificationModel): @@ -38,71 +38,72 @@ def forward(self, x): x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, : (n + 1)] x = self.embedding_dropout(x) + return x +@MODEL_REGISTRY.register() class CrossViT(BaseClassificationModel): """ - Implementation of 'CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image - Classification' - https://arxiv.org/abs/2103.14899 - - Parameters: - ----------- - img_size: int - Size of the image - patch_size_s: int - Size of the smaller patches - patch_size_l: int - Size of the larger patches - n_classes: int - Number of classes for classification - cross_dim_head_s: int - Dimension of the head of the cross-attention for the smaller patches - cross_dim_head_l: int - Dimension of the head of the cross-attention for the larger patches - latent_dim_s: int - Dimension of the hidden layer for the smaller patches - latent_dim_l: int - Dimension of the hidden layer for the larger patches - dim_head_s: int - Dimension of the head of the attention for the smaller patches - dim_head_l: int - Dimension of the head of the attention for the larger patches - depth_s: int - Number of attention layers in encoder for the smaller patches - depth_l: int - Number of attention layers in encoder for the larger patches - attn_heads_s: int - Number of attention heads for the smaller patches - attn_heads_l: int - Number of attention heads for the larger patches - cross_head_s: int - Number of CrossAttention heads for the smaller patches - cross_head_l: int - Number of CrossAttention heads for the larger patches - encoder_mlp_dim_s: int - Dimension of hidden layer in the encoder for the smaller patches - encoder_mlp_dim_l: int - Dimension of hidden layer in the encoder for the larger patches - in_channels: int - Number of input channels - decoder_config_s: int or tuple or list, optional - Configuration of the decoder for the smaller patches - decoder_config_l: int or tuple or list, optional - Configuration of the decoder for the larger patches - pool_s: {"cls","mean"} - Feature pooling type for the smaller patches - pool_l: {"cls","mean"} - Feature pooling type for the larger patches - p_dropout_encoder_s: float - Dropout probability in the encoder for the smaller patches - p_dropout_encoder_l: float - Dropout probability in the encoder for the larger patches - p_dropout_embedding_s: float - Dropout probability in the embedding layer for the smaller patches - p_dropout_embedding_l: float - Dropout probability in the embedding layer for the larger patches + Implementation of 'CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification' + https://arxiv.org/abs/2103.14899 + + Parameters + ---------- + img_size: int + Size of the image + patch_size_s: int + Size of the smaller patches + patch_size_l: int + Size of the larger patches + n_classes: int + Number of classes for classification + cross_dim_head_s: int + Dimension of the head of the cross-attention for the smaller patches + cross_dim_head_l: int + Dimension of the head of the cross-attention for the larger patches + latent_dim_s: int + Dimension of the hidden layer for the smaller patches + latent_dim_l: int + Dimension of the hidden layer for the larger patches + head_dim_s: int + Dimension of the head of the attention for the smaller patches + head_dim_l: int + Dimension of the head of the attention for the larger patches + depth_s: int + Number of attention layers in encoder for the smaller patches + depth_l: int + Number of attention layers in encoder for the larger patches + attn_heads_s: int + Number of attention heads for the smaller patches + attn_heads_l: int + Number of attention heads for the larger patches + cross_head_s: int + Number of CrossAttention heads for the smaller patches + cross_head_l: int + Number of CrossAttention heads for the larger patches + encoder_mlp_dim_s: int + Dimension of hidden layer in the encoder for the smaller patches + encoder_mlp_dim_l: int + Dimension of hidden layer in the encoder for the larger patches + in_channels: int + Number of input channels + decoder_config_s: int or tuple or list, optional + Configuration of the decoder for the smaller patches + decoder_config_l: int or tuple or list, optional + Configuration of the decoder for the larger patches + pool_s: {"cls","mean"} + Feature pooling type for the smaller patches + pool_l: {"cls","mean"} + Feature pooling type for the larger patches + p_dropout_encoder_s: float + Dropout probability in the encoder for the smaller patches + p_dropout_encoder_l: float + Dropout probability in the encoder for the larger patches + p_dropout_embedding_s: float + Dropout probability in the embedding layer for the smaller patches + p_dropout_embedding_l: float + Dropout probability in the embedding layer for the larger patches """ def __init__( @@ -115,8 +116,8 @@ def __init__( cross_dim_head_l=64, latent_dim_s=1024, latent_dim_l=1024, - dim_head_s=64, - dim_head_l=64, + head_dim_s=64, + head_dim_l=64, depth_s=6, depth_l=6, attn_heads_s=16, @@ -137,6 +138,7 @@ def __init__( ): super().__init__(img_size, patch_size_s, in_channels, pool_s) super().__init__(img_size, patch_size_l, in_channels, pool_l) + self.s = _cross_p( img_size, patch_size_s, latent_dim_s, in_channels, p_dropout_embedding_s ) @@ -146,8 +148,8 @@ def __init__( self.encoder = CrossEncoder( latent_dim_s, latent_dim_l, - dim_head_s, - dim_head_l, + head_dim_s, + head_dim_l, cross_dim_head_s, cross_dim_head_l, depth_s, @@ -163,29 +165,48 @@ def __init__( ) self.pool_s = lambda x: x.mean(dim=1) if pool_s == "mean" else x[:, 0] self.pool_l = lambda x: x.mean(dim=1) if pool_l == "mean" else x[:, 0] + if decoder_config_s is not None: + if not isinstance(decoder_config_s, list): - decoder_config = list(decoder_config_l) + decoder_config_s = list(decoder_config_s) + assert ( - decoder_config[0] == latent_dim_s + decoder_config_s[0] == latent_dim_s ), "`latent_dim` should be equal to the first item of `decoder_config`" - self.decoder_s = MLPDecoder(decoder_config, n_classes) + + self.decoder_s = MLPDecoder(decoder_config_s, n_classes) else: self.decoder_s = MLPDecoder(latent_dim_s, n_classes) if decoder_config_l is not None: + if not isinstance(decoder_config_l, list): - decoder_config = list(decoder_config_l) + decoder_config_l = list(decoder_config_l) + assert ( - decoder_config[0] == latent_dim_l + decoder_config_l[0] == latent_dim_l ), "`latent_dim` should be equal to the first item of `decoder_config`" - self.decoder_l = MLPDecoder(decoder_config, n_classes) + + self.decoder_l = MLPDecoder(decoder_config_l, n_classes) else: self.decoder_l = MLPDecoder(latent_dim_l, n_classes) def forward(self, img): + """ + + Parameters + ---------- + img: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + + """ emb_s = self.s(img) emb_l = self.l(img) emb_s, emb_l = self.encoder(emb_s, emb_l) @@ -194,4 +215,5 @@ def forward(self, img): n_s = self.decoder_s(cls_s) n_l = self.decoder_l(cls_l) n = n_s + n_l + return n diff --git a/vformer/models/classification/cvt.py b/vformer/models/classification/cvt.py new file mode 100644 index 00000000..da2662e0 --- /dev/null +++ b/vformer/models/classification/cvt.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...common import BaseClassificationModel +from ...decoder import MLPDecoder +from ...encoder import CVTEmbedding, PosEmbedding, VanillaEncoder +from ...utils import MODEL_REGISTRY, pair + + +@MODEL_REGISTRY.register() +class CVT(BaseClassificationModel): + """ + Implementation of Escaping the Big Data Paradigm with Compact Transformers: + https://arxiv.org/abs/2104.05704 + + Parameters: + ------------ + img_size: int + Size of the image, default is 224 + patch_size:int + Size of the single patch in the image, default is 4 + in_channels:int + Number of input channels in image, default is 3 + seq_pool:bool + Whether to use sequence pooling, default is True + embedding_dim: int + Patch embedding dimension, default is 768 + num_layers: int + Number of Encoders in encoder block, default is 1 + num_heads: int + Number of heads in each transformer layer, default is 1 + mlp_ratio:float + Ratio of mlp heads to embedding dimension, default is 4.0 + num_classes: int + Number of classes for classification, default is 1000 + p_dropout: float + Dropout probability, default is 0.0 + attn_dropout: float + Dropout probability, defualt is 0.0 + drop_path: float + Stochastic depth rate, default is 0.1 + positional_embedding: str + One of the string values {'learnable','sine','None'}, default is learnable + decoder_config: tuple(int) or int + Configuration of the decoder. If None, the default configuration is used. + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_channels=3, + seq_pool=True, + embedding_dim=768, + head_dim=96, + num_layers=1, + num_heads=1, + mlp_ratio=4.0, + num_classes=1000, + p_dropout=0.1, + attn_dropout=0.1, + drop_path=0.1, + positional_embedding="learnable", + decoder_config=( + 768, + 1024, + ), + ): + super().__init__( + img_size=img_size, + patch_size=patch_size, + ) + + assert ( + img_size % patch_size == 0 + ), f"Image size ({img_size}) has to be divisible by patch size ({patch_size})" + + img_size = pair(img_size) + self.in_channels = in_channels + self.embedding = CVTEmbedding( + in_channels=in_channels, + out_channels=embedding_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + max_pool=False, + activation=None, + num_conv_layers=1, + conv_bias=True, + ) + + positional_embedding = ( + positional_embedding + if positional_embedding in ["sine", "learnable", "none"] + else "sine" + ) + hidden_dim = int(embedding_dim * mlp_ratio) + self.embedding_dim = embedding_dim + self.sequence_length = self.embedding.sequence_length( + n_channels=in_channels, height=img_size[0], width=img_size[1] + ) + self.seq_pool = seq_pool + + assert ( + self.sequence_length is not None or positional_embedding == "none" + ), f"Positional embedding is set to {positional_embedding} and the sequence length was not specified." + + if not seq_pool: + self.sequence_length += 1 + self.class_emb = nn.Parameter( + torch.zeros(1, 1, self.embedding_dim), requires_grad=True + ) + else: + self.attention_pool = nn.Linear(self.embedding_dim, 1) + + if positional_embedding != "none": + self.positional_emb = PosEmbedding( + shape=self.sequence_length, + dim=embedding_dim, + drop=p_dropout, + sinusoidal=True if positional_embedding is "sine" else False, + ) + else: + self.positional_emb = None + + dpr = [x.item() for x in torch.linspace(0, drop_path, num_layers)] + self.encoder_blocks = nn.ModuleList( + [ + VanillaEncoder( + embedding_dim=embedding_dim, + num_heads=num_heads, + depth=1, + mlp_dim=hidden_dim, + head_dim=head_dim, + p_dropout=p_dropout, + attn_dropout=attn_dropout, + drop_path_rate=dpr[i], + ) + for i in range(num_layers) + ] + ) + if decoder_config is not None: + + if not isinstance(decoder_config, list) and not isinstance( + decoder_config, tuple + ): + decoder_config = [decoder_config] + assert ( + decoder_config[0] == embedding_dim + ), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} " + self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes) + else: + self.decoder = MLPDecoder(config=embedding_dim, n_classes=num_classes) + + def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + + """ + + x = self.embedding(x) + + if self.positional_emb is None and x.size(1) < self.sequence_length: + x = F.pad( + x, (0, 0, 0, self.in_channels - x.size(1)), mode="constant", value=0 + ) + + if not self.seq_pool: + cls_token = self.class_emb.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + + if self.positional_emb is not None: + x = self.positional_emb(x) + + for blk in self.encoder_blocks: + x = blk(x) + + if self.seq_pool: + x = torch.matmul( + F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x + ).squeeze(-2) + else: + x = x[:, 0] + + x = self.decoder(x) + + return x diff --git a/vformer/models/classification/pyramid.py b/vformer/models/classification/pyramid.py index c321d0d5..a0b74c3e 100644 --- a/vformer/models/classification/pyramid.py +++ b/vformer/models/classification/pyramid.py @@ -4,15 +4,18 @@ from timm.models.layers import trunc_normal_ from ...decoder import MLPDecoder -from ...encoder import AbsolutePositionEmbedding, OverlapPatchEmbed, PVTEncoder +from ...encoder import OverlapPatchEmbed, PVTEncoder, PVTPosEmbedding +from ...utils import MODEL_REGISTRY +@MODEL_REGISTRY.register() class PVTClassification(nn.Module): """ - Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v1 + Implementation of Pyramid Vision Transformer: + https://arxiv.org/abs/2102.12122v1 - Parameters: - ----------- + Parameters + ---------- img_size: int Image size patch_size: list(int) @@ -32,24 +35,25 @@ class PVTClassification(nn.Module): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 norm_layer: Normalization layer, default is nn.LayerNorm - sr_ratio: float + sr_ratios: float Spatial reduction ratio decoder_config:int or tuple[int], optional Configuration of the decoder. If None, the default configuration is used. linear: bool - Whether to use + Whether to use linear Spatial attention, default is False use_dwconv: bool - Whether to use Depth-wise convolutions in the overlap-patch embedding layer + Whether to use Depth-wise convolutions, default is False ape: bool - Whether to use absolute position embedding + Whether to use absolute position embedding, default is True """ def __init__( @@ -64,7 +68,7 @@ def __init__( qkv_bias=False, qk_scale=None, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], @@ -94,7 +98,7 @@ def __init__( patch_size=patch_size[i], stride=4 if i == 0 else 2, in_channels=in_channels if i == 0 else embed_dims[i - 1], - embed_dim=embed_dims[i], + embedding_dim=embed_dims[i], ) ] ) @@ -104,7 +108,7 @@ def __init__( self.pos_embeds.append( nn.ModuleList( [ - AbsolutePositionEmbedding( + PVTPosEmbedding( pos_shape=img_size // np.prod(patch_size[: i + 1]), pos_dim=embed_dims[i], ) @@ -131,7 +135,7 @@ def __init__( qk_scale=qk_scale, p_dropout=p_dropout, depth=depths[i], - attn_drop=attn_drop_rate, + attn_dropout=attn_dropout, drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], sr_ratio=sr_ratios[i], linear=linear, @@ -160,6 +164,18 @@ def __init__( self.decoder = MLPDecoder(config=embed_dims[-1], n_classes=num_classes) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + + """ B = x.shape[0] for i in range(len(self.depths)): patch_embed = self.patch_embeds[i] @@ -186,21 +202,23 @@ def forward(self, x): return x +@MODEL_REGISTRY.register() class PVTClassificationV2(PVTClassification): """ - Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v2 + Implementation of Pyramid Vision Transformer: + https://arxiv.org/abs/2102.12122v2 - Parameters: - ----------- + Parameters + ---------- img_size: int Image size patch_size: list(int) List of patch size in_channels: int - Input channels in image, default=3 + Input channels in image, default is 3 num_classes: int Number of classes for classification - embed_dims: int + embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer @@ -211,24 +229,25 @@ class PVTClassificationV2(PVTClassification): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - norm_layer: + norm_layer:nn.Module Normalization layer, default is nn.LayerNorm - sr_ratio: float + sr_ratios: float Spatial reduction ratio decoder_config:int or tuple[int], optional Configuration of the decoder. If None, the default configuration is used. linear: bool - Whether to use + Whether to use linear Spatial attention, default is False use_dwconv: bool - Whether to use Depth-wise convolutions in the overlap-patch embedding layer + Whether to use Depth-wise convolutions, default is True ape: bool - Whether to use absolute position embedding + Whether to use absolute position embedding, default is false """ def __init__( @@ -237,13 +256,13 @@ def __init__( patch_size=[7, 3, 3, 3], in_channels=3, num_classes=1000, - embed_dims=[64, 128, 256, 512], + embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=0.0, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], @@ -258,13 +277,13 @@ def __init__( patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, - embed_dims=embed_dims, + embed_dims=embedding_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, - attn_drop_rate=attn_drop_rate, + attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, norm_layer=norm_layer, depths=depths, diff --git a/vformer/models/classification/swin.py b/vformer/models/classification/swin.py index e2b0c9c3..41cbe9b6 100644 --- a/vformer/models/classification/swin.py +++ b/vformer/models/classification/swin.py @@ -4,17 +4,19 @@ from ...common import BaseClassificationModel from ...decoder import MLPDecoder -from ...encoder import PatchEmbedding, SwinEncoder +from ...encoder import PatchEmbedding, PosEmbedding, SwinEncoder from ...functional import PatchMerging +from ...utils import MODEL_REGISTRY +@MODEL_REGISTRY.register() class SwinTransformer(BaseClassificationModel): """ Implementation of `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` https://arxiv.org/abs/2103.14030v1 - Parameters: - ----------- + Parameters + ---------- img_size: int Size of an Image patch_size: int @@ -23,7 +25,7 @@ class SwinTransformer(BaseClassificationModel): Input channels in image, default=3 n_classes: int Number of classes for classification - embed_dim: int + embedding_dim: int Patch Embedding dimension depths: tuple[int] Depth in each Transformer layer @@ -36,13 +38,14 @@ class SwinTransformer(BaseClassificationModel): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional - drop_rate: float + Override default qk scale of head_dim ** -0.5 in Window Attention if set + p_dropout: float Dropout rate, default is 0.0 - attn_drop_rate: float + attn_dropout: float Attention dropout rate,default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - norm_layer: + norm_layer: nn.Module Normalization layer,default is nn.LayerNorm ape: bool, optional Whether to add relative/absolute position embedding to patch embedding, default is True @@ -58,15 +61,15 @@ def __init__( patch_size, in_channels, n_classes, - embed_dim=96, + embedding_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, + p_dropout=0.0, + attn_dropout=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=True, @@ -81,25 +84,25 @@ def __init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, - embed_dim=embed_dim, + embedding_dim=embedding_dim, norm_layer=norm_layer if patch_norm else nn.Identity, ) self.patch_resolution = self.patch_embed.patch_resolution num_patches = self.patch_resolution[0] * self.patch_resolution[1] self.ape = ape - num_features = int(embed_dim * 2 ** (len(depths) - 1)) - - if self.ape: - self.absolute_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, embed_dim) - ) - trunc_normal_(self.absolute_pos_embed, std=0.02) + num_features = int(embedding_dim * 2 ** (len(depths) - 1)) + self.absolute_pos_embed = ( + PosEmbedding(shape=num_patches, dim=embedding_dim, drop=p_dropout, std=0.02) + if ape + else nn.Identity() + ) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.encoder = nn.ModuleList() + for i_layer in range(len(depths)): layer = SwinEncoder( - dim=int(embed_dim * (2 ** i_layer)), + dim=int(embedding_dim * (2 ** i_layer)), input_resolution=( (self.patch_resolution[0] // (2 ** i_layer)), self.patch_resolution[1] // (2 ** i_layer), @@ -110,8 +113,8 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qkv_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, + p_dropout=p_dropout, + attn_dropout=attn_dropout, drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if i_layer < len(depths) - 1 else None, @@ -119,28 +122,45 @@ def __init__( self.encoder.append(layer) if decoder_config is not None: + if not isinstance(decoder_config, list): decoder_config = list(decoder_config) + assert ( decoder_config[0] == num_features ), f"first item of `decoder_config` should be equal to the `num_features`; num_features=embed_dim * 2** (len(depths)-1) which is = {num_features} " + self.decoder = MLPDecoder(decoder_config, n_classes) + else: self.decoder = MLPDecoder(num_features, n_classes) + self.pool = nn.AdaptiveAvgPool1d(1) self.norm = norm_layer(num_features) if norm_layer is not None else nn.Identity - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=p_dropout) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + + """ x = self.patch_embed(x) - if self.ape: - x += self.absolute_pos_embed - x = self.pos_drop(x) + + x = self.absolute_pos_embed(x) + for layer in self.encoder: x = layer(x) x = self.norm(x) - x = self.pool(x.transpose(1, 2)).flatten(1) x = self.decoder(x) + return x diff --git a/vformer/models/classification/vanilla.py b/vformer/models/classification/vanilla.py index 43784760..eafff68b 100644 --- a/vformer/models/classification/vanilla.py +++ b/vformer/models/classification/vanilla.py @@ -4,25 +4,27 @@ from ...common import BaseClassificationModel from ...decoder import MLPDecoder -from ...encoder import LinearEmbedding, VanillaEncoder +from ...encoder import LinearEmbedding, PosEmbedding, VanillaEncoder +from ...utils import MODEL_REGISTRY +@MODEL_REGISTRY.register() class VanillaViT(BaseClassificationModel): """ Implementation of 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale' https://arxiv.org/abs/2010.11929 - Parameters: - ----------- + Parameters + ---------- img_size: int Size of the image patch_size: int Size of a patch n_classes: int Number of classes for classification - latent_dim: int + embedding_dim: int Dimension of hidden layer - dim_head: int + head_dim: int Dimension of the attention head depth: int Number of attention layers in the encoder @@ -30,7 +32,7 @@ class VanillaViT(BaseClassificationModel): Number of the attention heads encoder_mlp_dim: int Dimension of hidden layer in the encoder - in_channel: int + in_channels: int Number of input channels decoder_config: int or tuple or list, optional Configuration of the decoder. If None, the default configuration is used. @@ -47,8 +49,8 @@ def __init__( img_size, patch_size, n_classes, - latent_dim=1024, - dim_head=64, + embedding_dim=1024, + head_dim=64, depth=6, attn_heads=16, encoder_mlp_dim=2048, @@ -61,40 +63,60 @@ def __init__( super().__init__(img_size, patch_size, in_channels, pool) self.patch_embedding = LinearEmbedding( - latent_dim, self.patch_height, self.patch_width, self.patch_dim + embedding_dim, self.patch_height, self.patch_width, self.patch_dim ) - self.pos_embedding = nn.Parameter( - torch.randn(1, self.n_patches + 1, latent_dim) + self.pos_embedding = PosEmbedding( + shape=self.n_patches + 1, + dim=embedding_dim, + drop=p_dropout_embedding, + sinusoidal=False, ) - self.cls_token = nn.Parameter(torch.randn(1, 1, latent_dim)) - self.embedding_dropout = nn.Dropout(p_dropout_embedding) + self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) self.encoder = VanillaEncoder( - latent_dim, depth, attn_heads, dim_head, encoder_mlp_dim, p_dropout_encoder + embedding_dim=embedding_dim, + depth=depth, + num_heads=attn_heads, + head_dim=head_dim, + mlp_dim=encoder_mlp_dim, + p_dropout=p_dropout_encoder, ) self.pool = lambda x: x.mean(dim=1) if pool == "mean" else x[:, 0] if decoder_config is not None: + if not isinstance(decoder_config, list): decoder_config = list(decoder_config) + assert ( - decoder_config[0] == latent_dim - ), "`latent_dim` should be equal to the first item of `decoder_config`" + decoder_config[0] == embedding_dim + ), "`embedding_dim` should be equal to the first item of `decoder_config`" + self.decoder = MLPDecoder(decoder_config, n_classes) else: - self.decoder = MLPDecoder(latent_dim, n_classes) + self.decoder = MLPDecoder(embedding_dim, n_classes) def forward(self, x): + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns tensor of size `num_classes` + """ x = self.patch_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, : (n + 1)] - x = self.embedding_dropout(x) + x = self.pos_embedding(x) x = self.encoder(x) x = self.pool(x) x = self.decoder(x) diff --git a/vformer/models/dense/PVT/detection.py b/vformer/models/dense/PVT/detection.py index e9115422..7b6a2c47 100644 --- a/vformer/models/dense/PVT/detection.py +++ b/vformer/models/dense/PVT/detection.py @@ -2,15 +2,19 @@ import torch import torch.nn as nn -from ....encoder import AbsolutePositionEmbedding, OverlapPatchEmbed, PVTEncoder +from ....encoder import OverlapPatchEmbed, PVTEncoder, PVTPosEmbedding +from ....utils import MODEL_REGISTRY +@MODEL_REGISTRY.register() class PVTDetection(nn.Module): """ - Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v1 + Implementation of Pyramid Vision Transformer: + https://arxiv.org/abs/2102.12122v1 - Parameters: - ----------- + + Parameters + ---------- img_size: int Image size patch_size: list(int) @@ -19,7 +23,7 @@ class PVTDetection(nn.Module): Input channels in image, default=3 num_classes: int Number of classes for classification - embed_dims: int + embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer @@ -30,20 +34,21 @@ class PVTDetection(nn.Module): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - sr_ratio: float + sr_ratios: float Spatial reduction ratio linear: bool - Whether to use linear spatial attention + Whether to use linear spatial attention, default is False use_dwconv: bool - Whether to use Depth-wise convolutions in Overlap-patch embedding + Whether to use Depth-wise convolutions in Overlap-patch embedding, default is False ape: bool - Whether to use absolute position embedding + Whether to use absolute position embedding, default is True """ @@ -52,13 +57,13 @@ def __init__( img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, - embed_dims=[64, 128, 256, 512], + embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], @@ -68,17 +73,23 @@ def __init__( ape=True, ): super(PVTDetection, self).__init__() + self.ape = ape self.depths = depths + assert ( - len(depths) == len(num_heads) == len(embed_dims) + len(depths) == len(num_heads) == len(embedding_dims) ), "Configurations do not match" + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.patch_embeds = nn.ModuleList([]) self.blocks = nn.ModuleList([]) self.norms = nn.ModuleList() self.pos_embeds = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append( nn.ModuleList( [ @@ -86,19 +97,22 @@ def __init__( img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), patch_size=patch_size[i], stride=4 if i == 0 else 2, - in_channels=in_channels if i == 0 else embed_dims[i - 1], - embed_dim=embed_dims[i], + in_channels=in_channels + if i == 0 + else embedding_dims[i - 1], + embedding_dim=embedding_dims[i], ) ] ) ) + if ape: self.pos_embeds.append( nn.ModuleList( [ - AbsolutePositionEmbedding( + PVTPosEmbedding( pos_shape=img_size // np.prod(patch_size[: i + 1]), - pos_dim=embed_dims[i], + pos_dim=embedding_dims[i], ) ] ) @@ -108,14 +122,14 @@ def __init__( nn.ModuleList( [ PVTEncoder( - dim=embed_dims[i], + dim=embedding_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, depth=depths[i], - attn_drop=attn_drop_rate, + attn_dropout=attn_dropout, drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], sr_ratio=sr_ratios[i], linear=linear, @@ -125,36 +139,58 @@ def __init__( ] ) ) - self.norms.append(norm_layer(embed_dims[i])) - # cls_token - self.pool = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + self.norms.append(norm_layer(embedding_dims[i])) + + self.pool = nn.Parameter(torch.zeros(1, 1, embedding_dims[-1])) def forward(self, x): + + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns list containing output features from all pyramid stages + + """ B = x.shape[0] out = [] + for i in range(len(self.depths)): + patch_embed = self.patch_embeds[i] block = self.blocks[i] norm = self.norms[i] x, H, W = patch_embed[0](x) + if self.ape: pos_embed = self.pos_embeds[i] x = pos_embed[0](x, H=H, W=W) + for blk in block: x = blk(x, H=H, W=W) + x = norm(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() out.append(x) + return out +@MODEL_REGISTRY.register() class PVTDetectionV2(PVTDetection): """ - Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v2 + Implementation of Pyramid Vision Transformer: + https://arxiv.org/abs/2102.12122v2 + - Parameters: - ----------- + Parameters + ---------- img_size: int Image size patch_size: list(int) @@ -163,7 +199,7 @@ class PVTDetectionV2(PVTDetection): Input channels in image, default=3 num_classes: int Number of classes for classification - embed_dims: int + embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer @@ -174,13 +210,14 @@ class PVTDetectionV2(PVTDetection): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - sr_ratio: float + sr_ratios: float Spatial reduction ratio linear: bool Whether to use linear spatial attention @@ -188,8 +225,6 @@ class PVTDetectionV2(PVTDetection): Whether to use Depth-wise convolutions in Overlap-patch embedding ape: bool Whether to use absolute position embedding - return_pyramid: bool - Whether to return all pyramid feature layers, if false returns only last feature layer, default is True """ def __init__( @@ -197,18 +232,18 @@ def __init__( img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, - embed_dims=[64, 128, 256, 512], + embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=0.0, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - use_abs_pos_embed=False, + ape=False, use_dwconv=True, linear=False, ): @@ -216,18 +251,18 @@ def __init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, - embed_dims=embed_dims, + embedding_dims=embedding_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, - attn_drop_rate=attn_drop_rate, + attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, norm_layer=norm_layer, depths=depths, sr_ratios=sr_ratios, linear=linear, - ape=use_abs_pos_embed, + ape=ape, use_dwconv=use_dwconv, ) diff --git a/vformer/models/dense/PVT/segmentation.py b/vformer/models/dense/PVT/segmentation.py index 8ad36454..789aff66 100644 --- a/vformer/models/dense/PVT/segmentation.py +++ b/vformer/models/dense/PVT/segmentation.py @@ -3,22 +3,25 @@ import torch.nn as nn from ....decoder import SegmentationHead -from ....encoder import AbsolutePositionEmbedding, OverlapPatchEmbed, PVTEncoder +from ....encoder import OverlapPatchEmbed, PVTEncoder, PVTPosEmbedding +from ....utils import MODEL_REGISTRY +@MODEL_REGISTRY.register() class PVTSegmentation(nn.Module): """ - Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v1 + Implementation of Pyramid Vision Transformer: + https://arxiv.org/abs/2102.12122v1 - Parameters: - ----------- + Parameters + ---------- img_size: int Image size patch_size: list(int) List of patch size in_channels: int Input channels in image, default=3 - embed_dims: int + embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer @@ -29,13 +32,14 @@ class PVTSegmentation(nn.Module): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional - p_dropout: float, + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set + p_dropout: float Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - sr_ratio: float + sr_ratios: float Spatial reduction ratio linear: bool Whether to use linear spatial attention @@ -44,7 +48,7 @@ class PVTSegmentation(nn.Module): ape: bool Whether to use absolute position embedding return_pyramid:bool - Whether to use all pyramid feature layers for up-sampling, default is true + Whether to use all pyramid feature layers for up-sampling, default is False """ def __init__( @@ -52,13 +56,13 @@ def __init__( img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, - embed_dims=[64, 128, 256, 512], + embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], @@ -70,18 +74,23 @@ def __init__( return_pyramid=False, ): super(PVTSegmentation, self).__init__() + self.ape = ape self.depths = depths self.return_pyramid = return_pyramid + assert ( - len(depths) == len(num_heads) == len(embed_dims) + len(depths) == len(num_heads) == len(embedding_dims) ), "Configurations do not match" + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.patch_embeds = nn.ModuleList([]) self.blocks = nn.ModuleList([]) self.norms = nn.ModuleList() self.pos_embeds = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append( nn.ModuleList( [ @@ -89,19 +98,22 @@ def __init__( img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), patch_size=patch_size[i], stride=4 if i == 0 else 2, - in_channels=in_channels if i == 0 else embed_dims[i - 1], - embed_dim=embed_dims[i], + in_channels=in_channels + if i == 0 + else embedding_dims[i - 1], + embedding_dim=embedding_dims[i], ) ] ) ) + if ape: self.pos_embeds.append( nn.ModuleList( [ - AbsolutePositionEmbedding( + PVTPosEmbedding( pos_shape=img_size // np.prod(patch_size[: i + 1]), - pos_dim=embed_dims[i], + pos_dim=embedding_dims[i], ) ] ) @@ -111,14 +123,14 @@ def __init__( nn.ModuleList( [ PVTEncoder( - dim=embed_dims[i], + dim=embedding_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, depth=depths[i], - attn_drop=attn_drop_rate, + attn_dropout=attn_dropout, drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], sr_ratio=sr_ratios[i], linear=linear, @@ -128,49 +140,70 @@ def __init__( ] ) ) - self.norms.append(norm_layer(embed_dims[i])) + self.norms.append(norm_layer(embedding_dims[i])) + self.head = SegmentationHead( out_channels=out_channels, - embed_dims=embed_dims if not return_pyramid else [embed_dims[-1]], + embed_dims=embedding_dims if not return_pyramid else [embedding_dims[-1]], ) def forward(self, x): + + """ + + Parameters + ---------- + x: torch.Tensor + Input tensor + Returns + ---------- + torch.Tensor + Returns output tensor + """ B = x.shape[0] out = [] + for i in range(len(self.depths)): + patch_embed = self.patch_embeds[i] block = self.blocks[i] norm = self.norms[i] x, H, W = patch_embed[0](x) + if self.ape: pos_embed = self.pos_embeds[i] x = pos_embed[0](x, H=H, W=W) + for blk in block: x = blk(x, H=H, W=W) + x = norm(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() out.append(x) + if self.return_pyramid: out = out[3:4] out = self.head(out) + return out +@MODEL_REGISTRY.register() class PVTSegmentationV2(PVTSegmentation): """ Implementation of Pyramid Vision Transformer - https://arxiv.org/abs/2102.12122v1 - Parameters: - ----------- + Parameters + ---------- img_size: int Image size patch_size: list(int) List of patch size in_channels: int Input channels in image, default=3 - embed_dims: int + embedding_dims: int Patch Embedding dimension num_heads:tuple[int] Number of heads in each transformer layer @@ -181,20 +214,21 @@ class PVTSegmentationV2(PVTSegmentation): qkv_bias: bool, default= True Adds bias to the qkv if true qk_scale: float, optional + Override default qk scale of head_dim ** -0.5 in Spatial Attention if set p_dropout: float, Dropout rate,default is 0.0 - attn_drop_rate: float, + attn_dropout: float, Attention dropout rate, default is 0.0 drop_path_rate: float Stochastic depth rate, default is 0.1 - sr_ratio: float + sr_ratios: float Spatial reduction ratio linear: bool - Whether to use linear spatial attention + Whether to use linear spatial attention, default is False use_dwconv: bool - Whether to use Depth-wise convolutions in Overlap-patch embedding + Whether to use Depth-wise convolutions in Overlap-patch embedding, default is True ape: bool - Whether to use absolute position embedding + Whether to use absolute position embedding, default is False return_pyramid: bool Whether to use all pyramid feature layers for up-sampling, default is true """ @@ -204,18 +238,18 @@ def __init__( img_size=224, patch_size=[7, 3, 3, 3], in_channels=3, - embed_dims=[64, 128, 256, 512], + embedding_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratio=[4, 4, 4, 4], qkv_bias=False, qk_scale=0.0, p_dropout=0.0, - attn_drop_rate=0.0, + attn_dropout=0.0, drop_path_rate=0.0, norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - use_abs_pos_embed=False, + ape=False, use_dwconv=True, linear=False, return_pyramid=False, @@ -224,19 +258,19 @@ def __init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, - embed_dims=embed_dims, + embedding_dims=embedding_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, p_dropout=p_dropout, - attn_drop_rate=attn_drop_rate, + attn_dropout=attn_dropout, drop_path_rate=drop_path_rate, norm_layer=norm_layer, depths=depths, sr_ratios=sr_ratios, linear=linear, - ape=use_abs_pos_embed, + ape=ape, use_dwconv=use_dwconv, return_pyramid=return_pyramid, ) diff --git a/vformer/utils/__init__.py b/vformer/utils/__init__.py index d52467cd..e14e44dd 100644 --- a/vformer/utils/__init__.py +++ b/vformer/utils/__init__.py @@ -1,2 +1,3 @@ +from .registry import * from .utils import pair from .window_utils import * diff --git a/vformer/utils/registry.py b/vformer/utils/registry.py new file mode 100644 index 00000000..595cc773 --- /dev/null +++ b/vformer/utils/registry.py @@ -0,0 +1,96 @@ +""" +Adapted from Detectron2 (https://github.com/facebookresearch/detectron2) +""" + + +class Registry: + """ + Class to register objects and then retrieve them by name. + Parameters + ---------- + name : str + Name of the registry + """ + + def __init__(self, name): + + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + + assert ( + name not in self._obj_map + ), f"An object named '{name}' was already registered in '{self._name}' registry!" + + self._obj_map[name] = obj + + def register(self, obj=None, name=None): + """ + Method to register an object in the registry + Parameters + ---------- + obj : object, optional + Object to register, defaults to None (which will return the decorator) + name : str, optional + Name of the object to register, defaults to None (which will use the name of the object) + """ + + if obj is None: + + def deco(func_or_class, name=name): + if name is None: + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + if name is None: # pragma: no cover + name = obj.__name__ + + self._do_register(name, obj) # pragma: no cover + + def get(self, name): + """ + Method to retrieve an object from the registry + Parameters + ---------- + name : str + Name of the object to retrieve + Returns + ------- + object + Object registered under the given name + """ + + ret = self._obj_map.get(name) + if ret is None: # pragma: no cover + raise KeyError( + f"No object named '{name}' found in '{self._name}' registry!" + ) + + return ret + + def get_list(self): + """ + Method to retrieve all objects from the registry + Returns + ------- + list + List of all objects registered in the registry + """ + + return list(self._obj_map.keys()) + + def __contains__(self, name): + return name in self._obj_map # pragma: no cover + + def __iter__(self): + return iter(self._obj_map.items()) # pragma: no cover + + +ATTENTION_REGISTRY = Registry("ATTENTION") +DECODER_REGISTRY = Registry("DECODER") +ENCODER_REGISTRY = Registry("ENCODER") +MODEL_REGISTRY = Registry("MODEL") diff --git a/vformer/utils/utils.py b/vformer/utils/utils.py index e999c4f0..bb84468f 100644 --- a/vformer/utils/utils.py +++ b/vformer/utils/utils.py @@ -1,7 +1,7 @@ def pair(t): """ - Parameters: - ----------- + Parameters + ---------- t: tuple[int] or int """ return t if isinstance(t, tuple) else (t, t) diff --git a/vformer/utils/window_utils.py b/vformer/utils/window_utils.py index 2e3d0835..234bc34f 100644 --- a/vformer/utils/window_utils.py +++ b/vformer/utils/window_utils.py @@ -5,13 +5,13 @@ def cyclicshift(input, shift_size, dims=None): """ - Parameters: + Parameters ---------- input: torch.Tensor input tensor - shift_size: int or tuple[int] + shift_size: int or tuple(int) Number of places by which input tensor is shifted - dims: int or tuple[int],optional + dims: int or tuple(int),optional Axis along which to roll """ @@ -22,8 +22,8 @@ def cyclicshift(input, shift_size, dims=None): def window_partition(x, window_size): """ - Parameters: - ----------- + Parameters + ---------- x: torch.Tensor input tensor window_size: int @@ -40,8 +40,8 @@ def window_partition(x, window_size): def window_reverse(windows, window_size, H, W): """ - Parameters: - ----------- + Parameters + ---------- windows: torch.Tensor window_size: int """ @@ -55,8 +55,8 @@ def window_reverse(windows, window_size, H, W): def get_relative_position_bias_index(window_size): """ - Parameters: - ------------ + Parameters + ---------- window_size: int or tuple[int] Window size """ @@ -78,8 +78,8 @@ def get_relative_position_bias_index(window_size): def create_mask(window_size, shift_size, H, W): """ - Parameters: - ----------- + Parameters + ---------- window_size: int Window Size shift_size: int From 6289f1745f430d7db5859b8fa486a400669c6870 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 15:22:50 +0530 Subject: [PATCH 29/67] Create patch_multiscale.py --- vformer/encoder/embedding/patch_multiscale.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 vformer/encoder/embedding/patch_multiscale.py diff --git a/vformer/encoder/embedding/patch_multiscale.py b/vformer/encoder/embedding/patch_multiscale.py new file mode 100644 index 00000000..fc1b2f05 --- /dev/null +++ b/vformer/encoder/embedding/patch_multiscale.py @@ -0,0 +1,47 @@ +class PatchEmbed(nn.Module): + """ + arameters + ---------- + img_size: int + Image Size + + dim_in: int + Number of input channels in the image + dim_out: int + Number of linear projection output channels + kernel: int + kernel Size + stride: int + stride Size + padding: int + padding Size + conv_2d : bool + Use nn.Conv2D if true, nn.conv3D if fals3 + """ + + def __init__( + self, + dim_in=3, + dim_out=768, + kernel=(1, 16, 16), + stride=(1, 4, 4), + padding=(1, 7, 7), + conv_2d=False, + ): + super().__init__() + if conv_2d: + conv = nn.Conv2d + else: + conv = nn.Conv3d + self.proj = conv( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + return x.flatten(2).transpose(1, 2) From b3e85068d499761bb6802000795a4481f9208771 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 15:59:56 +0530 Subject: [PATCH 30/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 255 +++++++++++++++++++- 1 file changed, 253 insertions(+), 2 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 80cf910d..66412537 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -2,10 +2,156 @@ import torch.nn as nn from vformer.common import BaseClassificationModel +from vformer.encoder.embedding.patch_multiscale import PatchEmbed from vformer.decoder.mlp import MLPDecoder -from vformer.encoder.multiscale import MultiScaleBlock +from vformer.encoder.multiscale import MultiScaleBlock -attention_block = MultiScaleBlock( +@MODEL_REGISTRY.register() +class MultiScaleViT(BaseClassificationModel): + """ + Implementation of 'Multiscale Vision Transformers' + https://arxiv.org/abs/2104.11227 + Parameters + ---------- + """ +def __init__(self, cfg): + super().__init__() + # Get parameters. + pool_first = False + # Prepare input. + spatial_size = 224 + temporal_size = 8 + in_chans = 3 + use_2d_patch = False + self.patch_stride = [2, 4, 4] + if use_2d_patch: + self.patch_stride = [1] + self.patch_stride + # Prepare output. + num_classes = 400 + embed_dim = 96 + # Prepare backbone + num_heads = 1 + mlp_ratio = 4.0 + qkv_bias = True + self.drop_rate = 0.0 + depth = 16 + drop_path_rate = 0.1 + mode = "conv" + self.cls_embed_on = True + self.sep_pos_embed = False + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.patch_embed = stem_helper.PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=cfg.MVIT.PATCH_KERNEL, + stride=cfg.MVIT.PATCH_STRIDE, + padding=cfg.MVIT.PATCH_PADDING, + conv_2d=use_2d_patch, + ) + self.input_dims = [temporal_size, spatial_size, spatial_size] + assert self.input_dims[1] == self.input_dims[2] + self.patch_dims = [ + self.input_dims[i] // self.patch_stride[i] + for i in range(len(self.input_dims)) + ] + num_patches = math.prod(self.patch_dims) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if self.cls_embed_on: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + pos_embed_dim = num_patches + 1 + else: + pos_embed_dim = num_patches + + if self.sep_pos_embed: + self.pos_embed_spatial = nn.Parameter( + torch.zeros( + 1, self.patch_dims[1] * self.patch_dims[2], embed_dim + ) + ) + self.pos_embed_temporal = nn.Parameter( + torch.zeros(1, self.patch_dims[0], embed_dim) + ) + if self.cls_embed_on: + self.pos_embed_class = nn.Parameter( + torch.zeros(1, 1, embed_dim) + ) + else: + self.pos_embed = nn.Parameter( + torch.zeros(1, pos_embed_dim, embed_dim) + ) + + if self.drop_rate > 0.0: + self.pos_drop = nn.Dropout(p=self.drop_rate) + + dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) + for i in range(len(cfg.MVIT.DIM_MUL)): + dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] + for i in range(len(cfg.MVIT.HEAD_MUL)): + head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] + + pool_q = [[] for i in range(cfg.MVIT.DEPTH)] + pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] + stride_q = [[] for i in range(cfg.MVIT.DEPTH)] + stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] + + for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): + stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ + 1: + ] + if cfg.MVIT.POOL_KVQ_KERNEL is not None: + pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL + else: + pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ + s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] + ] + + # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. + if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: + _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE + cfg.MVIT.POOL_KV_STRIDE = [] + for i in range(cfg.MVIT.DEPTH): + if len(stride_q[i]) > 0: + _stride_kv = [ + max(_stride_kv[d] // stride_q[i][d], 1) + for d in range(len(_stride_kv)) + ] + cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) + + for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): + stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ + i + ][1:] + if cfg.MVIT.POOL_KVQ_KERNEL is not None: + pool_kv[ + cfg.MVIT.POOL_KV_STRIDE[i][0] + ] = cfg.MVIT.POOL_KVQ_KERNEL + else: + pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ + s + 1 if s > 1 else s + for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] + ] + + self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None + + self.blocks = nn.ModuleList() + + if cfg.MODEL.ACT_CHECKPOINT: + validate_checkpoint_wrapper_import(checkpoint_wrapper) + + for i in range(depth): + num_heads = round_width(num_heads, head_mul[i]) + embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) + dim_out = round_width( + embed_dim, + dim_mul[i + 1], + divisor=round_width(num_heads, head_mul[i + 1]), + ) + attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, @@ -22,3 +168,108 @@ has_cls_embed=self.cls_embed_on, pool_first=pool_first, ) + if cfg.MODEL.ACT_CHECKPOINT: + attention_block = checkpoint_wrapper(attention_block) + self.blocks.append(attention_block) + + embed_dim = dim_out + self.norm = norm_layer(embed_dim) + + self.head = head_helper.TransformerBasicHead( + embed_dim, + num_classes, + dropout_rate=cfg.MODEL.DROPOUT_RATE, + act_func=cfg.MODEL.HEAD_ACT, + ) + if self.sep_pos_embed: + trunc_normal_(self.pos_embed_spatial, std=0.02) + trunc_normal_(self.pos_embed_temporal, std=0.02) + if self.cls_embed_on: + trunc_normal_(self.pos_embed_class, std=0.02) + else: + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_embed_on: + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.cfg.MVIT.ZERO_DECAY_POS_CLS: + if self.sep_pos_embed: + if self.cls_embed_on: + return { + "pos_embed_spatial", + "pos_embed_temporal", + "pos_embed_class", + "cls_token", + } + else: + return { + "pos_embed_spatial", + "pos_embed_temporal", + "pos_embed_class", + } + else: + if self.cls_embed_on: + return {"pos_embed", "cls_token"} + else: + return {"pos_embed"} + else: + return {} + + def forward(self, x): + x = x[0] + x = self.patch_embed(x) + + T = self.cfg.DATA.NUM_FRAMES // self.patch_stride[0] + H = self.cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1] + W = self.cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2] + B, N, C = x.shape + + if self.cls_embed_on: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if self.sep_pos_embed: + pos_embed = self.pos_embed_spatial.repeat( + 1, self.patch_dims[0], 1 + ) + torch.repeat_interleave( + self.pos_embed_temporal, + self.patch_dims[1] * self.patch_dims[2], + dim=1, + ) + if self.cls_embed_on: + pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) + x = x + pos_embed + else: + x = x + self.pos_embed + + if self.drop_rate: + x = self.pos_drop(x) + + if self.norm_stem: + x = self.norm_stem(x) + + thw = [T, H, W] + for blk in self.blocks: + x, thw = blk(x, thw) + + x = self.norm(x) + if self.cls_embed_on: + x = x[:, 0] + else: + x = x.mean(1) + + x = self.head(x) + return x From a0ea7d1f399bc998e8791ab1676507e103f4abb5 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 16:16:40 +0530 Subject: [PATCH 31/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 66412537..53fbcd13 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -39,6 +39,7 @@ def __init__(self, cfg): mode = "conv" self.cls_embed_on = True self.sep_pos_embed = False + norm_stem = False norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( @@ -94,10 +95,10 @@ def __init__(self, cfg): for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] - pool_q = [[] for i in range(cfg.MVIT.DEPTH)] - pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] - stride_q = [[] for i in range(cfg.MVIT.DEPTH)] - stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] + pool_q = [[] for i in range(depth)] + pool_kv = [[] for i in range(depth)] + stride_q = [[] for i in range(depth)] + stride_kv = [[] for i in range(depth)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ @@ -136,7 +137,7 @@ def __init__(self, cfg): for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] - self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None + self.norm_stem = norm_layer(embed_dim) if norm_stem else None self.blocks = nn.ModuleList() From c2bffb88934128bb5d22615c96f476301c1428fb Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 4 Jan 2022 16:35:06 +0530 Subject: [PATCH 32/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 36 ++++----------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 53fbcd13..cccad44e 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn @@ -202,38 +203,13 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - @torch.jit.ignore - def no_weight_decay(self): - if self.cfg.MVIT.ZERO_DECAY_POS_CLS: - if self.sep_pos_embed: - if self.cls_embed_on: - return { - "pos_embed_spatial", - "pos_embed_temporal", - "pos_embed_class", - "cls_token", - } - else: - return { - "pos_embed_spatial", - "pos_embed_temporal", - "pos_embed_class", - } - else: - if self.cls_embed_on: - return {"pos_embed", "cls_token"} - else: - return {"pos_embed"} - else: - return {} - - def forward(self, x): + def forward(self,spatial_size,temporal_size, x): x = x[0] x = self.patch_embed(x) - - T = self.cfg.DATA.NUM_FRAMES // self.patch_stride[0] - H = self.cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1] - W = self.cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2] + + T = temporal_size // self.patch_stride[0] + H = spatial_size // self.patch_stride[1] + W = spatial_size // self.patch_stride[2] B, N, C = x.shape if self.cls_embed_on: From d616f0c7308cff0564cedd21ca60aed778514e23 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 12:53:18 +0530 Subject: [PATCH 33/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 58 +++++++++++---------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index cccad44e..089ee9e3 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -42,13 +42,18 @@ def __init__(self, cfg): self.sep_pos_embed = False norm_stem = False norm_layer = partial(nn.LayerNorm, eps=1e-6) + DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + POOL_KVQ_KERNEL: [3, 3, 3] + POOL_KV_STRIDE_ADAPTIVE: [1, 8, 8] + POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, - kernel=cfg.MVIT.PATCH_KERNEL, - stride=cfg.MVIT.PATCH_STRIDE, - padding=cfg.MVIT.PATCH_PADDING, + kernel=(3, 7, 7), + stride=(2, 4, 4), + padding=(1, 3, 3), conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] @@ -91,60 +96,57 @@ def __init__(self, cfg): self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) - for i in range(len(cfg.MVIT.DIM_MUL)): - dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] - for i in range(len(cfg.MVIT.HEAD_MUL)): - head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] + for i in range(len(DIM_MUL)): + dim_mul[DIM_MUL[i][0]] = DIM_MUL[i][1] + for i in range(len(HEAD_MUL)): + head_mul[HEAD_MUL[i][0]] = HEAD_MUL[i][1] pool_q = [[] for i in range(depth)] pool_kv = [[] for i in range(depth)] stride_q = [[] for i in range(depth)] stride_kv = [[] for i in range(depth)] - for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): - stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ + for i in range(len(POOL_Q_STRIDE)): + stride_q[POOL_Q_STRIDE[i][0]] = POOL_Q_STRIDE[i][ 1: ] if cfg.MVIT.POOL_KVQ_KERNEL is not None: - pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL + pool_q[POOL_Q_STRIDE[i][0]] = POOL_KVQ_KERNEL else: - pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ - s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] + pool_q[POOL_Q_STRIDE[i][0]] = [ + s + 1 if s > 1 else s for s in POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. - if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: - _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE - cfg.MVIT.POOL_KV_STRIDE = [] - for i in range(cfg.MVIT.DEPTH): + if POOL_KV_STRIDE_ADAPTIVE is not None: + _stride_kv = POOL_KV_STRIDE_ADAPTIVE + POOL_KV_STRIDE = [] + for i in range(depth): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] - cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) + POOL_KV_STRIDE.append([i] + _stride_kv) - for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): - stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ + for i in range(len(POOL_KV_STRIDE)): + stride_kv[POOL_KV_STRIDE[i][0]] = POOL_KV_STRIDE[ i ][1:] - if cfg.MVIT.POOL_KVQ_KERNEL is not None: + if POOL_KVQ_KERNEL is not None: pool_kv[ - cfg.MVIT.POOL_KV_STRIDE[i][0] - ] = cfg.MVIT.POOL_KVQ_KERNEL + POOL_KV_STRIDE[i][0] + ] = KVQ_KERNEL else: - pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ + pool_kv[POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s - for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] + for s in POOL_KV_STRIDE[i][1:] ] self.norm_stem = norm_layer(embed_dim) if norm_stem else None self.blocks = nn.ModuleList() - if cfg.MODEL.ACT_CHECKPOINT: - validate_checkpoint_wrapper_import(checkpoint_wrapper) - for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads) @@ -180,7 +182,7 @@ def __init__(self, cfg): self.head = head_helper.TransformerBasicHead( embed_dim, num_classes, - dropout_rate=cfg.MODEL.DROPOUT_RATE, + dropout_rate=self.drop_rate, act_func=cfg.MODEL.HEAD_ACT, ) if self.sep_pos_embed: From ebbc04b13a2dc8487a5ee2e22ef3dce23d13b1fe Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:06:43 +0530 Subject: [PATCH 34/67] Update vformer/encoder/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/encoder/multiscale.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vformer/encoder/multiscale.py b/vformer/encoder/multiscale.py index 6e4f9506..11cd13f0 100644 --- a/vformer/encoder/multiscale.py +++ b/vformer/encoder/multiscale.py @@ -4,7 +4,9 @@ from timm.models.layers import DropPath from .nn import FeedForward as Mlp from ..attention import MultiScaleAttention +from ..utils import ENCODER_REGISTRY +@ENCODER_REGISTRY.register() class MultiScaleBlock(nn.Module): """ Multiscale Attention Block From f57e04a820d940725b9e08521aef4b150d060d23 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:06:49 +0530 Subject: [PATCH 35/67] Update vformer/attention/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/attention/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 637d6750..21360376 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -52,7 +52,7 @@ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): tensor = tensor.squeeze(1) return tensor, thw_shape - +@ATTENTION_REGISTRY.register() class MultiScaleAttention(nn.Module): """ Multiscale Attention From 290e79fe4e1f5b18386a20e95e72d1a02b4381fd Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:15:34 +0530 Subject: [PATCH 36/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 089ee9e3..ead43b0f 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn -from vformer.common import BaseClassificationModel -from vformer.encoder.embedding.patch_multiscale import PatchEmbed -from vformer.decoder.mlp import MLPDecoder -from vformer.encoder.multiscale import MultiScaleBlock +from ...common import BaseClassificationModel +from ...decoder import MLPDecoder +from ...encoder import MultiScaleBlock, PatchEmbed +from ...utils import MODEL_REGISTRY @MODEL_REGISTRY.register() class MultiScaleViT(BaseClassificationModel): From c9d4044b3bfb37995317f33a3dcce2c9957a6812 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:17:09 +0530 Subject: [PATCH 37/67] Update __init__.py --- vformer/attention/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vformer/attention/__init__.py b/vformer/attention/__init__.py index 04f69528..b62b13cd 100644 --- a/vformer/attention/__init__.py +++ b/vformer/attention/__init__.py @@ -2,3 +2,4 @@ from .spatial import SpatialAttention from .vanilla import VanillaSelfAttention from .window import WindowAttention +from .multiscale import MultiScaleAttention From 60d9e976d180a6a547b8b8beb0c38125a348bbdd Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:17:58 +0530 Subject: [PATCH 38/67] Update multiscale.py --- vformer/attention/multiscale.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vformer/attention/multiscale.py b/vformer/attention/multiscale.py index 21360376..6c5eab22 100644 --- a/vformer/attention/multiscale.py +++ b/vformer/attention/multiscale.py @@ -1,6 +1,7 @@ import numpy import torch import torch.nn as nn +from ..utils import ATTENTION_REGISTRY def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): """ From a0d20632ade6caad04e20aa935830b281c9a3a49 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:18:40 +0530 Subject: [PATCH 39/67] Update __init__.py --- vformer/encoder/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vformer/encoder/__init__.py b/vformer/encoder/__init__.py index 8688a84a..764b3060 100644 --- a/vformer/encoder/__init__.py +++ b/vformer/encoder/__init__.py @@ -4,3 +4,4 @@ from .pyramid import PVTEncoder from .swin import SwinEncoder, SwinEncoderBlock from .vanilla import VanillaEncoder +from .multiscale import MultiScaleBlock From 930c596155d4642226ca14ecdeaabf0e99053264 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 13:37:35 +0530 Subject: [PATCH 40/67] Update test_attention.py --- tests/test_attention.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 2b38692d..7dd2877a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -51,6 +51,21 @@ def test_CrossAttention(): assert out.shape == test_tensor1.shape del attention +def test_MultiScaleAttention(): + + test_tensor1 = torch.randn(96,8,56,56) + test_tensor2 = torch.randn(768,8,14,14) + + attention = ATTENTION_REGISTRY.get("MultiScaleAttention")(dim=192) + out = attention(test_tensor1) + assert out.shape == (192,8,28,28) + del attention + + attention = ATTENTION_REGISTRY.get("VanillaSelfAttention")(dim=768) + out = attention(test_tensor2) + assert out.shape == (768,8,14,14) + del attention + def test_SpatialAttention(): From 0e619710ce50712924a5085eb6c5c7d16ee9e1cc Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 14:08:17 +0530 Subject: [PATCH 41/67] Update test_encoder.py --- tests/test_encoder.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index cae8386a..e3104a39 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -56,7 +56,22 @@ def test_SwinEncoder(): out = encoder_block(test_tensor) assert out.shape == test_tensor.shape - +def test_MultiScaleBlock(): + + test_tensor1 = torch.randn(96,8,56,56) + encoder1 = ENCODER_REGISTRY.get("VanillaEncoder")(dim=192) + out1 = encoder1(test_tensor) + assert out1.shape == (192,8,28,28) + + test_tensor2 = torch.randn(768,8,14,14) + encoder2 = ENCODER_REGISTRY.get("VanillaEncoder")(dim=768) + out2 = encoder2(test_tensor) + assert out2.shape == (768,8,14,14) # shape remains same + + + del encoder1, encoder2, test_tensor1, test_tensor2 + + def test_PVTEncoder(): test_tensor = torch.randn(4, 3136, 64) From 313b2d0a97b93c76fbec10b927a002c25b408c10 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 15:12:45 +0530 Subject: [PATCH 42/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 71 +++++++++++---------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index ead43b0f..7390e256 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -15,45 +15,52 @@ class MultiScaleViT(BaseClassificationModel): Parameters ---------- """ -def __init__(self, cfg): +def __init__(self, + spatial_size = 224, + pool_first = False, + temporal_size = 8, + in_chans = 3, + use_2d_patch = False, + patch_stride = [2,4,4], + num_classes = 400 + embed_dim = 96 + num_heads = 1 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 + depth = 16 + drop_path_rate = 0.1 + mode = "conv" + cls_embed_on = True + sep_pos_embed = False + norm_stem = False + norm_layer = partial(nn.LayerNorm, eps=1e-6) + patch_kernel = (3, 7, 7) + patch_stride = (2, 4, 4) + patch_padding = (1, 3, 3) + DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] + POOL_KVQ_KERNEL: [3, 3, 3] + POOL_KV_STRIDE_ADAPTIVE: [1, 8, 8] + POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] + ): super().__init__() - # Get parameters. - pool_first = False - # Prepare input. - spatial_size = 224 - temporal_size = 8 - in_chans = 3 - use_2d_patch = False - self.patch_stride = [2, 4, 4] + self.patch_stride = patch_stride if use_2d_patch: self.patch_stride = [1] + self.patch_stride - # Prepare output. - num_classes = 400 - embed_dim = 96 - # Prepare backbone - num_heads = 1 - mlp_ratio = 4.0 - qkv_bias = True - self.drop_rate = 0.0 - depth = 16 - drop_path_rate = 0.1 - mode = "conv" - self.cls_embed_on = True - self.sep_pos_embed = False - norm_stem = False - norm_layer = partial(nn.LayerNorm, eps=1e-6) - DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] - HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] - POOL_KVQ_KERNEL: [3, 3, 3] - POOL_KV_STRIDE_ADAPTIVE: [1, 8, 8] - POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] + + self.drop_rate = DROPOUT_RATE + + self.cls_embed_on = cls_embed_on + self.sep_pos_embed = sep_pos_embed + self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, - kernel=(3, 7, 7), - stride=(2, 4, 4), - padding=(1, 3, 3), + kernel=patch_kernel, + stride=patch_stride, + padding=patch_padding, conv_2d=use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] From 8a9a06ca0d61ff2fd8b6636a803a6b28a358b676 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 15:19:00 +0530 Subject: [PATCH 43/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 7390e256..c1c772f9 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -49,7 +49,7 @@ def __init__(self, if use_2d_patch: self.patch_stride = [1] + self.patch_stride - self.drop_rate = DROPOUT_RATE + self.drop_rate = drop_rate self.cls_embed_on = cls_embed_on self.sep_pos_embed = sep_pos_embed From 36010fda11b14e9e59b745f3907a413498ba1b27 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 15:21:38 +0530 Subject: [PATCH 44/67] Update test_models.py --- tests/test_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 5995aaf1..3fb6ccc1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -159,7 +159,15 @@ def test_CrossVit(): assert out.shape == (2, 10) del model +def test_MultiScale(): + model = MODEL_REGISTRY.get("MultiScaleViT")() + out = model(img_3channels_224) + assert out.shape == (8, 400) + del model + + + def test_pvt(): # classification model = MODEL_REGISTRY.get("PVTClassification")( From 558dcffaaa7ad411dfea0dba1b4fb36bf0f4583f Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Fri, 14 Jan 2022 21:20:41 +0530 Subject: [PATCH 45/67] update --- tests/test_attention.py | 3 ++- tests/test_encoder.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 7dd2877a..4ad2a010 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -55,9 +55,10 @@ def test_MultiScaleAttention(): test_tensor1 = torch.randn(96,8,56,56) test_tensor2 = torch.randn(768,8,14,14) + thw = [2,2,2] attention = ATTENTION_REGISTRY.get("MultiScaleAttention")(dim=192) - out = attention(test_tensor1) + out = attention(test_tensor1, thw) assert out.shape == (192,8,28,28) del attention diff --git a/tests/test_encoder.py b/tests/test_encoder.py index e3104a39..b02d8c46 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -59,12 +59,12 @@ def test_SwinEncoder(): def test_MultiScaleBlock(): test_tensor1 = torch.randn(96,8,56,56) - encoder1 = ENCODER_REGISTRY.get("VanillaEncoder")(dim=192) + encoder1 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=192) out1 = encoder1(test_tensor) assert out1.shape == (192,8,28,28) test_tensor2 = torch.randn(768,8,14,14) - encoder2 = ENCODER_REGISTRY.get("VanillaEncoder")(dim=768) + encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=768) out2 = encoder2(test_tensor) assert out2.shape == (768,8,14,14) # shape remains same From e555aa017a46f3222594d2a2e667d475370091ea Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Sat, 15 Jan 2022 01:03:22 +0530 Subject: [PATCH 46/67] Update vformer/encoder/embedding/patch_multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/encoder/embedding/patch_multiscale.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vformer/encoder/embedding/patch_multiscale.py b/vformer/encoder/embedding/patch_multiscale.py index fc1b2f05..b0654590 100644 --- a/vformer/encoder/embedding/patch_multiscale.py +++ b/vformer/encoder/embedding/patch_multiscale.py @@ -1,3 +1,5 @@ +import torch.nn as nn + class PatchEmbed(nn.Module): """ arameters From 6fac3a77d7693e435b6c83d11825dbe5078bf188 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 00:16:48 +0530 Subject: [PATCH 47/67] Update tests --- tests/test_attention.py | 15 +++++---------- tests/test_encoder.py | 18 ++++++++++-------- vformer/encoder/multiscale.py | 1 + vformer/models/classification/multiscale.py | 2 +- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 4ad2a010..ba747ec0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -53,18 +53,13 @@ def test_CrossAttention(): def test_MultiScaleAttention(): - test_tensor1 = torch.randn(96,8,56,56) - test_tensor2 = torch.randn(768,8,14,14) + test_tensor1 = torch.randn(96,56,56) + test_tensor2 = torch.randn(768,14,14) thw = [2,2,2] - attention = ATTENTION_REGISTRY.get("MultiScaleAttention")(dim=192) - out = attention(test_tensor1, thw) - assert out.shape == (192,8,28,28) - del attention - - attention = ATTENTION_REGISTRY.get("VanillaSelfAttention")(dim=768) - out = attention(test_tensor2) - assert out.shape == (768,8,14,14) + attention = ATTENTION_REGISTRY.get("MultiScaleAttention")(dim=56) + out,_ = attention(test_tensor1, thw) + assert out.shape == (96,56,56) del attention diff --git a/tests/test_encoder.py b/tests/test_encoder.py index b02d8c46..f38c5ce0 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -58,15 +58,17 @@ def test_SwinEncoder(): def test_MultiScaleBlock(): - test_tensor1 = torch.randn(96,8,56,56) - encoder1 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=192) - out1 = encoder1(test_tensor) - assert out1.shape == (192,8,28,28) + thw = [1,5,11] + test_tensor1 = torch.randn(96,56,56) + encoder1 = ENCODER_REGISTRY.get("MultiScaleBlock")(56,56,8) + out1, _ = encoder1(test_tensor1, thw) + assert out1.shape == (96,56,56) - test_tensor2 = torch.randn(768,8,14,14) - encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(dim=768) - out2 = encoder2(test_tensor) - assert out2.shape == (768,8,14,14) # shape remains same + thw = [1,13,1] + test_tensor2 = torch.randn(768,14,14) + encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(14,14,14) + out2, _ = encoder2(test_tensor2, thw) + assert out2.shape == (768,14,14) # shape remains same del encoder1, encoder2, test_tensor1, test_tensor2 diff --git a/vformer/encoder/multiscale.py b/vformer/encoder/multiscale.py index 11cd13f0..6199ac8a 100644 --- a/vformer/encoder/multiscale.py +++ b/vformer/encoder/multiscale.py @@ -5,6 +5,7 @@ from .nn import FeedForward as Mlp from ..attention import MultiScaleAttention from ..utils import ENCODER_REGISTRY +from ..attention.multiscale import attention_pool @ENCODER_REGISTRY.register() class MultiScaleBlock(nn.Module): diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index c1c772f9..f3cea62e 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -15,7 +15,7 @@ class MultiScaleViT(BaseClassificationModel): Parameters ---------- """ -def __init__(self, + def __init__(self, spatial_size = 224, pool_first = False, temporal_size = 8, From 7cc4ceabb028db893551eeb9b0cdf97d294c3e33 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 00:48:02 +0530 Subject: [PATCH 48/67] Update tests --- tests/test_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index f38c5ce0..1e38dc39 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -66,9 +66,9 @@ def test_MultiScaleBlock(): thw = [1,13,1] test_tensor2 = torch.randn(768,14,14) - encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(14,14,14) + encoder2 = ENCODER_REGISTRY.get("MultiScaleBlock")(14,28,7) out2, _ = encoder2(test_tensor2, thw) - assert out2.shape == (768,14,14) # shape remains same + assert out2.shape == (768,14,28) # shape remains same del encoder1, encoder2, test_tensor1, test_tensor2 From 077e2e78c22a6a709e7ea92b8e47446fade6a097 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:56:41 +0530 Subject: [PATCH 49/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index f3cea62e..2814bd42 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -35,14 +35,13 @@ def __init__(self, sep_pos_embed = False norm_stem = False norm_layer = partial(nn.LayerNorm, eps=1e-6) - patch_kernel = (3, 7, 7) - patch_stride = (2, 4, 4) - patch_padding = (1, 3, 3) - DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] - HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]] - POOL_KVQ_KERNEL: [3, 3, 3] - POOL_KV_STRIDE_ADAPTIVE: [1, 8, 8] - POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] + patch_kernel = (3, 7, 7), + patch_padding = (1, 3, 3), + DIM_MUL=[[1, 2.0], [3, 2.0], [14, 2.0]], + HEAD_MUL= [[1, 2.0], [3, 2.0], [14, 2.0]], + POOL_KVQ_KERNEL= [3, 3, 3], + POOL_KV_STRIDE_ADAPTIVE=[1, 8, 8], + POOL_Q_STRIDE= [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], ): super().__init__() self.patch_stride = patch_stride From 6fea53d2133e76ec54ef91f47323d834f5c90ded Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:56:50 +0530 Subject: [PATCH 50/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 2814bd42..e1f5013c 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -22,13 +22,13 @@ def __init__(self, in_chans = 3, use_2d_patch = False, patch_stride = [2,4,4], - num_classes = 400 - embed_dim = 96 - num_heads = 1 - mlp_ratio = 4.0 - qkv_bias = True - drop_rate = 0.0 - depth = 16 + num_classes = 400, + embed_dim = 96, + num_heads = 1, + mlp_ratio = 4.0, + qkv_bias = True, + drop_rate = 0.0, + depth = 16, drop_path_rate = 0.1 mode = "conv" cls_embed_on = True From e43a8581a0ae4faad61b06e0cc15648750c6658b Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:57:07 +0530 Subject: [PATCH 51/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index e1f5013c..697099fd 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -1,7 +1,7 @@ import math import torch import torch.nn as nn - +from functools import partial from ...common import BaseClassificationModel from ...decoder import MLPDecoder from ...encoder import MultiScaleBlock, PatchEmbed From 372aa69a65505c71f9aafbd55bc4528b4ad5207c Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:57:20 +0530 Subject: [PATCH 52/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 697099fd..edd3bcb0 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -116,7 +116,7 @@ def __init__(self, stride_q[POOL_Q_STRIDE[i][0]] = POOL_Q_STRIDE[i][ 1: ] - if cfg.MVIT.POOL_KVQ_KERNEL is not None: + if POOL_KVQ_KERNEL is not None: pool_q[POOL_Q_STRIDE[i][0]] = POOL_KVQ_KERNEL else: pool_q[POOL_Q_STRIDE[i][0]] = [ From 5ad17bd5f85d1d3c956bb91002a32c9ba3b3e491 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:57:28 +0530 Subject: [PATCH 53/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index edd3bcb0..6997e0f2 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -178,7 +178,7 @@ def __init__(self, has_cls_embed=self.cls_embed_on, pool_first=pool_first, ) - if cfg.MODEL.ACT_CHECKPOINT: + if ACT_CHECKPOINT: attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) From 8b0b82e7a5a0ab1b73bfc9a1d5984560c00b8d24 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 17 Jan 2022 07:57:36 +0530 Subject: [PATCH 54/67] Update vformer/models/classification/multiscale.py Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 6997e0f2..c2845660 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -189,7 +189,7 @@ def __init__(self, embed_dim, num_classes, dropout_rate=self.drop_rate, - act_func=cfg.MODEL.HEAD_ACT, + act_func=HEAD_ACT, ) if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) From a0a2e180e06e3022f59995e703cdccfdaf3b3f15 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 19:45:37 +0530 Subject: [PATCH 55/67] Import trunc_normal_ --- vformer/models/classification/multiscale.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index c2845660..4cc543d4 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -6,6 +6,7 @@ from ...decoder import MLPDecoder from ...encoder import MultiScaleBlock, PatchEmbed from ...utils import MODEL_REGISTRY +from timm.models.layers import trunc_normal_ @MODEL_REGISTRY.register() class MultiScaleViT(BaseClassificationModel): From 14ddfb835882accafc8c509c266d644b78ce9dc9 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 20:05:55 +0530 Subject: [PATCH 56/67] Create multiscale.py --- vformer/utils/multiscale.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 vformer/utils/multiscale.py diff --git a/vformer/utils/multiscale.py b/vformer/utils/multiscale.py new file mode 100644 index 00000000..2bf018d7 --- /dev/null +++ b/vformer/utils/multiscale.py @@ -0,0 +1,29 @@ +import logging + +logger = logging.get_logger(__name__) + +def get_logger(name): + """ + Retrieve the logger with the specified name or, if name is None, return a + logger which is the root logger of the hierarchy. + Args: + name (string): name of the logger. + """ + return logging.getLogger(name) + + +def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): + if not multiplier: + return width + width *= multiplier + min_width = min_width or divisor + if verbose: + logger.info(f"min width {min_width}") + logger.info(f"width {width} divisor {divisor}") + logger.info(f"other {int(width + divisor / 2) // divisor * divisor}") + + width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) + if width_out < 0.9 * width: + width_out += divisor + return int(width_out) + From 5b603a22e0cdf663633d2602f5f27fee0a7626c1 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 20:08:33 +0530 Subject: [PATCH 57/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 4cc543d4..a9b106d0 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -6,6 +6,7 @@ from ...decoder import MLPDecoder from ...encoder import MultiScaleBlock, PatchEmbed from ...utils import MODEL_REGISTRY +from ...utils.multiscale import round_width from timm.models.layers import trunc_normal_ @MODEL_REGISTRY.register() @@ -55,7 +56,7 @@ def __init__(self, self.sep_pos_embed = sep_pos_embed self.num_classes = num_classes - self.patch_embed = stem_helper.PatchEmbed( + self.patch_embed = PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=patch_kernel, From c40f8442901b9f02005846872f370fdbacd0e7a6 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 20:09:42 +0530 Subject: [PATCH 58/67] Update multiscale.py --- vformer/utils/multiscale.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vformer/utils/multiscale.py b/vformer/utils/multiscale.py index 2bf018d7..265da991 100644 --- a/vformer/utils/multiscale.py +++ b/vformer/utils/multiscale.py @@ -1,7 +1,5 @@ import logging -logger = logging.get_logger(__name__) - def get_logger(name): """ Retrieve the logger with the specified name or, if name is None, return a @@ -11,6 +9,7 @@ def get_logger(name): """ return logging.getLogger(name) +logger = logging.get_logger(__name__) def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): if not multiplier: From 9c147153ced9a39482fec4c68361633569eae2ee Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 22:29:33 +0530 Subject: [PATCH 59/67] Update __init__.py --- vformer/models/classification/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vformer/models/classification/__init__.py b/vformer/models/classification/__init__.py index e6daa1ad..59f5a599 100644 --- a/vformer/models/classification/__init__.py +++ b/vformer/models/classification/__init__.py @@ -4,3 +4,4 @@ from .pyramid import PVTClassification, PVTClassificationV2 from .swin import SwinTransformer from .vanilla import VanillaViT +from .multiscale import MultiScaleViT From 3947d2c0f95204b38b122f2beb632c82197dd243 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 22:31:51 +0530 Subject: [PATCH 60/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index a9b106d0..9e21258f 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -31,12 +31,12 @@ def __init__(self, qkv_bias = True, drop_rate = 0.0, depth = 16, - drop_path_rate = 0.1 - mode = "conv" - cls_embed_on = True - sep_pos_embed = False - norm_stem = False - norm_layer = partial(nn.LayerNorm, eps=1e-6) + drop_path_rate = 0.1, + mode = "conv", + cls_embed_on = True, + sep_pos_embed = False, + norm_stem = False, + norm_layer = partial(nn.LayerNorm, eps=1e-6), patch_kernel = (3, 7, 7), patch_padding = (1, 3, 3), DIM_MUL=[[1, 2.0], [3, 2.0], [14, 2.0]], From 014c6e53e495f5c6ad527b6d3f18606cae669e87 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 22:34:10 +0530 Subject: [PATCH 61/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 9e21258f..562e6c5b 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -4,7 +4,8 @@ from functools import partial from ...common import BaseClassificationModel from ...decoder import MLPDecoder -from ...encoder import MultiScaleBlock, PatchEmbed +from ...encoder import MultiScaleBlock +from ...encoder.embedding import MultiScaleBlockPatchEmbed from ...utils import MODEL_REGISTRY from ...utils.multiscale import round_width from timm.models.layers import trunc_normal_ From 85fd77133e96e6d3d53c534f84de873c70ea5dda Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 22:35:22 +0530 Subject: [PATCH 62/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 562e6c5b..ab41c56d 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -5,7 +5,7 @@ from ...common import BaseClassificationModel from ...decoder import MLPDecoder from ...encoder import MultiScaleBlock -from ...encoder.embedding import MultiScaleBlockPatchEmbed +from ...encoder.embedding import PatchEmbed from ...utils import MODEL_REGISTRY from ...utils.multiscale import round_width from timm.models.layers import trunc_normal_ From b43eadb7f1fd45ab77cc9f328a70cdc9d7ef9c9e Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:25:59 +0530 Subject: [PATCH 63/67] Update --- vformer/models/classification/multiscale.py | 85 +++++++++++---------- vformer/utils/multiscale.py | 49 +++++++++++- 2 files changed, 91 insertions(+), 43 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index ab41c56d..8d64d37b 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -5,11 +5,11 @@ from ...common import BaseClassificationModel from ...decoder import MLPDecoder from ...encoder import MultiScaleBlock -from ...encoder.embedding import PatchEmbed +from ...encoder.embedding.patch_multiscale import PatchEmbed from ...utils import MODEL_REGISTRY -from ...utils.multiscale import round_width +from ...utils.multiscale import round_width,TransformerBasicHead from timm.models.layers import trunc_normal_ - +from fairscale.nn.checkpoint import checkpoint_wrapper @MODEL_REGISTRY.register() class MultiScaleViT(BaseClassificationModel): """ @@ -18,38 +18,39 @@ class MultiScaleViT(BaseClassificationModel): Parameters ---------- """ - def __init__(self, - spatial_size = 224, - pool_first = False, - temporal_size = 8, - in_chans = 3, - use_2d_patch = False, - patch_stride = [2,4,4], - num_classes = 400, - embed_dim = 96, - num_heads = 1, - mlp_ratio = 4.0, - qkv_bias = True, - drop_rate = 0.0, - depth = 16, - drop_path_rate = 0.1, - mode = "conv", - cls_embed_on = True, - sep_pos_embed = False, - norm_stem = False, - norm_layer = partial(nn.LayerNorm, eps=1e-6), - patch_kernel = (3, 7, 7), - patch_padding = (1, 3, 3), - DIM_MUL=[[1, 2.0], [3, 2.0], [14, 2.0]], - HEAD_MUL= [[1, 2.0], [3, 2.0], [14, 2.0]], - POOL_KVQ_KERNEL= [3, 3, 3], - POOL_KV_STRIDE_ADAPTIVE=[1, 8, 8], - POOL_Q_STRIDE= [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], - ): - super().__init__() - self.patch_stride = patch_stride + def __init__( + self, + img_size = 224, + pool_first = False, + temporal_size = 8, + in_chans = 3, + use_2d_patch = False, + patch_size = [2,4,4], + num_classes = 400, + embed_dim = 96, + num_heads = 1, + mlp_ratio = 4.0, + qkv_bias = True, + drop_rate = 0.0, + depth = 16, + drop_path_rate = 0.1, + mode = "conv", + cls_embed_on = True, + sep_pos_embed = False, + norm_stem = False, + norm_layer = partial(nn.LayerNorm, eps=1e-6), + patch_kernel = (3, 7, 7), + patch_padding = (1, 3, 3), + DIM_MUL=[[1, 2.0], [3, 2.0], [14, 2.0]], + HEAD_MUL= [[1, 2.0], [3, 2.0], [14, 2.0]], + POOL_KVQ_KERNEL= [3, 3, 3], + POOL_KV_STRIDE_ADAPTIVE=[1, 8, 8], + POOL_Q_STRIDE= [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + ): + super(MultiScaleViT, self).__init__() + self.patch_size = patch_size if use_2d_patch: - self.patch_stride = [1] + self.patch_stride + self.patch_size = [1] + self.patch_size self.drop_rate = drop_rate @@ -61,14 +62,14 @@ def __init__(self, dim_in=in_chans, dim_out=embed_dim, kernel=patch_kernel, - stride=patch_stride, + stride=patch_size, padding=patch_padding, conv_2d=use_2d_patch, ) - self.input_dims = [temporal_size, spatial_size, spatial_size] + self.input_dims = [temporal_size, img_size, img_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ - self.input_dims[i] // self.patch_stride[i] + self.input_dims[i] // self.patch_size[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) @@ -188,7 +189,7 @@ def __init__(self, embed_dim = dim_out self.norm = norm_layer(embed_dim) - self.head = head_helper.TransformerBasicHead( + self.head = TransformerBasicHead( embed_dim, num_classes, dropout_rate=self.drop_rate, @@ -214,13 +215,13 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward(self,spatial_size,temporal_size, x): + def forward(self,img_size,temporal_size, x): x = x[0] x = self.patch_embed(x) - T = temporal_size // self.patch_stride[0] - H = spatial_size // self.patch_stride[1] - W = spatial_size // self.patch_stride[2] + T = temporal_size // self.patch_size[0] + H = img_size // self.patch_size[1] + W = img_size // self.patch_size[2] B, N, C = x.shape if self.cls_embed_on: diff --git a/vformer/utils/multiscale.py b/vformer/utils/multiscale.py index 265da991..d6118f57 100644 --- a/vformer/utils/multiscale.py +++ b/vformer/utils/multiscale.py @@ -1,4 +1,5 @@ import logging +import torch.nn as nn def get_logger(name): """ @@ -9,7 +10,7 @@ def get_logger(name): """ return logging.getLogger(name) -logger = logging.get_logger(__name__) +logger = get_logger(__name__) def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): if not multiplier: @@ -26,3 +27,49 @@ def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): width_out += divisor return int(width_out) +class TransformerBasicHead(nn.Module): + """ + BasicHead. No pool. + """ + + def __init__( + self, + dim_in, + num_classes, + dropout_rate=0.0, + act_func="softmax", + ): + """ + Perform linear projection and activation as head for tranformers. + Args: + dim_in (int): the channel dimension of the input to the head. + num_classes (int): the channel dimensions of the output to the head. + dropout_rate (float): dropout rate. If equal to 0.0, perform no + dropout. + act_func (string): activation function to use. 'softmax': applies + softmax on the output. 'sigmoid': applies sigmoid on the output. + """ + super(TransformerBasicHead, self).__init__() + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + self.projection = nn.Linear(dim_in, num_classes, bias=True) + + # Softmax for evaluation and testing. + if act_func == "softmax": + self.act = nn.Softmax(dim=1) + elif act_func == "sigmoid": + self.act = nn.Sigmoid() + else: + raise NotImplementedError( + "{} is not supported as an activation" + "function.".format(act_func) + ) + + def forward(self, x): + if hasattr(self, "dropout"): + x = self.dropout(x) + x = self.projection(x) + + if not self.training: + x = self.act(x) + return x \ No newline at end of file From 161ca3f170a4fa7c5da1766f3d127e6adb0398f3 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Mon, 24 Jan 2022 15:15:42 +0530 Subject: [PATCH 64/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 8d64d37b..5203dd38 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -11,7 +11,7 @@ from timm.models.layers import trunc_normal_ from fairscale.nn.checkpoint import checkpoint_wrapper @MODEL_REGISTRY.register() -class MultiScaleViT(BaseClassificationModel): +class MultiScaleViT(nn.Module): """ Implementation of 'Multiscale Vision Transformers' https://arxiv.org/abs/2104.11227 From beed6eab3fa8c6eefd2e220cff2c365966d45d5e Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:43:23 +0530 Subject: [PATCH 65/67] remove logger --- vformer/utils/multiscale.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/vformer/utils/multiscale.py b/vformer/utils/multiscale.py index d6118f57..21b15536 100644 --- a/vformer/utils/multiscale.py +++ b/vformer/utils/multiscale.py @@ -1,26 +1,9 @@ -import logging import torch.nn as nn - -def get_logger(name): - """ - Retrieve the logger with the specified name or, if name is None, return a - logger which is the root logger of the hierarchy. - Args: - name (string): name of the logger. - """ - return logging.getLogger(name) - -logger = get_logger(__name__) - -def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): +def round_width(width, multiplier, min_width=1, divisor=1): if not multiplier: return width width *= multiplier min_width = min_width or divisor - if verbose: - logger.info(f"min width {min_width}") - logger.info(f"width {width} divisor {divisor}") - logger.info(f"other {int(width + divisor / 2) // divisor * divisor}") width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) if width_out < 0.9 * width: @@ -72,4 +55,4 @@ def forward(self, x): if not self.training: x = self.act(x) - return x \ No newline at end of file + return x From 4926fed42046c6d5669b5b73191034b3f9992c93 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:45:03 +0530 Subject: [PATCH 66/67] Update multiscale.py --- vformer/models/classification/multiscale.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 5203dd38..2c9a9855 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -9,7 +9,6 @@ from ...utils import MODEL_REGISTRY from ...utils.multiscale import round_width,TransformerBasicHead from timm.models.layers import trunc_normal_ -from fairscale.nn.checkpoint import checkpoint_wrapper @MODEL_REGISTRY.register() class MultiScaleViT(nn.Module): """ @@ -182,8 +181,6 @@ def __init__( has_cls_embed=self.cls_embed_on, pool_first=pool_first, ) - if ACT_CHECKPOINT: - attention_block = checkpoint_wrapper(attention_block) self.blocks.append(attention_block) embed_dim = dim_out From 209b5108df8228c7039b3b4fda907dd8a16a7e07 Mon Sep 17 00:00:00 2001 From: Rishav Mukherji <72412977+Amapocho@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:52:55 +0530 Subject: [PATCH 67/67] Update HEAD_ACT --- vformer/models/classification/multiscale.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vformer/models/classification/multiscale.py b/vformer/models/classification/multiscale.py index 2c9a9855..ca337a55 100644 --- a/vformer/models/classification/multiscale.py +++ b/vformer/models/classification/multiscale.py @@ -45,6 +45,7 @@ def __init__( POOL_KVQ_KERNEL= [3, 3, 3], POOL_KV_STRIDE_ADAPTIVE=[1, 8, 8], POOL_Q_STRIDE= [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], + HEAD_ACT = "softmax", ): super(MultiScaleViT, self).__init__() self.patch_size = patch_size @@ -145,7 +146,7 @@ def __init__( if POOL_KVQ_KERNEL is not None: pool_kv[ POOL_KV_STRIDE[i][0] - ] = KVQ_KERNEL + ] = POOL_KVQ_KERNEL else: pool_kv[POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s