Skip to content

Commit 20f7fd5

Browse files
committed
added tests for mobilnetv3 timm backbone
1 parent ee3c8e1 commit 20f7fd5

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
from icevision.all import *
3+
4+
5+
@pytest.mark.parametrize(
6+
"ds, model_type",
7+
[
8+
(
9+
"fridge_ds",
10+
models.mmdet.retinanet,
11+
),
12+
],
13+
)
14+
class TestTimmBackbones:
15+
def dls_model(self, ds, model_type, samples_source, request):
16+
train_ds, valid_ds = request.getfixturevalue(ds)
17+
train_dl = model_type.train_dl(train_ds, batch_size=2)
18+
valid_dl = model_type.valid_dl(valid_ds, batch_size=2)
19+
20+
# backbone = model_type.backbones.mmdet.resnet50_fpn_1x()
21+
backbone = model_type.backbones.timm.mobilenet.mobilenetv3_large_100
22+
backbone.config_path = samples_source / backbone.config_path
23+
24+
model = model_type.model(backbone=backbone, num_classes=5)
25+
26+
return train_dl, valid_dl, model
27+
28+
def test_mmdet_bbox_models_fastai(self, ds, model_type, samples_source, request):
29+
train_dl, valid_dl, model = self.dls_model(
30+
ds, model_type, samples_source, request
31+
)
32+
33+
learn = model_type.fastai.learner(
34+
dls=[train_dl, valid_dl], model=model, splitter=fastai.trainable_params
35+
)
36+
learn.fine_tune(1, 3e-4)
37+
38+
def test_mmdet_bbox_models_light(self, ds, model_type, samples_source, request):
39+
train_dl, valid_dl, model = self.dls_model(
40+
ds, model_type, samples_source, request
41+
)
42+
43+
class LitModel(model_type.lightning.ModelAdapter):
44+
def configure_optimizers(self):
45+
return Adam(self.parameters(), lr=1e-4)
46+
47+
light_model = LitModel(model)
48+
trainer = pl.Trainer(
49+
max_epochs=1,
50+
weights_summary=None,
51+
num_sanity_val_steps=0,
52+
logger=False,
53+
checkpoint_callback=False,
54+
)
55+
trainer.fit(light_model, train_dl, valid_dl)

0 commit comments

Comments
 (0)