|
| 1 | +import pytest |
| 2 | +from icevision.all import * |
| 3 | + |
| 4 | + |
| 5 | +@pytest.mark.parametrize( |
| 6 | + "ds, model_type", |
| 7 | + [ |
| 8 | + ( |
| 9 | + "fridge_ds", |
| 10 | + models.mmdet.retinanet, |
| 11 | + ), |
| 12 | + ], |
| 13 | +) |
| 14 | +class TestTimmBackbones: |
| 15 | + def dls_model(self, ds, model_type, samples_source, request): |
| 16 | + train_ds, valid_ds = request.getfixturevalue(ds) |
| 17 | + train_dl = model_type.train_dl(train_ds, batch_size=2) |
| 18 | + valid_dl = model_type.valid_dl(valid_ds, batch_size=2) |
| 19 | + |
| 20 | + # backbone = model_type.backbones.mmdet.resnet50_fpn_1x() |
| 21 | + backbone = model_type.backbones.timm.mobilenet.mobilenetv3_large_100 |
| 22 | + backbone.config_path = samples_source / backbone.config_path |
| 23 | + |
| 24 | + model = model_type.model(backbone=backbone, num_classes=5) |
| 25 | + |
| 26 | + return train_dl, valid_dl, model |
| 27 | + |
| 28 | + def test_mmdet_bbox_models_fastai(self, ds, model_type, samples_source, request): |
| 29 | + train_dl, valid_dl, model = self.dls_model( |
| 30 | + ds, model_type, samples_source, request |
| 31 | + ) |
| 32 | + |
| 33 | + learn = model_type.fastai.learner( |
| 34 | + dls=[train_dl, valid_dl], model=model, splitter=fastai.trainable_params |
| 35 | + ) |
| 36 | + learn.fine_tune(1, 3e-4) |
| 37 | + |
| 38 | + def test_mmdet_bbox_models_light(self, ds, model_type, samples_source, request): |
| 39 | + train_dl, valid_dl, model = self.dls_model( |
| 40 | + ds, model_type, samples_source, request |
| 41 | + ) |
| 42 | + |
| 43 | + class LitModel(model_type.lightning.ModelAdapter): |
| 44 | + def configure_optimizers(self): |
| 45 | + return Adam(self.parameters(), lr=1e-4) |
| 46 | + |
| 47 | + light_model = LitModel(model) |
| 48 | + trainer = pl.Trainer( |
| 49 | + max_epochs=1, |
| 50 | + weights_summary=None, |
| 51 | + num_sanity_val_steps=0, |
| 52 | + logger=False, |
| 53 | + checkpoint_callback=False, |
| 54 | + ) |
| 55 | + trainer.fit(light_model, train_dl, valid_dl) |
0 commit comments