Skip to content

Commit 23a54b4

Browse files
Vozfqubvel
andauthored
Genet from timm (#344)
* gernet from regnet * basic gernet * depth set to 5, and requirements+import update * docs * Fix summary error * remove input size * manet fix and test with latest timm Co-authored-by: Pavel Yakubovskiy <[email protected]>
1 parent f91cc59 commit 23a54b4

File tree

6 files changed

+162
-5
lines changed

6 files changed

+162
-5
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
python -m pip install codecov pytest mock
3030
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
3131
pip install .
32+
pip install -U git+https://github.com/rwightman/pytorch-image-models
3233
- name: Test
3334
run: |
3435
python -m pytest -s tests

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 106 available encoders
15+
- 109 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -188,6 +188,19 @@ The following is a list of supported encoders in the SMP. Select the appropriate
188188
</div>
189189
</details>
190190

191+
<details>
192+
<summary style="margin-left: 25px;">GERNet</summary>
193+
<div style="margin-left: 25px;">
194+
195+
|Encoder |Weights |Params, M |
196+
|--------------------------------|:------------------------------:|:------------------------------:|
197+
|timm-gernet_s |imagenet |6M |
198+
|timm-gernet_m |imagenet |18M |
199+
|timm-gernet_l |imagenet |28M |
200+
201+
</div>
202+
</details>
203+
191204
<details>
192205
<summary style="margin-left: 25px;">SE-Net</summary>
193206
<div style="margin-left: 25px;">

docs/encoders.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ RegNet(x/y)
136136
| timm-regnety\_320 | imagenet | 141M |
137137
+---------------------+------------+-------------+
138138

139+
GERNet
140+
~~~~~~
141+
142+
+-------------------------+------------+-------------+
143+
| Encoder | Weights | Params, M |
144+
+=========================+============+=============+
145+
| timm-gernet\_s | imagenet | 6M |
146+
+-------------------------+------------+-------------+
147+
| timm-gernet\_m | imagenet | 18M |
148+
+-------------------------+------------+-------------+
149+
| timm-gernet\_l | imagenet | 28M |
150+
+-------------------------+------------+-------------+
151+
139152
SE-Net
140153
~~~~~~
141154

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
from .timm_res2net import timm_res2net_encoders
1818
from .timm_regnet import timm_regnet_encoders
1919
from .timm_sknet import timm_sknet_encoders
20+
try:
21+
from .timm_gernet import timm_gernet_encoders
22+
except ImportError as e:
23+
timm_gernet_encoders = {}
24+
print("Current timm version doesn't support GERNet."
25+
"If GERNet support is needed please update timm")
26+
2027
from ._preprocessing import preprocess_input
2128

2229
encoders = {}
@@ -36,6 +43,7 @@
3643
encoders.update(timm_res2net_encoders)
3744
encoders.update(timm_regnet_encoders)
3845
encoders.update(timm_sknet_encoders)
46+
encoders.update(timm_gernet_encoders)
3947

4048

4149
def get_encoder(name, in_channels=3, depth=5, weights=None):
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from timm.models import ByobCfg, BlocksCfg, ByobNet
2+
3+
from ._base import EncoderMixin
4+
import torch.nn as nn
5+
6+
7+
class GERNetEncoder(ByobNet, EncoderMixin):
8+
def __init__(self, out_channels, depth=5, **kwargs):
9+
super().__init__(**kwargs)
10+
self._depth = depth
11+
self._out_channels = out_channels
12+
self._in_channels = 3
13+
14+
del self.head
15+
16+
def get_stages(self):
17+
return [
18+
nn.Identity(),
19+
self.stem,
20+
self.stages[0],
21+
self.stages[1],
22+
self.stages[2],
23+
nn.Sequential(self.stages[3], self.stages[4], self.final_conv)
24+
]
25+
26+
def forward(self, x):
27+
stages = self.get_stages()
28+
29+
features = []
30+
for i in range(self._depth + 1):
31+
x = stages[i](x)
32+
features.append(x)
33+
34+
return features
35+
36+
def load_state_dict(self, state_dict, **kwargs):
37+
state_dict.pop("head.fc.weight")
38+
state_dict.pop("head.fc.bias")
39+
super().load_state_dict(state_dict, **kwargs)
40+
41+
42+
regnet_weights = {
43+
'timm-gernet_s': {
44+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth',
45+
},
46+
'timm-gernet_m': {
47+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth',
48+
},
49+
'timm-gernet_l': {
50+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
51+
},
52+
}
53+
54+
pretrained_settings = {}
55+
for model_name, sources in regnet_weights.items():
56+
pretrained_settings[model_name] = {}
57+
for source_name, source_url in sources.items():
58+
pretrained_settings[model_name][source_name] = {
59+
"url": source_url,
60+
'input_range': [0, 1],
61+
'mean': [0.485, 0.456, 0.406],
62+
'std': [0.229, 0.224, 0.225],
63+
'num_classes': 1000
64+
}
65+
66+
timm_gernet_encoders = {
67+
'timm-gernet_s': {
68+
'encoder': GERNetEncoder,
69+
"pretrained_settings": pretrained_settings["timm-gernet_s"],
70+
'params': {
71+
'out_channels': (3, 13, 48, 48, 384, 1920),
72+
'cfg': ByobCfg(
73+
blocks=(
74+
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
75+
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
76+
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
77+
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
78+
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
79+
),
80+
stem_chs=13,
81+
num_features=1920,
82+
)
83+
},
84+
},
85+
'timm-gernet_m': {
86+
'encoder': GERNetEncoder,
87+
"pretrained_settings": pretrained_settings["timm-gernet_m"],
88+
'params': {
89+
'out_channels': (3, 32, 128, 192, 640, 2560),
90+
'cfg': ByobCfg(
91+
blocks=(
92+
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
93+
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
94+
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
95+
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
96+
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
97+
),
98+
stem_chs=32,
99+
num_features=2560,
100+
)
101+
},
102+
},
103+
'timm-gernet_l': {
104+
'encoder': GERNetEncoder,
105+
"pretrained_settings": pretrained_settings["timm-gernet_l"],
106+
'params': {
107+
'out_channels': (3, 32, 128, 192, 640, 2560),
108+
'cfg': ByobCfg(
109+
blocks=(
110+
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
111+
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
112+
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
113+
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
114+
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
115+
),
116+
stem_chs=32,
117+
num_features=2560,
118+
)
119+
},
120+
},
121+
}

segmentation_models_pytorch/manet/decoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,19 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True,
5656
use_batchnorm=use_batchnorm,
5757
)
5858
)
59+
reduced_channels = max(1, skip_channels // reduction)
5960
self.SE_ll = nn.Sequential(
6061
nn.AdaptiveAvgPool2d(1),
61-
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
62+
nn.Conv2d(skip_channels, reduced_channels, 1),
6263
nn.ReLU(inplace=True),
63-
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
64+
nn.Conv2d(reduced_channels, skip_channels, 1),
6465
nn.Sigmoid(),
6566
)
6667
self.SE_hl = nn.Sequential(
6768
nn.AdaptiveAvgPool2d(1),
68-
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
69+
nn.Conv2d(skip_channels, reduced_channels, 1),
6970
nn.ReLU(inplace=True),
70-
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
71+
nn.Conv2d(reduced_channels, skip_channels, 1),
7172
nn.Sigmoid(),
7273
)
7374
self.conv1 = md.Conv2dReLU(

0 commit comments

Comments
 (0)