diff --git a/detectron/core/config.py b/detectron/core/config.py index 3ca943a68..7a12e652b 100644 --- a/detectron/core/config.py +++ b/detectron/core/config.py @@ -441,6 +441,11 @@ # (e.g., 'generalized_rcnn', 'mask_rcnn', ...) __C.MODEL.TYPE = '' +# Detection model helper class to use +# +# Allows to apply custom DetectionModelHelper implementation +__C.MODEL.MODEL_HELPER_CLASS = 'detectron.modeling.detector.DetectionModelHelper' + # The backbone conv body to use # The string must match a function that is imported in modeling.model_builder # (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN diff --git a/detectron/modeling/model_builder.py b/detectron/modeling/model_builder.py index 25ab21770..b3f181123 100644 --- a/detectron/modeling/model_builder.py +++ b/detectron/modeling/model_builder.py @@ -113,7 +113,16 @@ def create(model_type_func, train=False, gpu_id=0): targeted to a specific GPU by specifying gpu_id. This is used by optimizer.build_data_parallel_model() during test time. """ - model = DetectionModelHelper( + parts = cfg.MODEL.MODEL_HELPER_CLASS.split('.') + try: + module_name = '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + model_helper_class = getattr(module, parts[-1]) + except (IndexError, ImportError, AttributeError): + logger.error('Failed to find model helper: %s', model_helper_class) + raise + + model = model_helper_class( name=model_type_func, train=train, num_classes=cfg.MODEL.NUM_CLASSES, @@ -145,7 +154,13 @@ def get_func(func_name): return globals()[parts[0]] # Otherwise, assume we're referencing a module under modeling module_name = 'detectron.modeling.' + '.'.join(parts[:-1]) - module = importlib.import_module(module_name) + try: + module = importlib.import_module(module_name) + except ImportError: + # Finally check if we're referencing a module from the environment + module_name = '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + logger.debug('Using function %s from the environment', func_name) return getattr(module, parts[-1]) except Exception: logger.error('Failed to find function: {}'.format(func_name))