Skip to content

Commit 09b97ba

Browse files
committed
[detection] Use ml_hub for building gitub repo urls and github release url
[tests] make deploy test non essential
1 parent 92be5eb commit 09b97ba

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

ml/vision/models/detection/detector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
import random
23
from abc import abstractmethod
34
from pathlib import Path
45

56
import torch as th
7+
from torch import nn
68
import torchvision.transforms as T
79

8-
from ml import nn, random, logging
10+
from ml import logging
911
from ...datasets import coco
1012

1113
COLORS91 = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(coco.COCO91_CLASSES))]

ml/vision/models/detection/detr/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from pathlib import Path
33

44
import torch
5-
from ml import io, nn, hub, logging
6-
from ml.nn import functional as F
5+
from torch import hub
6+
7+
import ml.hub as ml_hub
8+
from ml import io, logging
79

810
GITHUB_DETR = dict(
911
owner='facebookresearch',
@@ -29,10 +31,10 @@
2931
def github(tag='main', deformable=False):
3032
if deformable:
3133
tag = TAGS_DEFORMABLE_DETR[tag]
32-
return hub.github(owner=GITHUB_DEFORMABLE_DETR['owner'], project=GITHUB_DEFORMABLE_DETR['project'], tag=tag)
34+
return ml_hub.github(owner=GITHUB_DEFORMABLE_DETR['owner'], project=GITHUB_DEFORMABLE_DETR['project'], tag=tag)
3335
else:
3436
tag = TAGS_DETR[tag]
35-
return hub.github(owner=GITHUB_DETR['owner'], project=GITHUB_DETR['project'], tag=tag)
37+
return ml_hub.github(owner=GITHUB_DETR['owner'], project=GITHUB_DETR['project'], tag=tag)
3638

3739
def from_pretrained(chkpt, model_dir=None, force_reload=False, **kwargs):
3840
# TODO naming for custom checkpoints
@@ -65,7 +67,7 @@ def from_pretrained(chkpt, model_dir=None, force_reload=False, **kwargs):
6567
# GitHub Release
6668
owner = kwargs.get('owner', GITHUB_DETR['owner'])
6769
proj = kwargs.get('project', GITHUB_DETR['project'])
68-
url = hub.github_release_url(owner, proj, tag, chkpt)
70+
url = ml_hub.github_release_url(owner, proj, tag, chkpt)
6971
chkpt = hub.load_state_dict_from_url(url,
7072
model_dir=model_dir,
7173
map_location=torch.device('cpu'),

ml/vision/models/detection/yolox/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from pathlib import Path
44

55
import torch
6+
from torch import hub
7+
8+
from ml import logging
9+
import ml.hub as ml_hub
610

7-
from ml import hub, logging
811

912
GITHUB_YOLOX = dict(
1013
owner='Megvii-BaseDetection',
@@ -18,7 +21,7 @@
1821

1922
def github(tag='main', deformable=False):
2023
tag = TAGS_YOLOX[tag]
21-
return hub.github(owner=GITHUB_YOLOX['owner'], project=GITHUB_YOLOX['project'], tag=tag)
24+
return ml_hub.github(owner=GITHUB_YOLOX['owner'], project=GITHUB_YOLOX['project'], tag=tag)
2225

2326
def custom_forward(self, x, targets=None):
2427
# fpn output content features of [dark3, dark4, dark5]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def readme():
189189
license='BSD-3',
190190
classifiers=[
191191
'License :: OSI Approved :: BSD License',
192-
'Operating System :: macOS/Ubuntu 16.04+',
192+
'Operating System :: macOS/Ubuntu 18.04+',
193193
'Development Status :: 1 - Alpha',
194194
'Intended Audience :: Developers',
195195
'Intended Audience :: Education',

tests/test_yolox_deploy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def detector(tag, dev):
6262
detector.eval()
6363
return detector.to(dev)
6464

65-
@pytest.mark.essential
65+
# @pytest.mark.essential
6666
@pytest.mark.parametrize("B", [1])
6767
@pytest.mark.parametrize("shape", [(640, 640)])
6868
def test_deploy_onnx(benchmark, name, batch, detector, dev, B, shape):
@@ -161,7 +161,7 @@ def preprocess(image_path, *shape):
161161
for torch_feats, feats in zip(torch_features, features):
162162
th.testing.assert_close(torch_feats.float(), feats.float(), rtol=1e-03, atol=4e-04)
163163

164-
@pytest.mark.essential
164+
# @pytest.mark.essential
165165
@pytest.mark.parametrize("B", [8])
166166
@pytest.mark.parametrize("batch_preprocess", [True, False])
167167
@pytest.mark.parametrize('fp16', [True, False])

0 commit comments

Comments
 (0)