Skip to content

Commit 8113046

Browse files
committed
update interface
1 parent bf0c736 commit 8113046

File tree

6 files changed

+19
-20
lines changed

6 files changed

+19
-20
lines changed

src/otx/algo/detection/atss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
torch_compile=torch_compile,
5757
tile_config=tile_config,
5858
)
59+
breakpoint()
5960
self.tile_image_size = tile_image_size
6061

6162
@property

src/otx/algo/instance_segmentation/heads/custom_roi_head.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,6 @@ def bbox_loss(self, x: tuple[Tensor], sampling_results: list[SamplingResult], ba
548548

549549
class CustomConvFCBBoxHead(Shared2FCBBoxHead, ClassIncrementalMixin):
550550
"""CustomConvFCBBoxHead class for OTX."""
551-
# checked
552551

553552
def loss_and_target(
554553
self,

src/otx/cli/cli.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,18 +331,16 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None:
331331
# For num_classes update, Model and Metric are instantiated separately.
332332
model_config = self.config[self.subcommand].pop("model")
333333

334-
input_size = self.config["train"]["engine"].get("input_size")
335-
if input_size is not None:
336-
if isinstance(input_size, int):
337-
input_size = (input_size, input_size)
338-
self.config["train"]["data"]["input_size"] = input_size
339-
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size
340-
341334
# Instantiate the things that don't need to special handling
342335
self.config_init = self.parser.instantiate_classes(self.config)
343336
self.workspace = self.get_config_value(self.config_init, "workspace")
344337
self.datamodule = self.get_config_value(self.config_init, "data")
345338

339+
if (input_size := self.datamodule.input_size) is not None:
340+
if isinstance(input_size, int):
341+
input_size = (input_size, input_size)
342+
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size
343+
346344
# Instantiate the model and needed components
347345
self.model = self.instantiate_model(model_config=model_config)
348346

src/otx/core/data/module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,22 @@ def __init__(
6363
auto_num_workers: bool = False,
6464
device: DeviceType = DeviceType.auto,
6565
input_size: int | tuple[int, int] | None = None,
66+
adaptive_input_size: bool = False,
6667
) -> None:
6768
"""Constructor."""
6869
super().__init__()
6970
self.task = task
7071
self.data_format = data_format
7172
self.data_root = data_root
7273

74+
if adaptive_input_size:
75+
print("adaptive_input_size works")
76+
7377
if input_size is not None:
7478
for subset_cfg in [train_subset, val_subset, test_subset, unlabeled_subset]:
7579
if subset_cfg.input_size is None:
7680
subset_cfg.input_size = input_size
81+
self.input_size = input_size
7782

7883
self.train_subset = train_subset
7984
self.val_subset = val_subset

src/otx/engine/engine.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def __init__(
122122
checkpoint: PathLike | None = None,
123123
device: DeviceType = DeviceType.auto,
124124
num_devices: int = 1,
125-
input_size: Sequence[int] | int | None = None,
126125
**kwargs,
127126
):
128127
"""Initializes the OTX Engine.
@@ -147,17 +146,8 @@ def __init__(
147146
data_root=data_root,
148147
task=datamodule.task if datamodule is not None else task,
149148
model_name=None if isinstance(model, OTXModel) else model,
150-
input_size=input_size,
151149
)
152150

153-
if input_size is not None:
154-
if isinstance(datamodule, OTXDataModule) and datamodule.input_size != input_size:
155-
msg = "Data module is already initialized. Input size will be ignored to data module."
156-
logging.warning(msg)
157-
if isinstance(model, OTXModel) and model.input_size != input_size:
158-
msg = "Model is already initialized. Input size will be ignored to model."
159-
logging.warning(msg)
160-
161151
self._datamodule: OTXDataModule | None = (
162152
datamodule if datamodule is not None else self._auto_configurator.get_datamodule()
163153
)
@@ -169,6 +159,7 @@ def __init__(
169159
if isinstance(model, OTXModel)
170160
else self._auto_configurator.get_model(
171161
label_info=self._datamodule.label_info if self._datamodule is not None else None,
162+
input_size=self._datamodule.input_size,
172163
)
173164
)
174165

src/otx/engine/utils/auto_configurator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
],
6666
"common_semantic_segmentation_with_subset_dirs": [OTXTaskType.SEMANTIC_SEGMENTATION],
6767
"kinetics": [OTXTaskType.ACTION_CLASSIFICATION],
68-
"mvtec_classification": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION],
68+
"mvtec": [OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION],
6969
}
7070

7171
OVMODEL_PER_TASK = {
@@ -245,7 +245,7 @@ def get_datamodule(self) -> OTXDataModule | None:
245245
**data_config,
246246
)
247247

248-
def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None) -> OTXModel:
248+
def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes | None = None, input_size: Sequence[int] | None = None) -> OTXModel:
249249
"""Retrieves the OTXModel instance based on the provided model name and meta information.
250250
251251
Args:
@@ -278,6 +278,11 @@ def get_model(self, model_name: str | None = None, label_info: LabelInfoTypes |
278278

279279
model_config = deepcopy(self.config["model"])
280280

281+
if input_size is not None:
282+
if isinstance(input_size, int):
283+
input_size = (input_size, input_size)
284+
model_config["init_args"]["input_size"] = tuple(model_config["init_args"]["input_size"][:-2]) + input_size
285+
281286
model_cls = get_model_cls_from_config(Namespace(model_config))
282287

283288
if should_pass_label_info(model_cls):

0 commit comments

Comments
 (0)