Skip to content

Commit

Permalink
release MobileViT, from @murufeng
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 21, 2021
1 parent 86a7302 commit b983bbe
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 56 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,17 +554,17 @@ pred = nest(img) # (1, 1000)

<img src="./images/mbvit.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.

You can use it with the following code (ex. mobilevit_xs)

```
```python
import torch
from vit_pytorch.mobile_vit import MobileViT

mbvit_xs = MobileViT(
image_size=(256, 256),
image_size = (256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
num_classes = 1000
Expand Down Expand Up @@ -1190,6 +1190,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{mehta2021mobilevit,
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author = {Sachin Mehta and Mohammad Rastegari},
year = {2021},
eprint = {2110.02178},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.24.3',
version = '0.25.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
114 changes: 62 additions & 52 deletions vit_pytorch/mobile_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Reduce

def _make_divisible(v, divisor, min_value=None):

if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
Expand All @@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
return new_v


def Conv_BN_ReLU(inp, oup, kernel, stride=1):
def conv_bn_relu(inp, oup, kernel, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(oup),
Expand Down Expand Up @@ -63,8 +63,6 @@ class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5

Expand All @@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
)

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
Expand All @@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
Expand Down Expand Up @@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
)

def forward(self, x):
out = self.conv(x)

if self.identity:
return x + self.conv(x)
else:
return self.conv(x)
out = out + x
return out

class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size

self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)

self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)

self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)

def forward(self, x):
y = x.clone()
Expand All @@ -165,8 +165,7 @@ def forward(self, x):
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
pw=self.pw)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)

# Fusion
x = self.conv3(x)
Expand All @@ -176,54 +175,65 @@ def forward(self, x):


class MobileViT(nn.Module):
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion = 4,
kernel_size = 3,
patch_size = (2, 2),
depths = (2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'

ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0

L = [2, 4, 3]

self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)

self.mv2 = nn.ModuleList([])
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

self.mvit = nn.ModuleList([])
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))

self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
init_dim, *_, last_dim = channels

self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)

self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))

self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
]))

self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
]))

self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
]))

self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)

def forward(self, x):
x = self.conv1(x)
x = self.mv2[0](x)

x = self.mv2[1](x)
x = self.mv2[2](x)
x = self.mv2[3](x)

x = self.mv2[4](x)
x = self.mvit[0](x)
for conv in self.stem:
x = conv(x)

x = self.mv2[5](x)
x = self.mvit[1](x)

x = self.mv2[6](x)
x = self.mvit[2](x)
x = self.conv2(x)

x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)

return self.to_logits(x)

0 comments on commit b983bbe

Please sign in to comment.