-
Notifications
You must be signed in to change notification settings - Fork 0
/
framework.py
59 lines (50 loc) · 1.51 KB
/
framework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from . import yolo
from . import yolov2
from . import vanilla
from os.path import basename
class framework(object):
constructor = vanilla.constructor
loss = vanilla.train.loss
def __init__(self, meta, FLAGS):
model = basename(meta['model'])
model = '.'.join(model.split('.')[:-1])
meta['name'] = model
self.constructor(meta, FLAGS)
def is_inp(self, file_name):
return True
class YOLO(framework):
constructor = yolo.constructor
parse = yolo.data.parse
shuffle = yolo.data.shuffle
preprocess = yolo.predict.preprocess
postprocess = yolo.predict.postprocess
loss = yolo.train.loss
is_inp = yolo.misc.is_inp
profile = yolo.misc.profile
_batch = yolo.data._batch
resize_input = yolo.predict.resize_input
findboxes = yolo.predict.findboxes
process_box = yolo.predict.process_box
class YOLOv2(framework):
constructor = yolo.constructor
parse = yolo.data.parse
shuffle = yolov2.data.shuffle
preprocess = yolo.predict.preprocess
loss = yolov2.train.loss
is_inp = yolo.misc.is_inp
postprocess = yolov2.predict.postprocess
_batch = yolov2.data._batch
resize_input = yolo.predict.resize_input
findboxes = yolov2.predict.findboxes
process_box = yolo.predict.process_box
"""
framework factory
"""
types = {
'[detection]': YOLO,
'[region]': YOLOv2
}
def create_framework(meta, FLAGS):
net_type = meta['type']
this = types.get(net_type, framework)
return this(meta, FLAGS)