diff --git a/pyproject.toml b/pyproject.toml index c0db79961a8..d89bcf55f14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "importlib_resources==6.4.0", "docstring_parser==0.16", # CLI help-formatter "rich_argparse==1.4.0", # CLI help-formatter + "einops==0.7.0", ] [project.optional-dependencies] diff --git a/src/otx/algo/action_classification/__init__.py b/src/otx/algo/action_classification/__init__.py index d9f3b9c7159..e880cde9494 100644 --- a/src/otx/algo/action_classification/__init__.py +++ b/src/otx/algo/action_classification/__init__.py @@ -3,9 +3,17 @@ # """Module for OTX action classification models.""" -from .backbones import OTXMoViNet -from .heads import MoViNetHead +from .backbones import MoViNetBackbone, X3DBackbone +from .heads import MoViNetHead, X3DHead from .openvino_model import OTXOVActionCls -from .recognizers import MoViNetRecognizer, OTXRecognizer3D +from .recognizers import BaseRecognizer, MoViNetRecognizer -__all__ = ["OTXOVActionCls", "OTXRecognizer3D", "OTXMoViNet", "MoViNetHead", "MoViNetRecognizer"] +__all__ = [ + "OTXOVActionCls", + "BaseRecognizer", + "MoViNetBackbone", + "MoViNetHead", + "MoViNetRecognizer", + "X3DBackbone", + "X3DHead", +] diff --git a/src/otx/algo/action_classification/backbones/__init__.py b/src/otx/algo/action_classification/backbones/__init__.py index 044d3165fa4..ac25ee20a51 100644 --- a/src/otx/algo/action_classification/backbones/__init__.py +++ b/src/otx/algo/action_classification/backbones/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Custom backbones for action classification.""" -from .movinet import OTXMoViNet +from .movinet import MoViNetBackbone +from .x3d import X3DBackbone -__all__ = ["OTXMoViNet"] +__all__ = ["MoViNetBackbone", "X3DBackbone"] diff --git a/src/otx/algo/action_classification/backbones/movinet.py b/src/otx/algo/action_classification/backbones/movinet.py index 70b0a2975a1..cc7430af220 100644 --- a/src/otx/algo/action_classification/backbones/movinet.py +++ b/src/otx/algo/action_classification/backbones/movinet.py @@ -1,7 +1,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Code modified from: https://github.com/Atze00/MoViNet-pytorch/blob/main/movinets/models.py.""" +# Copyright (c) OpenMMLab. All rights reserved. +"""Code modified from: https://github.com/Atze00/MoViNet-pytorch/blob/main/movinets/models.py.""" from __future__ import annotations from collections import OrderedDict @@ -10,8 +11,7 @@ import torch import torch.nn.functional as F # noqa: N812 from einops import rearrange -from mmaction.models import MODELS -from mmengine.config import Config +from omegaconf.dictconfig import DictConfig from torch import Tensor, nn from torch.nn.modules.utils import _pair, _triple @@ -438,7 +438,7 @@ class BasicBneck(nn.Module): """Basic bottleneck block of MoViNet network. Args: - cfg (Config): Configuration object containing block's hyperparameters. + cfg (DictConfig): configuration object containing block's hyperparameters. tf_like (bool): A boolean indicating whether to use TensorFlow like convolution padding or not. conv_type (str): A string indicating the type of convolutional layer to use. @@ -460,7 +460,7 @@ class BasicBneck(nn.Module): def __init__( self, - cfg: Config, + cfg: DictConfig, tf_like: bool, conv_type: str, norm_layer: Callable[..., nn.Module] | None = None, @@ -543,11 +543,11 @@ def forward(self, x: Tensor) -> Tensor: return residual + self.alpha * x -class MoViNet(nn.Module): +class MoViNetBackboneBase(nn.Module): """MoViNet class used for video classification. Args: - cfg (Config): Configuration object containing network's hyperparameters. + cfg (DictConfig): configuration object containing network's hyperparameters. conv_type (str, optional): A string indicating the type of convolutional layer to use. Can be "2d" or "3d". Defaults to "3d". tf_like (bool, optional): A boolean indicating whether to use TensorFlow like @@ -569,7 +569,7 @@ class MoViNet(nn.Module): def __init__( self, - cfg: Config, + cfg: DictConfig, conv_type: str = "3d", tf_like: bool = False, ) -> None: @@ -650,70 +650,69 @@ def init_weights(self) -> None: self.apply(self._init_weights) -@MODELS.register_module() -class OTXMoViNet(MoViNet): +class MoViNetBackbone(MoViNetBackboneBase): """MoViNet wrapper class for OTX.""" - def __init__(self, **kwargs): - cfg = Config() + def __init__(self, **kwargs) -> None: + cfg = DictConfig({}) cfg.name = "A0" - cfg.conv1 = Config() - OTXMoViNet.fill_conv(cfg.conv1, 3, 8, (1, 3, 3), (1, 2, 2), (0, 1, 1)) + cfg.conv1 = DictConfig({}) + MoViNetBackbone.fill_conv(cfg.conv1, 3, 8, (1, 3, 3), (1, 2, 2), (0, 1, 1)) cfg.blocks = [ - [Config()], - [Config() for _ in range(3)], - [Config() for _ in range(3)], - [Config() for _ in range(4)], - [Config() for _ in range(4)], + [DictConfig({})], + [DictConfig({}) for _ in range(3)], + [DictConfig({}) for _ in range(3)], + [DictConfig({}) for _ in range(4)], + [DictConfig({}) for _ in range(4)], ] # block 2 - OTXMoViNet.fill_se_config(cfg.blocks[0][0], 8, 8, 24, (1, 5, 5), (1, 2, 2), (0, 2, 2), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[0][0], 8, 8, 24, (1, 5, 5), (1, 2, 2), (0, 2, 2), (0, 1, 1)) # block 3 - OTXMoViNet.fill_se_config(cfg.blocks[1][0], 8, 32, 80, (3, 3, 3), (1, 2, 2), (1, 0, 0), (0, 0, 0)) - OTXMoViNet.fill_se_config(cfg.blocks[1][1], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[1][2], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[1][0], 8, 32, 80, (3, 3, 3), (1, 2, 2), (1, 0, 0), (0, 0, 0)) + MoViNetBackbone.fill_se_config(cfg.blocks[1][1], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[1][2], 32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) # block 4 - OTXMoViNet.fill_se_config(cfg.blocks[2][0], 32, 56, 184, (5, 3, 3), (1, 2, 2), (2, 0, 0), (0, 0, 0)) - OTXMoViNet.fill_se_config(cfg.blocks[2][1], 56, 56, 112, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[2][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[2][0], 32, 56, 184, (5, 3, 3), (1, 2, 2), (2, 0, 0), (0, 0, 0)) + MoViNetBackbone.fill_se_config(cfg.blocks[2][1], 56, 56, 112, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[2][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) # block 5 - OTXMoViNet.fill_se_config(cfg.blocks[3][0], 56, 56, 184, (5, 3, 3), (1, 1, 1), (2, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[3][1], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[3][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[3][3], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[3][0], 56, 56, 184, (5, 3, 3), (1, 1, 1), (2, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[3][1], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[3][2], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[3][3], 56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)) # block 6 - OTXMoViNet.fill_se_config(cfg.blocks[4][0], 56, 104, 384, (5, 3, 3), (1, 2, 2), (2, 1, 1), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[4][1], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[4][2], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) - OTXMoViNet.fill_se_config(cfg.blocks[4][3], 104, 104, 344, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[4][0], 56, 104, 384, (5, 3, 3), (1, 2, 2), (2, 1, 1), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[4][1], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[4][2], 104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) + MoViNetBackbone.fill_se_config(cfg.blocks[4][3], 104, 104, 344, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)) - cfg.conv7 = Config() - OTXMoViNet.fill_conv(cfg.conv7, 104, 480, (1, 1, 1), (1, 1, 1), (0, 0, 0)) + cfg.conv7 = DictConfig({}) + MoViNetBackbone.fill_conv(cfg.conv7, 104, 480, (1, 1, 1), (1, 1, 1), (0, 0, 0)) - cfg.dense9 = Config({"hidden_dim": 2048}) + cfg.dense9 = DictConfig({"hidden_dim": 2048}) super().__init__(cfg) @staticmethod def fill_se_config( - conf: Config, + conf: DictConfig, input_channels: int, out_channels: int, expanded_channels: int, - kernel_size: tuple[int, int], - stride: tuple[int, int], - padding: tuple[int, int], - padding_avg: tuple[int, int], + kernel_size: tuple[int, int, int], + stride: tuple[int, int, int], + padding: tuple[int, int, int], + padding_avg: tuple[int, int, int], ) -> None: - """Set the values of a given Config object to SE module. + """Set the values of a given DictConfig object to SE module. Args: - conf (Config): The Config object to be updated. + conf (DictConfig): The DictConfig object to be updated. input_channels (int): The number of input channels. out_channels (int): The number of output channels. expanded_channels (int): The number of channels after expansion in the basic block. @@ -727,7 +726,7 @@ def fill_se_config( """ conf.expanded_channels = expanded_channels conf.padding_avg = padding_avg - OTXMoViNet.fill_conv( + MoViNetBackbone.fill_conv( conf, input_channels, out_channels, @@ -738,17 +737,17 @@ def fill_se_config( @staticmethod def fill_conv( - conf: Config, + conf: DictConfig, input_channels: int, out_channels: int, - kernel_size: tuple[int, int], - stride: tuple[int, int], - padding: tuple[int, int], + kernel_size: tuple[int, int, int], + stride: tuple[int, int, int], + padding: tuple[int, int, int], ) -> None: - """Set the values of a given Config object to conv layer. + """Set the values of a given DictConfig object to conv layer. Args: - conf (Config): The Config object to be updated. + conf (DictConfig): The DictConfig object to be updated. input_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (tuple[int]): The size of the kernel. diff --git a/src/otx/algo/action_classification/backbones/x3d.py b/src/otx/algo/action_classification/backbones/x3d.py new file mode 100644 index 00000000000..11805b52679 --- /dev/null +++ b/src/otx/algo/action_classification/backbones/x3d.py @@ -0,0 +1,546 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. + +"""X3D backbone implementation.""" +from __future__ import annotations + +import math + +import torch.utils.checkpoint as cp +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm + +from otx.algo.modules.activation import Swish, build_activation_layer +from otx.algo.modules.conv_module import ConvModule +from otx.algo.utils.mmengine_utils import load_checkpoint +from otx.algo.utils.weight_init import constant_init, kaiming_init + + +class SEModule(nn.Module): + """Implementation of SqueezeExcitation module.""" + + def __init__(self, channels: int, reduction: float): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool3d(1) + self.bottleneck = self._round_width(channels, reduction) + self.fc1 = nn.Conv3d(channels, self.bottleneck, kernel_size=1, padding=0) + self.relu = nn.ReLU() + self.fc2 = nn.Conv3d(self.bottleneck, channels, kernel_size=1, padding=0) + self.sigmoid = nn.Sigmoid() + + @staticmethod + def _round_width(width: int, multiplier: float, min_width: int = 8, divisor: int = 8) -> int: + """Round width of filters based on width multiplier.""" + width = int(width * multiplier) + min_width = min_width or divisor + width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) + if width_out < 0.9 * width: + width_out += divisor + return int(width_out) + + def forward(self, x: Tensor) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (Tensor): The input data. + + Returns: + Tensor: The output of the module. + """ + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class BlockX3D(nn.Module): + """BlockX3D 3d building block for X3D. + + Args: + inplanes (int): Number of channels for the input in first conv3d layer. + planes (int): Number of channels produced by some norm/conv3d layers. + outplanes (int): Number of channels produced by final the conv3d layer. + spatial_stride (int): Spatial stride in the conv3d layer. Default: 1. + downsample (nn.Module | None): Downsample layer. Default: None. + se_ratio (float | None): The reduction ratio of squeeze and excitation + unit. If set as None, it means not using SE unit. Default: None. + use_swish (bool): Whether to use swish as the activation function + before and after the 3x3x3 conv. Default: True. + conv_cfg (dict): Config dict for convolution layer. + Default: ``dict(type='Conv3d')``. + norm_cfg (dict): Config for norm layers. required keys are ``type``, + Default: ``dict(type='BN3d')``. + act_cfg (dict): Config dict for activation layer. + Default: ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__( + self, + inplanes: int, + planes: int, + outplanes: int, + spatial_stride: int = 1, + downsample: nn.Module | None = None, + se_ratio: float | None = None, + use_swish: bool = True, + conv_cfg: dict | None = None, + norm_cfg: dict | None = None, + act_cfg: dict | None = None, + with_cp: bool = False, + ): + super().__init__() + + self.inplanes = inplanes + self.planes = planes + self.outplanes = outplanes + self.spatial_stride = spatial_stride + self.downsample = downsample + self.se_ratio = se_ratio + self.use_swish = use_swish + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.act_cfg_swish = Swish() + self.with_cp = with_cp + + self.conv1 = ConvModule( + in_channels=inplanes, + out_channels=planes, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ) + # Here we use the channel-wise conv + self.conv2 = ConvModule( + in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=(1, self.spatial_stride, self.spatial_stride), + padding=1, + groups=planes, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None, + ) + + self.swish = Swish() + + self.conv3 = ConvModule( + in_channels=planes, + out_channels=outplanes, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None, + ) + + if self.se_ratio is not None: + self.se_module = SEModule(planes, self.se_ratio) + + self.relu = build_activation_layer(self.act_cfg) if self.act_cfg else build_activation_layer({}) + + def forward(self, x: Tensor) -> Tensor: + """Defines the computation performed at every call.""" + + def _inner_forward(x: Tensor) -> Tensor: + """Forward wrapper for utilizing checkpoint.""" + identity = x + + out = self.conv1(x) + out = self.conv2(out) + if self.se_ratio is not None: + out = self.se_module(out) + + out = self.swish(out) + + out = self.conv3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + return out + identity + + out = cp.checkpoint(_inner_forward, x) if self.with_cp and x.requires_grad else _inner_forward(x) + return self.relu(out) + + +# We do not support initialize with 2D pretrain weight for X3D +class X3DBackbone(nn.Module): + """X3D backbone. https://arxiv.org/pdf/2004.04730.pdf. + + Args: + gamma_w (float): Global channel width expansion factor. Default: 1. + gamma_b (float): Bottleneck channel width expansion factor. Default: 1. + gamma_d (float): Network depth expansion factor. Default: 1. + pretrained (str | None): Name of pretrained model. Default: None. + in_channels (int): Channel num of input features. Default: 3. + num_stages (int): Resnet stages. Default: 4. + spatial_strides (Sequence[int]): + Spatial strides of residual blocks of each stage. + Default: ``(1, 2, 2, 2)``. + frozen_stages (int): Stages to be frozen (all param fixed). If set to + -1, it means not freezing any parameters. Default: -1. + se_style (str): The style of inserting SE modules into BlockX3D, 'half' + denotes insert into half of the blocks, while 'all' denotes insert + into all blocks. Default: 'half'. + se_ratio (float | None): The reduction ratio of squeeze and excitation + unit. If set as None, it means not using SE unit. Default: 1 / 16. + use_swish (bool): Whether to use swish as the activation function + before and after the 3x3x3 conv. Default: True. + conv_cfg (dict): Config for conv layers. required keys are ``type`` + Default: ``dict(type='Conv3d')``. + norm_cfg (dict): Config for norm layers. required keys are ``type`` and + ``requires_grad``. + Default: ``dict(type='BN3d', requires_grad=True)``. + act_cfg (dict): Config dict for activation layer. + Default: ``dict(type='ReLU', inplace=True)``. + norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze + running stats (mean and var). Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): + Whether to use zero initialization for residual block, + Default: True. + kwargs (dict, optional): Key arguments for "make_res_layer". + """ + + def __init__( + self, + gamma_w: float = 1.0, + gamma_b: float = 1.0, + gamma_d: float = 1.0, + pretrained: str | None = None, + in_channels: int = 3, + num_stages: int = 4, + spatial_strides: tuple[int, int, int, int] = (2, 2, 2, 2), + frozen_stages: int = -1, + se_style: str = "half", + se_ratio: float = 1 / 16, + use_swish: bool = True, + conv_cfg: dict | None = None, + norm_cfg: dict | None = None, + act_cfg: dict | None = None, + norm_eval: bool = False, + with_cp: bool = False, + zero_init_residual: bool = True, + **kwargs, + ): + super().__init__() + self.gamma_w = gamma_w + self.gamma_b = gamma_b + self.gamma_d = gamma_d + + self.pretrained = pretrained + self.in_channels = in_channels + # Hard coded, can be changed by gamma_w + self.base_channels = 24 + self.stage_blocks = [1, 2, 5, 3] + + # apply parameters gamma_w and gamma_d + self.base_channels = self._round_width(self.base_channels, self.gamma_w) + + self.stage_blocks = [self._round_repeats(x, self.gamma_d) for x in self.stage_blocks] + + self.num_stages = num_stages + if num_stages < 1 or num_stages > 4: + msg = "num_stages for X3DBackbone should be 1<=num_stages<=4." + raise ValueError(msg) + self.spatial_strides = spatial_strides + if len(spatial_strides) != num_stages: + msg = "number of spatial_strides should be same to num_stages." + raise ValueError(msg) + self.frozen_stages = frozen_stages + + self.se_style = se_style + if self.se_style not in ["all", "half"]: + msg = f"se_style should be 'all' or 'half', but got {self.se_style}." + raise ValueError(msg) + self.se_ratio = se_ratio + if self.se_ratio and self.se_ratio <= 0: + msg = f"se_ratio should be larger than 0, but got {self.se_ratio}." + raise ValueError(msg) + self.use_swish = use_swish + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + self.block = BlockX3D + self.stage_blocks = self.stage_blocks[:num_stages] + self.layer_inplanes = self.base_channels + self._make_stem_layer() + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + spatial_stride = spatial_strides[i] + inplanes = self.base_channels * 2**i + planes = int(inplanes * self.gamma_b) + + res_layer = self.make_res_layer( + self.block, + self.layer_inplanes, + inplanes, + planes, + num_blocks, + spatial_stride=spatial_stride, + se_style=self.se_style, + se_ratio=self.se_ratio, + use_swish=self.use_swish, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + act_cfg=self.act_cfg, + with_cp=with_cp, + **kwargs, + ) + self.layer_inplanes = inplanes + layer_name = f"layer{i + 1}" + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self.feat_dim = self.base_channels * 2 ** (len(self.stage_blocks) - 1) + self.conv5 = ConvModule( + self.feat_dim, + int(self.feat_dim * self.gamma_b), + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ) + self.feat_dim = int(self.feat_dim * self.gamma_b) + + @staticmethod + def _round_width(width: int, multiplier: float, min_depth: int = 8, divisor: int = 8) -> int: + """Round width of filters based on width multiplier.""" + if not multiplier: + return width + + width = int(width * multiplier) + min_depth = min_depth or divisor + new_filters = max(min_depth, int(width + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * width: + new_filters += divisor + return int(new_filters) + + @staticmethod + def _round_repeats(repeats: int, multiplier: float) -> int: + """Round number of layers based on depth multiplier.""" + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + # the module is parameterized with gamma_b + # no temporal_stride + def make_res_layer( + self, + block: nn.Module, + layer_inplanes: int, + inplanes: int, + planes: int, + blocks: int, + spatial_stride: int = 1, + se_style: str = "half", + se_ratio: float | None = None, + use_swish: bool = True, + norm_cfg: dict | None = None, + act_cfg: dict | None = None, + conv_cfg: dict | None = None, + with_cp: bool = False, + **kwargs, + ) -> nn.Module: + """Build residual layer for ResNet3D. + + Args: + block (nn.Module): Residual module to be built. + layer_inplanes (int): Number of channels for the input feature + of the res layer. + inplanes (int): Number of channels for the input feature in each + block, which equals to base_channels * gamma_w. + planes (int): Number of channels for the output feature in each + block, which equals to base_channel * gamma_w * gamma_b. + blocks (int): Number of residual blocks. + spatial_stride (int): Spatial strides in residual and conv layers. + Default: 1. + se_style (str): The style of inserting SE modules into BlockX3D, + 'half' denotes insert into half of the blocks, while 'all' + denotes insert into all blocks. Default: 'half'. + se_ratio (float | None): The reduction ratio of squeeze and + excitation unit. If set as None, it means not using SE unit. + Default: None. + use_swish (bool): Whether to use swish as the activation function + before and after the 3x3x3 conv. Default: True. + conv_cfg (dict | None): Config for norm layers. Default: None. + norm_cfg (dict | None): Config for norm layers. Default: None. + act_cfg (dict | None): Config for activate layers. Default: None. + with_cp (bool | None): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + + Returns: + nn.Module: A residual layer for the given config. + """ + downsample = None + if spatial_stride != 1 or layer_inplanes != inplanes: + downsample = ConvModule( + layer_inplanes, + inplanes, + kernel_size=1, + stride=(1, spatial_stride, spatial_stride), + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + ) + + use_se = [False] * blocks + if self.se_style == "all": + use_se = [True] * blocks + elif self.se_style == "half": + use_se = [i % 2 == 0 for i in range(blocks)] + else: + raise NotImplementedError + + layers = [] + layers.append( + block( + layer_inplanes, + planes, + inplanes, + spatial_stride=spatial_stride, + downsample=downsample, + se_ratio=se_ratio if use_se[0] else None, + use_swish=use_swish, + norm_cfg=norm_cfg, + conv_cfg=conv_cfg, + act_cfg=act_cfg, + with_cp=with_cp, + **kwargs, + ), + ) + + for i in range(1, blocks): + layers.append( # noqa: PERF401 + block( + inplanes, + planes, + inplanes, + spatial_stride=1, + se_ratio=se_ratio if use_se[i] else None, + use_swish=use_swish, + norm_cfg=norm_cfg, + conv_cfg=conv_cfg, + act_cfg=act_cfg, + with_cp=with_cp, + **kwargs, + ), + ) + + return nn.Sequential(*layers) + + def _make_stem_layer(self) -> None: + """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" + self.conv1_s = ConvModule( + self.in_channels, + self.base_channels, + kernel_size=(1, 3, 3), + stride=(1, 2, 2), + padding=(0, 1, 1), + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None, + ) + self.conv1_t = ConvModule( + self.base_channels, + self.base_channels, + kernel_size=(5, 1, 1), + stride=(1, 1, 1), + padding=(2, 0, 0), + groups=self.base_channels, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ) + + def _freeze_stages(self) -> None: + """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" + if self.frozen_stages >= 0: + self.conv1_s.eval() + self.conv1_t.eval() + for param in self.conv1_s.parameters(): + param.requires_grad = False + for param in self.conv1_t.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f"layer{i}") + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self) -> None: + """Initiate the parameters either from existing checkpoint or from scratch.""" + if isinstance(self.pretrained, str): + load_checkpoint(self, self.pretrained, strict=False) + + elif self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv3d): + kaiming_init(m) + elif isinstance(m, _BatchNorm): + constant_init(m, 1) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, BlockX3D): + constant_init(m.conv3.bn, 0) + else: + msg = "pretrained must be a str or None" + raise TypeError(msg) + + def forward(self, x: Tensor) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The feature of the input + samples extracted by the backbone. + """ + x = self.conv1_s(x) + x = self.conv1_t(x) + for layer_name in self.res_layers: + res_layer = getattr(self, layer_name) + x = res_layer(x) + return self.conv5(x) + + def train(self, mode: bool = True) -> None: + """Set the optimization status when training.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/src/otx/algo/action_classification/heads/__init__.py b/src/otx/algo/action_classification/heads/__init__.py index dee22c34408..0dbde770b87 100644 --- a/src/otx/algo/action_classification/heads/__init__.py +++ b/src/otx/algo/action_classification/heads/__init__.py @@ -3,5 +3,6 @@ """Custom heads for action classification.""" from .movinet_head import MoViNetHead +from .x3d_head import X3DHead -__all__ = ["MoViNetHead"] +__all__ = ["MoViNetHead", "X3DHead"] diff --git a/src/otx/algo/action_classification/heads/base_head.py b/src/otx/algo/action_classification/heads/base_head.py new file mode 100644 index 00000000000..354f4c41fec --- /dev/null +++ b/src/otx/algo/action_classification/heads/base_head.py @@ -0,0 +1,221 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. + +# mypy: disable-error-code="attr-defined" + +"""Custom MoViNet Head for video recognition.""" +from __future__ import annotations + +from abc import abstractmethod + +import numpy as np +import torch +from torch import Tensor, nn + +from otx.algo.action_classification.utils.data_sample import ActionDataSample +from otx.algo.modules.base_module import BaseModule + + +class BaseHead(BaseModule): + """Classification head for MoViNet. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + hidden_dim (int): Number of channels in hidden layer. + tf_like (bool): If True, uses TensorFlow-style padding. Default: False. + conv_type (str): Type of convolutional layer. Default: '3d'. + loss_cls (nn.module): Loss class like CrossEntropyLoss. + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + init_std (float): Standard deviation for initialization. Default: 0.1. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + loss_cls: nn.Module, + topk: tuple[int, int] = (1, 5), + average_clips: str | None = None, + ): + super().__init__() # Call the initializer of BaseModule + + self.num_classes = num_classes + self.in_channels = in_channels + self.loss_cls = loss_cls + self.average_clips = average_clips + + if not isinstance(topk, (int, tuple)): + msg = "`topk` should be an int or a tuple of ints" + raise TypeError(msg) + + if any(_topk <= 0 for _topk in topk): + msg = "Top-k should be larger than 0" + raise ValueError(msg) + + self.topk = topk + + @abstractmethod + def forward(self, x: Tensor, **kwargs) -> Tensor: + """Defines the computation performed at every call.""" + raise NotImplementedError + + def loss( + self, + feats: torch.Tensor | tuple[torch.Tensor], + data_samples: list[ActionDataSample], + **kwargs, + ) -> dict: + """Perform forward propagation of head and loss calculation on the features of the upstream network. + + Args: + feats (torch.Tensor | tuple[torch.Tensor]): Features from + upstream network. + data_samples (list[:obj:`ActionDataSample`]): The batch + data samples. + + Returns: + dict: A dictionary of loss components. + """ + cls_scores = self(feats, **kwargs) + return self.loss_by_feat(cls_scores, data_samples) + + def loss_by_feat(self, cls_scores: torch.Tensor, data_samples: list[ActionDataSample]) -> dict: + """Calculate the loss based on the features extracted by the head. + + Args: + cls_scores (torch.Tensor): Classification prediction results of + all class, has shape (batch_size, num_classes). + data_samples (list[:obj:`ActionDataSample`]): The batch + data samples. + + Returns: + dict: A dictionary of loss components. + """ + label_list = [x.gt_label for x in data_samples] + labels: torch.Tensor = torch.stack(label_list).to(cls_scores.device).squeeze() + + losses = {} + if labels.shape == torch.Size([]): + labels = labels.unsqueeze(0) + elif labels.dim() == 1 and labels.size()[0] == self.num_classes and cls_scores.size()[0] == 1: + # Fix a bug when training with soft labels and batch size is 1. + # When using soft labels, `labels` and `cls_score` share the same + # shape. + labels = labels.unsqueeze(0) + + if cls_scores.size() != labels.size(): + top_k_acc = self._top_k_accuracy( + cls_scores.detach().cpu().numpy(), + labels.detach().cpu().numpy(), + self.topk, + ) + for k, a in zip(self.topk, top_k_acc): + losses[f"top{k}_acc"] = torch.tensor(a, device=cls_scores.device) + + loss_cls = self.loss_cls(cls_scores, labels) + # loss_cls may be dictionary or single tensor + if isinstance(loss_cls, dict): + losses.update(loss_cls) + else: + losses["loss_cls"] = loss_cls + return losses + + def predict( + self, + feats: torch.Tensor | tuple[torch.Tensor], + data_samples: list[ActionDataSample], + **kwargs, + ) -> list[ActionDataSample]: + """Perform forward propagation of head and predict recognition results on the features of the upstream network. + + Args: + feats (torch.Tensor | tuple[torch.Tensor]): Features from + upstream network. + data_samples (list[:obj:`ActionDataSample`]): The batch + data samples. + + Returns: + list[:obj:`ActionDataSample`]: Recognition results wrapped + by :obj:`ActionDataSample`. + """ + cls_scores = self(feats, **kwargs) + return self.predict_by_feat(cls_scores, data_samples) + + def predict_by_feat(self, cls_scores: torch.Tensor, data_samples: list[ActionDataSample]) -> list[ActionDataSample]: + """Transform a batch of output features extracted from the head into prediction results. + + Args: + cls_scores (torch.Tensor): Classification scores, has a shape + (B*num_segs, num_classes) + data_samples (list[:obj:`ActionDataSample`]): The + annotation data of every samples. It usually includes + information such as `gt_label`. + + Returns: + List[:obj:`ActionDataSample`]: Recognition results wrapped + by :obj:`ActionDataSample`. + """ + num_segs = cls_scores.shape[0] // len(data_samples) + cls_scores = self.average_clip(cls_scores, num_segs=num_segs) + pred_labels = cls_scores.argmax(dim=-1, keepdim=True).detach() + + for data_sample, score, pred_label in zip(data_samples, cls_scores, pred_labels): + data_sample.set_pred_score(score) + data_sample.set_pred_label(pred_label) + return data_samples + + def average_clip(self, cls_scores: torch.Tensor, num_segs: int = 1) -> torch.Tensor: + """Averaging class scores over multiple clips. + + Using different averaging types ('score' or 'prob' or None, + which defined in test_cfg) to computed the final averaged + class score. Only called in test mode. + + Args: + cls_scores (torch.Tensor): Class scores to be averaged. + num_segs (int): Number of clips for each input sample. + + Returns: + torch.Tensor: Averaged class scores. + """ + if self.average_clips not in ["score", "prob", None]: + msg = f"{self.average_clips} is not supported. Currently supported ones are ['score', 'prob', None]" + raise ValueError(msg) + + batch_size = cls_scores.shape[0] + cls_scores = cls_scores.view((batch_size // num_segs, num_segs) + cls_scores.shape[1:]) + + if self.average_clips is None: + return cls_scores + + if self.average_clips == "prob": + cls_scores = nn.functional.softmax(cls_scores, dim=2).mean(dim=1) + elif self.average_clips == "score": + cls_scores = cls_scores.mean(dim=1) + + return cls_scores + + @staticmethod + def _top_k_accuracy(scores: list[np.ndarray], labels: list[int], topk: tuple[int, int] = (1, 5)) -> list[float]: + """Calculate top k accuracy score. + + Args: + scores (list[np.ndarray]): Prediction scores for each class. + labels (list[int]): Ground truth labels. + topk (tuple[int]): K value for top_k_accuracy. Default: (1, ). + + Returns: + list[float]: Top k accuracy score for each k. + """ + res = [] + labels = np.array(labels)[:, np.newaxis] + for k in topk: + max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1] + match_array = np.logical_or.reduce(max_k_preds == labels, axis=1) + topk_acc_score = match_array.sum() / match_array.shape[0] + res.append(topk_acc_score) + + return res diff --git a/src/otx/algo/action_classification/heads/movinet_head.py b/src/otx/algo/action_classification/heads/movinet_head.py index 80d257202fe..87dfdb18510 100644 --- a/src/otx/algo/action_classification/heads/movinet_head.py +++ b/src/otx/algo/action_classification/heads/movinet_head.py @@ -1,18 +1,17 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Custom MoViNet Head for video recognition.""" +# Copyright (c) OpenMMLab. All rights reserved. +"""Custom MoViNet Head for video recognition.""" from __future__ import annotations -from mmaction.models import MODELS -from mmaction.models.heads.base import BaseHead from torch import Tensor, nn from otx.algo.action_classification.backbones.movinet import ConvBlock3D +from otx.algo.action_classification.heads.base_head import BaseHead from otx.algo.utils.weight_init import normal_init -@MODELS.register_module() class MoViNetHead(BaseHead): """Classification head for MoViNet. @@ -22,7 +21,8 @@ class MoViNetHead(BaseHead): hidden_dim (int): Number of channels in hidden layer. tf_like (bool): If True, uses TensorFlow-style padding. Default: False. conv_type (str): Type of convolutional layer. Default: '3d'. - loss_cls (dict): Config for building loss. Default: dict(type='CrossEntropyLoss'). + loss_cls (nn.module): Loss class like CrossEntropyLoss. + topk (tuple[int, int]): Top-K training loss calculation. Default: (1, 5). spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. dropout_ratio (float): Probability of dropout layer. Default: 0.5. init_std (float): Standard deviation for initialization. Default: 0.1. @@ -33,12 +33,20 @@ def __init__( num_classes: int, in_channels: int, hidden_dim: int, - loss_cls: dict, + loss_cls: nn.Module, + topk: tuple[int, int] = (1, 5), tf_like: bool = False, conv_type: str = "3d", average_clips: str | None = None, ): - super().__init__(num_classes, in_channels, loss_cls, average_clips=average_clips) + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + topk=topk, + average_clips=average_clips, + ) # Call the initializer of BaseHead + self.init_std = 0.1 self.classifier = nn.Sequential( ConvBlock3D( diff --git a/src/otx/algo/action_classification/heads/x3d_head.py b/src/otx/algo/action_classification/heads/x3d_head.py new file mode 100644 index 00000000000..6e78c7ada83 --- /dev/null +++ b/src/otx/algo/action_classification/heads/x3d_head.py @@ -0,0 +1,101 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. + +"""X3D head implementation.""" +from __future__ import annotations + +from torch import Tensor, nn + +from otx.algo.action_classification.heads.base_head import BaseHead +from otx.algo.utils.weight_init import normal_init + + +class X3DHead(BaseHead): + """Classification head for I3D. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + loss_cls (nn.module): Loss class like CrossEntropyLoss. + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + init_std (float): Std value for Initiation. Default: 0.01. + fc1_bias (bool): If the first fc layer has bias. Default: False. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + hidden_dim: int, + loss_cls: nn.Module, + spatial_type: str = "avg", + dropout_ratio: float = 0.5, + init_std: float = 0.01, + fc1_bias: bool = False, + average_clips: str | None = None, + ) -> None: + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + average_clips=average_clips, + ) # Call the initializer of BaseHead + + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + self.init_std = init_std + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + + self.fc1_bias = fc1_bias + + self.fc1 = nn.Linear(self.in_channels, hidden_dim, bias=self.fc1_bias) + self.fc2 = nn.Linear(hidden_dim, self.num_classes) + + self.relu = nn.ReLU() + + self.pool = None + if self.spatial_type == "avg": + self.pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + elif self.spatial_type == "max": + self.pool = nn.AdaptiveMaxPool3d((1, 1, 1)) + else: + raise NotImplementedError + + def init_weights(self) -> None: + """Initiate the parameters from scratch.""" + normal_init(self.fc1, std=self.init_std) + normal_init(self.fc2, std=self.init_std) + + def forward(self, x: Tensor, **kwargs) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (Tensor): The input data. + + Returns: + Tensor: The classification scores for input samples. + """ + # [N, in_channels, T, H, W] + if self.pool is None: + msg = "pool for X3DHead should be given." + raise ValueError(msg) + + x = self.pool(x) + # [N, in_channels, 1, 1, 1] + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + x = self.fc1(x) + # [N, 2048] + x = self.relu(x) + + if self.dropout is not None: + x = self.dropout(x) + + # [N, num_classes] + return self.fc2(x) diff --git a/src/otx/algo/action_classification/mmconfigs/movinet.yaml b/src/otx/algo/action_classification/mmconfigs/movinet.yaml deleted file mode 100644 index da082719f11..00000000000 --- a/src/otx/algo/action_classification/mmconfigs/movinet.yaml +++ /dev/null @@ -1,26 +0,0 @@ -load_from: https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA0_statedict_v3?raw=true -backbone: - type: OTXMoViNet -cls_head: - average_clips: prob - in_channels: 480 - hidden_dim: 2048 - num_classes: 400 - loss_cls: - type: CrossEntropyLoss - loss_weight: 1.0 - type: MoViNetHead -data_preprocessor: - format_shape: NCTHW - mean: - - 0.0 - - 0.0 - - 0.0 - std: - - 255.0 - - 255.0 - - 255.0 - type: ActionDataPreprocessor -test_cfg: null -train_cfg: null -type: MoViNetRecognizer diff --git a/src/otx/algo/action_classification/mmconfigs/x3d.yaml b/src/otx/algo/action_classification/mmconfigs/x3d.yaml deleted file mode 100644 index 9535ae539a6..00000000000 --- a/src/otx/algo/action_classification/mmconfigs/x3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -load_from: https://download.openmmlab.com/mmaction/recognition/x3d/facebook/x3d_m_facebook_16x5x1_kinetics400_rgb_20201027-3f42382a.pth -backbone: - gamma_b: 2.25 - gamma_d: 2.2 - gamma_w: 1 - type: X3D -cls_head: - average_clips: prob - dropout_ratio: 0.5 - fc1_bias: false - in_channels: 432 - num_classes: 400 - spatial_type: avg - type: X3DHead -data_preprocessor: - format_shape: NCTHW - mean: - - 114.75 - - 114.75 - - 114.75 - std: - - 57.38 - - 57.38 - - 57.38 - type: ActionDataPreprocessor -test_cfg: null -train_cfg: null -type: OTXRecognizer3D diff --git a/src/otx/algo/action_classification/movinet.py b/src/otx/algo/action_classification/movinet.py index ef02368a6d4..9e6863f90aa 100644 --- a/src/otx/algo/action_classification/movinet.py +++ b/src/otx/algo/action_classification/movinet.py @@ -7,10 +7,15 @@ from typing import TYPE_CHECKING -from otx.algo.utils.mmconfig import read_mmconfig +from torch import nn + +from otx.algo.action_classification.backbones.movinet import MoViNetBackbone +from otx.algo.action_classification.heads.movinet_head import MoViNetHead +from otx.algo.action_classification.recognizers.movinet_recognizer import MoViNetRecognizer +from otx.algo.utils.mmengine_utils import load_checkpoint from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.metrics.accuracy import MultiClassClsMetricCallable -from otx.core.model.action_classification import MMActionCompatibleModel +from otx.core.model.action_classification import OTXActionClsModel from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import LabelInfoTypes @@ -21,7 +26,7 @@ from otx.core.metrics import MetricCallable -class MoViNet(MMActionCompatibleModel): +class MoViNet(OTXActionClsModel): """MoViNet Model.""" def __init__( @@ -32,16 +37,36 @@ def __init__( metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, ) -> None: - config = read_mmconfig("movinet") + self.load_from = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA0_statedict_v3?raw=true" super().__init__( label_info=label_info, - config=config, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) + def _create_model(self) -> nn.Module: + model = self._build_model(num_classes=self.label_info.num_classes) + model.init_weights() + self.classification_layers = self.get_classification_layers(prefix="model.") + + if self.load_from is not None: + load_checkpoint(model, self.load_from, map_location="cpu") + + return model + + def _build_model(self, num_classes: int) -> nn.Module: + return MoViNetRecognizer( + backbone=MoViNetBackbone(), + cls_head=MoViNetHead( + num_classes=num_classes, + in_channels=480, + hidden_dim=2048, + loss_cls=nn.CrossEntropyLoss(), + ), + ) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_action_ckpt(state_dict, add_prefix) diff --git a/src/otx/algo/action_classification/recognizers/__init__.py b/src/otx/algo/action_classification/recognizers/__init__.py index 196fb947b20..2f4e962a085 100644 --- a/src/otx/algo/action_classification/recognizers/__init__.py +++ b/src/otx/algo/action_classification/recognizers/__init__.py @@ -4,6 +4,6 @@ """Custom 3D recognizers for OTX.""" from .movinet_recognizer import MoViNetRecognizer -from .recognizer import OTXRecognizer3D +from .recognizer import BaseRecognizer -__all__ = ["OTXRecognizer3D", "MoViNetRecognizer"] +__all__ = ["BaseRecognizer", "MoViNetRecognizer"] diff --git a/src/otx/algo/action_classification/recognizers/movinet_recognizer.py b/src/otx/algo/action_classification/recognizers/movinet_recognizer.py index 7b167391d8e..f34e73cffcb 100644 --- a/src/otx/algo/action_classification/recognizers/movinet_recognizer.py +++ b/src/otx/algo/action_classification/recognizers/movinet_recognizer.py @@ -1,20 +1,18 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""MoViNet Recognizer for OTX compatibility.""" +"""MoViNet Recognizer for OTX compatibility.""" import functools -from mmaction.models import MODELS from torch import nn -from otx.algo.action_classification.recognizers.recognizer import OTXRecognizer3D +from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer -@MODELS.register_module() -class MoViNetRecognizer(OTXRecognizer3D): +class MoViNetRecognizer(BaseRecognizer): """MoViNet recognizer model framework for OTX compatibility.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) # Hooks for redirect state_dict load/save self._register_state_dict_hook(self.state_dict_hook) diff --git a/src/otx/algo/action_classification/recognizers/recognizer.py b/src/otx/algo/action_classification/recognizers/recognizer.py index 386752a6f15..f71e8d0e041 100644 --- a/src/otx/algo/action_classification/recognizers/recognizer.py +++ b/src/otx/algo/action_classification/recognizers/recognizer.py @@ -1,20 +1,243 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) OpenMMLab. All rights reserved. + """Custom 3D recognizer for OTX.""" +from __future__ import annotations + +from typing import Any import torch -from mmaction.models import MODELS -from mmaction.models.recognizers import Recognizer3D +from otx.algo.action_classification.utils.data_sample import ActionDataSample +from otx.algo.modules.base_module import BaseModule -@MODELS.register_module() -class OTXRecognizer3D(Recognizer3D): + +class BaseRecognizer(BaseModule): """Custom 3d recognizer class for OTX. This is for patching forward function during export procedure. """ + def __init__( + self, + backbone: torch.Module, + cls_head: torch.Module, + neck: torch.Module | None = None, + test_cfg: dict | None = None, + ) -> None: + super().__init__() + + self.backbone = backbone + self.cls_head = cls_head + if neck is not None: + self.neck = neck + self.test_cfg = test_cfg + + @property + def with_neck(self) -> bool: + """bool: whether the recognizer has a neck.""" + return hasattr(self, "neck") and self.neck is not None + + @property + def with_cls_head(self) -> bool: + """bool: whether the recognizer has a cls_head.""" + return hasattr(self, "cls_head") and self.cls_head is not None + + def extract_feat( + self, + inputs: torch.Tensor, + stage: str = "neck", + data_samples: list[ActionDataSample] | None = None, + test_mode: bool = False, + ) -> tuple: + """Extract features of different stages. + + Args: + inputs (torch.Tensor): The input data. + stage (str): Which stage to output the feature. + Defaults to ``'neck'``. + data_samples (list[:obj:`ActionDataSample`], optional): Action data + samples, which are only needed in training. Defaults to None. + test_mode (bool): Whether in test mode. Defaults to False. + + Returns: + torch.Tensor: The extracted features. + dict: A dict recording the kwargs for downstream + pipeline. These keys are usually included: + ``loss_aux``. + """ + # Record the kwargs required by `loss` and `predict` + loss_predict_kwargs = {} + + num_segs = inputs.shape[1] + # [N, num_crops, C, T, H, W] -> + # [N * num_crops, C, T, H, W] + # `num_crops` is calculated by: + # 1) `twice_sample` in `SampleFrames` + # 2) `num_sample_positions` in `DenseSampleFrames` + # 3) `ThreeCrop/TenCrop` in `test_pipeline` + # 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1` + inputs = inputs.view((-1,) + inputs.shape[2:]) + + # Check settings of test + if test_mode: + if self.test_cfg is not None: + loss_predict_kwargs["fcn_test"] = self.test_cfg.get("fcn_test", False) + if self.test_cfg is not None and self.test_cfg.get("max_testing_views", False): + max_testing_views = self.test_cfg.get("max_testing_views") + if not isinstance(max_testing_views, int): + msg = "max_testing_views should be 'int'" + raise TypeError(msg) + + total_views = inputs.shape[0] + if num_segs != total_views: + msg = "max_testing_views is only compatible with batch_size == 1" + raise ValueError(msg) + view_ptr = 0 + feats = [] + while view_ptr < total_views: + batch_imgs = inputs[view_ptr : view_ptr + max_testing_views] + feat = self.backbone(batch_imgs) + if self.with_neck: + feat, _ = self.neck(feat) + feats.append(feat) + view_ptr += max_testing_views + + def recursively_cat( + feats: torch.Tensor | list[Any] | tuple[Any, ...], + ) -> tuple[torch.Tensor, ...]: + # recursively traverse feats until it's a tensor, + # then concat + out_feats: list[torch.Tensor] = [] + for e_idx, elem in enumerate(feats[0]): + batch_elem = [feat[e_idx] for feat in feats] + if not isinstance(elem, torch.Tensor): + batch_elem = recursively_cat(batch_elem) # type: ignore[assignment] + else: + batch_elem = torch.cat(batch_elem) + out_feats.append(batch_elem) + + return tuple(out_feats) + + x = recursively_cat(feats) if isinstance(feats[0], tuple) else torch.cat(feats) + else: + x = self.backbone(inputs) + if self.with_neck: + x, _ = self.neck(x) + + return x, loss_predict_kwargs + + # Return features extracted through backbone + x = self.backbone(inputs) + if stage == "backbone": + return x, loss_predict_kwargs + + loss_aux = {} + if self.with_neck: + x, loss_aux = self.neck(x, data_samples=data_samples) + + # Return features extracted through neck + loss_predict_kwargs["loss_aux"] = loss_aux + if stage == "neck": + return x, loss_predict_kwargs + + # Return raw logits through head. + x = self.cls_head(x, **loss_predict_kwargs) + return x, loss_predict_kwargs + + def forward( + self, + inputs: torch.Tensor, + data_samples: list[ActionDataSample] | None = None, + mode: str = "tensor", + **kwargs, + ) -> dict[str, torch.Tensor] | list[ActionDataSample] | tuple[torch.Tensor] | torch.Tensor: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: + + - ``tensor``: Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - ``predict``: Forward and return the predictions, which are fully + processed to a list of :obj:`ActionDataSample`. + - ``loss``: Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[``ActionDataSample], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of ``ActionDataSample``. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == "predict": + return self.predict(inputs, data_samples, **kwargs) + if mode == "loss": + return self.loss(inputs, data_samples, **kwargs) + if mode == "tensor": + return self._forward(inputs, **kwargs) + + msg = f"Invalid mode '{mode}'. Only supports loss, predict and tensor mode" + raise RuntimeError(msg) + + def loss(self, inputs: torch.Tensor, data_samples: list[ActionDataSample] | None, **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): Raw Inputs of the recognizer. + These should usually be mean centered and std scaled. + data_samples (List[``ActionDataSample``]): The batch + data samples. It usually includes information such + as ``gt_label``. + + Returns: + dict: A dictionary of loss components. + """ + feats, loss_kwargs = self.extract_feat(inputs, data_samples=data_samples) + + # loss_aux will be a empty dict if `self.with_neck` is False. + loss_aux = loss_kwargs.get("loss_aux", {}) + loss_cls = self.cls_head.loss(feats, data_samples, **loss_kwargs) + return self._merge_dict(loss_cls, loss_aux) + + def predict( + self, + inputs: torch.Tensor, + data_samples: list[ActionDataSample] | None, + **kwargs, + ) -> list[ActionDataSample]: + """Predict results from a batch of inputs and data samples with postprocessing. + + Args: + inputs (torch.Tensor): Raw Inputs of the recognizer. + These should usually be mean centered and std scaled. + data_samples (List[``ActionDataSample``]): The batch + data samples. It usually includes information such + as ``gt_label``. + + Returns: + List[``ActionDataSample``]: Return the recognition results. + The returns value is ``ActionDataSample``, which usually contains + ``pred_scores``. And the ``pred_scores`` usually contains + following keys. + + - item (torch.Tensor): Classification scores, has a shape + (num_classes, ) + """ + feats, predict_kwargs = self.extract_feat(inputs, test_mode=True) + return self.cls_head.predict(feats, data_samples, **predict_kwargs) + def _forward(self, inputs: torch.Tensor, stage: str = "backbone", **kwargs) -> torch.Tensor: """Network forward process for export procedure. @@ -27,3 +250,31 @@ def _forward(self, inputs: torch.Tensor, stage: str = "backbone", **kwargs) -> t cls_scores = self.cls_head(feats, **predict_kwargs) num_segs = cls_scores.shape[0] // inputs.shape[1] return self.cls_head.average_clip(cls_scores, num_segs=num_segs) + + @staticmethod + def _merge_dict(*args) -> dict: + """Merge all dictionaries into one dictionary. + + If pytorch version >= 1.8, ``merge_dict`` will be wrapped + by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip + trace ``merge_dict``. + + Note: + If a function needs to be traced by ``torch.fx.symbolic_trace``, + but inevitably needs to use ``update`` method of ``dict``(``update`` + is not traceable). It should use ``merge_dict`` to replace + ``xxx.update``. + + Args: + *args: dictionary needs to be merged. + + Returns: + dict: Merged dict from args + """ + output = {} + for item in args: + if not isinstance(item, dict): + msg = f"all arguments of merge_dict should be a dict, but got {type(item)}" + raise TypeError(msg) + output.update(item) + return output diff --git a/src/otx/algo/action_classification/utils/data_sample.py b/src/otx/algo/action_classification/utils/data_sample.py new file mode 100644 index 00000000000..b05d7f1088e --- /dev/null +++ b/src/otx/algo/action_classification/utils/data_sample.py @@ -0,0 +1,143 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. + +# mypy: disable-error-code="attr-defined" + +"""Implementation of action data sample.""" +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np +import torch +from otx.algo.utils.mmengine_utils import InstanceData + +LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int] +SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence, dict] + + +def format_label(value: LABEL_TYPE) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The formatted label tensor. + """ + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not isinstance(value, str): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + msg = f"Type {type(value)} is not an available label type." + raise TypeError(msg) + + return value + + +def format_score(value: SCORE_TYPE) -> torch.Tensor | dict: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | dict): + Score values or dict of scores values. + + Returns: + :obj:`torch.Tensor` | dict: The formatted scores. + """ + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not isinstance(value, str): + value = torch.tensor(value).float() + elif isinstance(value, dict): + for k, v in value.items(): + value[k] = format_score(v) + elif not isinstance(value, torch.Tensor): + msg = f"Type {type(value)} is not an available label type." + raise TypeError(msg) + + return value + + +class ActionDataSample(InstanceData): + """A data interface for action data that supports Tensor-like and dict-like operations.""" + + def set_gt_label(self, value: LABEL_TYPE) -> ActionDataSample: + """Set `gt_label``.""" + self.set_field(format_label(value), "gt_label", dtype=torch.Tensor) + return self + + def set_pred_label(self, value: LABEL_TYPE) -> ActionDataSample: + """Set ``pred_label``.""" + self.set_field(format_label(value), "pred_label", dtype=torch.Tensor) + return self + + def set_pred_score(self, value: SCORE_TYPE) -> ActionDataSample: + """Set score of ``pred_label``.""" + score = format_score(value) + self.set_field(score, "pred_score") + if hasattr(self, "num_classes"): + assert ( # noqa: S101 + len(score) == self.num_classes + ), f"The length of score {len(score)} should be equal to the num_classes {self.num_classes}." + else: + self.set_field(name="num_classes", value=len(score), field_type="metainfo") + return self + + @property + def proposals(self) -> InstanceData: + """Property of `proposals`.""" + return self._proposals + + @proposals.setter + def proposals(self, value) -> None: # noqa: ANN001 + """Setter of `proposals`.""" + self.set_field(value, "_proposals", dtype=InstanceData) + + @proposals.deleter + def proposals(self) -> None: + """Deleter of `proposals`.""" + del self._proposals + + @property + def gt_instances(self) -> InstanceData: + """Property of `gt_instances`.""" + return self._gt_instances + + @gt_instances.setter + def gt_instances(self, value) -> None: # noqa: ANN001 + """Setter of `gt_instances`.""" + self.set_field(value, "_gt_instances", dtype=InstanceData) + + @gt_instances.deleter + def gt_instances(self) -> None: + """Deleter of `gt_instances`.""" + del self._gt_instances + + @property + def features(self) -> InstanceData: + """Setter of `features`.""" + return self._features + + @features.setter + def features(self, value) -> None: # noqa: ANN001 + """Setter of `features`.""" + self.set_field(value, "_features", dtype=InstanceData) + + @features.deleter + def features(self) -> None: + """Deleter of `features`.""" + del self._features diff --git a/src/otx/algo/action_classification/x3d.py b/src/otx/algo/action_classification/x3d.py index 1d415ac2f2d..dbb6cb0f490 100644 --- a/src/otx/algo/action_classification/x3d.py +++ b/src/otx/algo/action_classification/x3d.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # """X3D model implementation.""" @@ -6,10 +6,15 @@ from typing import TYPE_CHECKING -from otx.algo.utils.mmconfig import read_mmconfig +from torch import nn + +from otx.algo.action_classification.backbones.x3d import X3DBackbone +from otx.algo.action_classification.heads.x3d_head import X3DHead +from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer +from otx.algo.utils.mmengine_utils import load_checkpoint from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.metrics.accuracy import MultiClassClsMetricCallable -from otx.core.model.action_classification import MMActionCompatibleModel +from otx.core.model.action_classification import OTXActionClsModel from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import LabelInfoTypes @@ -20,7 +25,7 @@ from otx.core.metrics import MetricCallable -class X3D(MMActionCompatibleModel): +class X3D(OTXActionClsModel): """X3D Model.""" def __init__( @@ -31,15 +36,48 @@ def __init__( metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, ) -> None: - config = read_mmconfig("x3d") + self.load_from = "https://download.openmmlab.com/mmaction/recognition/x3d/facebook/x3d_m_facebook_16x5x1_kinetics400_rgb_20201027-3f42382a.pth" super().__init__( label_info=label_info, - config=config, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) + self.mean = (114.75, 114.75, 114.75) + self.std = (57.38, 57.38, 57.38) + + def _create_model(self) -> nn.Module: + model = self._build_model(num_classes=self.label_info.num_classes) + model.init_weights() + self.classification_layers = self.get_classification_layers(prefix="model.") + + if self.load_from is not None: + load_checkpoint(model, self.load_from, map_location="cpu") + + return model + + def _build_model(self, num_classes: int) -> nn.Module: + return BaseRecognizer( + backbone=X3DBackbone( + gamma_b=2.25, + gamma_d=2.2, + gamma_w=1, + conv_cfg={"type": "Conv3d"}, + norm_cfg={"type": "BN3d", "requires_grad": True}, + act_cfg={"type": "ReLU", "inplace": True}, + ), + cls_head=X3DHead( + num_classes=num_classes, + in_channels=432, + hidden_dim=2048, + loss_cls=nn.CrossEntropyLoss(), + spatial_type="avg", + dropout_ratio=0.5, + average_clips="prob", + fc1_bias=False, + ), + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/modules/activation.py b/src/otx/algo/modules/activation.py index 7d507650776..cc3a1e95080 100644 --- a/src/otx/algo/modules/activation.py +++ b/src/otx/algo/modules/activation.py @@ -3,6 +3,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """This implementation replaces the functionality of mmcv.cnn.bricks.activation.build_activation_layer.""" +from __future__ import annotations import copy diff --git a/src/otx/algo/modules/conv_module.py b/src/otx/algo/modules/conv_module.py index 57c379def0f..e84c6230c84 100644 --- a/src/otx/algo/modules/conv_module.py +++ b/src/otx/algo/modules/conv_module.py @@ -124,9 +124,9 @@ def __init__( self, in_channels: int | tuple[int, ...], out_channels: int, - kernel_size: int | tuple[int, int], - stride: int | tuple[int, int] = 1, - padding: int | tuple[int, int] = 0, + kernel_size: int | tuple[int, ...], + stride: int | tuple[int, ...] = 1, + padding: int | tuple[int, ...] = 0, dilation: int | tuple[int, int] = 1, groups: int = 1, bias: bool | str = "auto", diff --git a/src/otx/core/data/dataset/action_classification.py b/src/otx/core/data/dataset/action_classification.py index ddea3f79a0c..23391984423 100644 --- a/src/otx/core/data/dataset/action_classification.py +++ b/src/otx/core/data/dataset/action_classification.py @@ -13,7 +13,7 @@ from otx.core.data.dataset.base import OTXDataset from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsDataEntity -from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.base import ImageInfo, VideoInfo from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER from otx.core.types.image import ImageColorChannel @@ -69,6 +69,7 @@ def _get_item_impl(self, idx: int) -> ActionClsDataEntity | None: ori_shape=(0, 0), image_color_channel=self.image_color_channel, ), + video_info=VideoInfo(), labels=torch.as_tensor([ann.label for ann in label_anns]), ) diff --git a/src/otx/core/data/dataset/action_detection.py b/src/otx/core/data/dataset/action_detection.py index ac061010dcf..01ea07db1e0 100644 --- a/src/otx/core/data/dataset/action_detection.py +++ b/src/otx/core/data/dataset/action_detection.py @@ -18,7 +18,7 @@ from otx.core.data.dataset.base import OTXDataset from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetDataEntity -from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.base import ImageInfo, VideoInfo class OTXActionDetDataset(OTXDataset[ActionDetDataEntity]): @@ -41,6 +41,7 @@ def _get_item_impl(self, idx: int) -> ActionDetDataEntity | None: ) entity = ActionDetDataEntity( + video=item.media, image=img_data, img_info=ImageInfo( img_idx=idx, @@ -48,6 +49,7 @@ def _get_item_impl(self, idx: int) -> ActionDetDataEntity | None: ori_shape=img_shape, image_color_channel=self.image_color_channel, ), + video_info=VideoInfo(), bboxes=tv_tensors.BoundingBoxes( bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, diff --git a/src/otx/core/data/entity/action_classification.py b/src/otx/core/data/entity/action_classification.py index 41d9a0e0899..253869f10b7 100644 --- a/src/otx/core/data/entity/action_classification.py +++ b/src/otx/core/data/entity/action_classification.py @@ -13,6 +13,7 @@ OTXBatchPredEntity, OTXDataEntity, OTXPredEntity, + VideoInfo, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType @@ -33,6 +34,7 @@ class ActionClsDataEntity(OTXDataEntity): """ video: Video + video_info: VideoInfo labels: LongTensor def to_tv_image(self) -> ActionClsDataEntity: diff --git a/src/otx/core/data/entity/action_detection.py b/src/otx/core/data/entity/action_detection.py index b4c2f0b7d9e..007ce4f9b4a 100644 --- a/src/otx/core/data/entity/action_detection.py +++ b/src/otx/core/data/entity/action_detection.py @@ -15,11 +15,13 @@ OTXBatchPredEntity, OTXDataEntity, OTXPredEntity, + VideoInfo, ) from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType if TYPE_CHECKING: + from datumaro.components.media import Video from torch import LongTensor @@ -35,6 +37,8 @@ class ActionDetDataEntity(OTXDataEntity): proposals: Pre-calculated actor proposals. """ + video: Video + video_info: VideoInfo bboxes: tv_tensors.BoundingBoxes labels: LongTensor frame_path: str diff --git a/src/otx/core/data/entity/base.py b/src/otx/core/data/entity/base.py index 431e00273cd..b876a2977b6 100644 --- a/src/otx/core/data/entity/base.py +++ b/src/otx/core/data/entity/base.py @@ -21,6 +21,7 @@ from otx.core.types.task import OTXTaskType if TYPE_CHECKING: + import decord import numpy as np @@ -206,6 +207,139 @@ def __repr__(self) -> str: ) +class VideoInfo(tv_tensors.TVTensor): + """Meta info for video. + + Attributes: + clip_len: Length of a video clip. + num_clips: Number of clips for training. + frame_interval: Interval between sampled frames in a video clip. + video_reader: Decord video reader. + avg_fps: Average number of frames per seconds in a clip. + num_frames: Number of total frames in a video clip. + start_index: Start frame index. + frame_inds: Numpy array of chosen frame indices. + """ + + clip_len: int + num_clips: int + frame_interval: int + video_reader: decord.VideoReader + avg_fps: float + num_frames: int + start_index: int + frame_inds: np.ndarray + + @classmethod + def _wrap( + cls, + dummy_tensor: Tensor, + *, + clip_len: int = 8, + num_clips: int = 1, + frame_interval: int = 4, + video_reader: decord.VideoReader | None = None, + avg_fps: float = 30.0, + num_frames: int | None = None, + start_index: int = 0, + frame_inds: np.ndarray | None = None, + ) -> ImageInfo: + video_info = dummy_tensor.as_subclass(cls) + video_info.video_reader = video_reader + video_info.avg_fps = avg_fps + video_info.num_frames = num_frames + video_info.clip_len = clip_len + video_info.num_clips = num_clips + video_info.frame_interval = frame_interval + video_info.start_index = start_index + video_info.frame_inds = frame_inds + return video_info + + def __new__( # noqa: D102 + cls, + clip_len: int = 8, + num_clips: int = 1, + frame_interval: int = 4, + video_reader: decord.VideoReader | None = None, + avg_fps: float = 30.0, + num_frames: int | None = None, + start_index: int = 0, + frame_inds: np.ndarray | None = None, + ) -> VideoInfo: + return cls._wrap( + dummy_tensor=Tensor(), + clip_len=clip_len, + num_clips=num_clips, + frame_interval=frame_interval, + video_reader=video_reader, + avg_fps=avg_fps, + num_frames=num_frames, + start_index=start_index, + frame_inds=frame_inds, + ) + + @classmethod + def _wrap_output( + cls, + output: Tensor, + args: tuple[()] = (), + kwargs: Mapping[str, Any] | None = None, + ) -> VideoInfo | list[VideoInfo] | tuple[VideoInfo]: + """Wrap an output (`torch.Tensor`) obtained from PyTorch function. + + For example, this function will be called when + + >>> img_info = VideoInfo(img_idx=0, img_shape=(10, 10), ori_shape=(10, 10)) + >>> `_wrap_output()` will be called after the PyTorch function `to()` is called + >>> img_info = img_info.to(device=torch.cuda) + """ + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) + + if isinstance(output, Tensor) and not isinstance(output, VideoInfo): + video_info = next(x for x in flat_params if isinstance(x, VideoInfo)) + output = VideoInfo._wrap( + dummy_tensor=output, + clip_len=video_info.clip_len, + num_clips=video_info.num_clips, + frame_interval=video_info.frame_interval, + video_reader=video_info.video_reader, + avg_fps=video_info.avg_fps, + num_frames=video_info.num_frames, + start_index=video_info.start_index, + frame_inds=video_info.frame_inds, + ) + elif isinstance(output, (tuple, list)): + video_infos = [x for x in flat_params if isinstance(x, VideoInfo)] + output = type(output)( + VideoInfo._wrap( + dummy_tensor=dummy_tensor, + clip_len=video_info.clip_len, + num_clips=video_info.num_clips, + frame_interval=video_info.frame_interval, + video_reader=video_info.video_reader, + avg_fps=video_info.avg_fps, + num_frames=video_info.num_frames, + start_index=video_info.start_index, + frame_inds=video_info.frame_inds, + ) + for dummy_tensor, video_info in zip(output, video_infos) + ) + return output + + def __repr__(self) -> str: + return ( + "VideoInfo(" + f"clip_len={self.clip_len}, " + f"num_clips={self.num_clips}, " + f"frame_interval={self.frame_interval}, " + f"video_reader={self.video_reader}, " + f"avg_fps={self.avg_fps}, " + f"num_frames={self.num_frames}, " + f"start_index={self.start_index}, " + f"frame_inds={self.frame_inds})" + ) + + @F.register_kernel(functional=F.resize, tv_tensor_cls=ImageInfo) def _resize_image_info(image_info: ImageInfo, size: list[int], **kwargs) -> ImageInfo: # noqa: ARG001 """Register ImageInfo to TorchVision v2 resize kernel.""" diff --git a/src/otx/core/data/transform_libs/mmaction.py b/src/otx/core/data/transform_libs/mmaction.py index 8f80f8204f1..f106474fcb2 100644 --- a/src/otx/core/data/transform_libs/mmaction.py +++ b/src/otx/core/data/transform_libs/mmaction.py @@ -20,6 +20,7 @@ from otx.core.data.entity.action_classification import ActionClsDataEntity from otx.core.data.entity.action_detection import ActionDetDataEntity +from otx.core.data.entity.base import VideoInfo from otx.core.utils.config import convert_conf_to_mmconfig_dict if TYPE_CHECKING: @@ -203,8 +204,10 @@ def transform(self, results: dict) -> ActionClsDataEntity | ActionDetDataEntity: ) return ActionDetDataEntity( + video=results["__otx__"].video, image=image, img_info=image_info, + video_info=VideoInfo(), bboxes=bboxes, labels=labels, proposals=proposals, @@ -215,6 +218,7 @@ def transform(self, results: dict) -> ActionClsDataEntity | ActionDetDataEntity: video=results["__otx__"].video, image=image, img_info=image_info, + video_info=VideoInfo(), labels=labels, ) diff --git a/src/otx/core/data/transform_libs/torchvision.py b/src/otx/core/data/transform_libs/torchvision.py index 8ae4a3e7b32..6392ba4776f 100644 --- a/src/otx/core/data/transform_libs/torchvision.py +++ b/src/otx/core/data/transform_libs/torchvision.py @@ -6,18 +6,21 @@ from __future__ import annotations import copy +import io import itertools import math from inspect import isclass from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Sequence import cv2 +import decord import numpy as np import PIL.Image import torch import torchvision.transforms.v2 as tvt_v2 from datumaro.components.media import Video from lightning.pytorch.cli import instantiate_class +from mmengine.fileio import FileClient from numpy import random from omegaconf import DictConfig from torchvision import tv_tensors @@ -526,9 +529,8 @@ def _resize_img(self, inputs: T_OTXDataEntity) -> tuple[T_OTXDataEntity, tuple[f """Resize images with inputs.img_info.img_shape.""" scale_factor: tuple[float, float] | None = getattr(inputs.img_info, "scale_factor", None) # (H, W) if (img := getattr(inputs, "image", None)) is not None: - img = to_np_image(img) - - img_shape = img.shape[:2] # (H, W) + # for considering video case + img_shape = img[0].shape[:2] if isinstance(img, list) else img.shape[:2] scale: tuple[int, int] = self.scale or scale_size( img_shape, self.scale_factor, # type: ignore[arg-type] @@ -537,11 +539,25 @@ def _resize_img(self, inputs: T_OTXDataEntity) -> tuple[T_OTXDataEntity, tuple[f if self.keep_ratio: scale = rescale_size(img_shape, scale) # type: ignore[assignment] - # flipping `scale` is required because cv2.resize uses (W, H) - img = cv2.resize(img, scale[::-1], interpolation=CV2_INTERP_CODES[self.interpolation]) + # for considering video case + if isinstance(img, list): + for idx, im in enumerate(img): + # flipping `scale` is required because cv2.resize uses (W, H) + img[idx] = cv2.resize( + to_np_image(im), + scale[::-1], + interpolation=CV2_INTERP_CODES[self.interpolation], + ) + else: + # flipping `scale` is required because cv2.resize uses (W, H) + img = cv2.resize(to_np_image(img), scale[::-1], interpolation=CV2_INTERP_CODES[self.interpolation]) inputs.image = img - inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2]) + + if isinstance(img, list): + inputs.img_info = _resize_image_info(inputs.img_info, img[0].shape[:2]) + else: + inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2]) scale_factor = (scale[0] / img_shape[0], scale[1] / img_shape[1]) return inputs, scale_factor @@ -1077,8 +1093,12 @@ def forward(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: if (cur_dir := self._choose_direction()) is not None: # flip image - img = to_np_image(inputs.image) - img = flip_image(img, direction=cur_dir) + if isinstance(inputs.image, list): + img = inputs.image + for idx, im in enumerate(img): + img[idx] = flip_image(to_np_image(im), direction=cur_dir) + else: + img = flip_image(to_np_image(inputs.image), direction=cur_dir) inputs.image = img # flip bboxes @@ -2614,6 +2634,420 @@ def forward(self, *inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: return outputs +class FormatShape(tvt_v2.Transform): + """Format final imgs shape to the given input_format.""" + + def __init__(self, input_format: str, collapse: bool = False) -> None: + self.input_format = input_format + self.collapse = collapse + if self.input_format not in [ + "NCTHW", + "NCHW", + "NPTCHW", + ]: + msg = f"The input format {self.input_format} is invalid." + raise ValueError(msg) + + def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: + """Perform the SampleFrames loading.""" + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 + inputs = _inputs[0] + + if not isinstance(inputs.image, np.ndarray): + inputs.image = np.array(inputs.image) + + # [M x H x W x C] + # M = 1 * N_crops * N_clips * T + if self.collapse and inputs.video_info.num_clips != 1: + msg = "num_clips should be 1." + raise ValueError(msg) + + if self.input_format == "NCTHW": + imgs = inputs.image + num_clips = inputs.video_info.num_clips + clip_len = inputs.video_info.clip_len + if isinstance(clip_len, dict): + clip_len = clip_len["RGB"] + + imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:]) + # N_crops x N_clips x T x H x W x C + imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4)) + # N_crops x N_clips x C x T x H x W + imgs = imgs.reshape((-1,) + imgs.shape[2:]) + # M' x C x T x H x W + # M' = N_crops x N_clips + inputs.image = imgs + inputs.img_info.img_shape = imgs.shape + elif self.input_format == "NCHW": + imgs = inputs.image + imgs = np.transpose(imgs, (0, 3, 1, 2)) + + # M x C x H x W + inputs.image = imgs + inputs.img_info.img_shape = imgs.shape + elif self.input_format == "NPTCHW": + num_proposals = inputs.proposals.shape[0] # WJ + num_clips = inputs.video_info.num_clips + clip_len = inputs.video_info.clip_len + imgs = inputs.image + imgs = imgs.reshape((num_proposals, num_clips * clip_len) + imgs.shape[1:]) + # P x M x H x W x C + # M = N_clips x T + imgs = np.transpose(imgs, (0, 1, 4, 2, 3)) + # P x M x C x H x W + inputs.image["imgs"] = imgs + inputs.img_info.img_shape = imgs.shape + + if self.collapse: + if inputs.image.shape[0] != 1: + msg = "num_clips should be 1." + raise ValueError(msg) + inputs.image = inputs.image.squeeze(0) + inputs.img_info.img_shape = inputs.image.shape + + inputs.image = tv_tensors.Image(inputs.image) + return inputs + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(input_format='{self.input_format}')" + return repr_str + + +class DecordInit(tvt_v2.Transform): + """Using decord to initialize the video_reader.""" + + def __init__(self, io_backend: str = "disk", num_threads: int = 1, **kwargs) -> None: + self.io_backend = io_backend + self.num_threads = num_threads + self.kwargs = kwargs + self.file_client = None + + def _get_video_reader(self, filename: str) -> decord.VideoReader: + if self.file_client is None: + self.file_client = FileClient(self.io_backend, **self.kwargs) + file_obj = io.BytesIO(self.file_client.get(filename)) + return decord.VideoReader(file_obj, num_threads=self.num_threads) + + def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: + """Perform the Decord initialization.""" + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 + inputs = _inputs[0] + + container = self._get_video_reader(inputs.video.path) + inputs.video_info.video_reader = container + inputs.video_info.num_frames = len(container) + inputs.video_info.avg_fps = container.get_avg_fps() + return inputs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(io_backend={self.io_backend}, num_threads={self.num_threads})" + + +class SampleFrames(tvt_v2.Transform): + """Sample frames from the video. + + Required Keys: + + - total_frames + - start_index + + Added Keys: + + - frame_inds + - frame_interval + - num_clips + + Args: + clip_len (int): Frames of each sampled output clip. + frame_interval (int): Temporal interval of adjacent sampled frames. + Defaults to 1. + num_clips (int): Number of clips to be sampled. Default: 1. + temporal_jitter (bool): Whether to apply temporal jittering. + Defaults to False. + twice_sample (bool): Whether to use twice sample when testing. + If set to True, it will sample frames with and without fixed shift, + which is commonly used for testing in TSM model. Defaults to False. + out_of_bound_opt (str): The way to deal with out of bounds frame + indexes. Available options are 'loop', 'repeat_last'. + Defaults to 'loop'. + test_mode (bool): Store True when building test or validation dataset. + Defaults to False. + keep_tail_frames (bool): Whether to keep tail frames when sampling. + Defaults to False. + target_fps (optional, int): Convert input videos with arbitrary frame + rates to the unified target FPS before sampling frames. If + ``None``, the frame rate will not be adjusted. Defaults to + ``None``. + """ + + def __init__( + self, + clip_len: int, + frame_interval: int = 1, + num_clips: int = 1, + temporal_jitter: bool = False, + twice_sample: bool = False, + out_of_bound_opt: str = "loop", + test_mode: bool = False, + keep_tail_frames: bool = False, + target_fps: int | None = None, + **kwargs, + ) -> None: + self.clip_len = clip_len + self.frame_interval = frame_interval + self.num_clips = num_clips + self.temporal_jitter = temporal_jitter + self.twice_sample = twice_sample + self.out_of_bound_opt = out_of_bound_opt + self.test_mode = test_mode + self.keep_tail_frames = keep_tail_frames + self.target_fps = target_fps + if self.out_of_bound_opt not in ["loop", "repeat_last"]: + msg = f"out_of_bound_opt should be 'loop' or 'repeat_last', but found {self.out_of_bound_opt}." + raise ValueError(msg) + + def _get_train_clips(self, num_frames: int, ori_clip_len: float) -> np.array: + """Get clip offsets in train mode. + + It will calculate the average interval for selected frames, + and randomly shift them within offsets between [0, avg_interval]. + If the total number of frames is smaller than clips num or origin + frames length, it will return all zero indices. + + Args: + num_frames (int): Total number of frame in the video. + ori_clip_len (float): length of original sample clip. + + Returns: + np.ndarray: Sampled frame indices in train mode. + """ + if self.keep_tail_frames: + avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) + if num_frames > ori_clip_len - 1: + base_offsets = np.arange(self.num_clips) * avg_interval + clip_offsets = (base_offsets + np.random.uniform(0, avg_interval, self.num_clips)).astype(np.int32) + else: + clip_offsets = np.zeros((self.num_clips,), dtype=np.int32) + else: + avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips + + if avg_interval > 0: + base_offsets = np.arange(self.num_clips) * avg_interval + clip_offsets = base_offsets + np.random.randint(avg_interval, size=self.num_clips) + elif num_frames > max(self.num_clips, ori_clip_len): + clip_offsets = np.sort(np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips)) + elif avg_interval == 0: + ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips + clip_offsets = np.around(np.arange(self.num_clips) * ratio) + else: + clip_offsets = np.zeros((self.num_clips,), dtype=np.int32) + + return clip_offsets + + def _get_test_clips(self, num_frames: int, ori_clip_len: float) -> np.array: + """Get clip offsets in test mode. + + If the total number of frames is + not enough, it will return all zero indices. + + Args: + num_frames (int): Total number of frame in the video. + ori_clip_len (float): length of original sample clip. + + Returns: + np.ndarray: Sampled frame indices in test mode. + """ + if self.clip_len == 1: # 2D recognizer + # assert self.frame_interval == 1 + avg_interval = num_frames / float(self.num_clips) + base_offsets = np.arange(self.num_clips) * avg_interval + clip_offsets = base_offsets + avg_interval / 2.0 + if self.twice_sample: + clip_offsets = np.concatenate([clip_offsets, base_offsets]) + else: # 3D recognizer + max_offset = max(num_frames - ori_clip_len, 0) + num_clips = self.num_clips * 2 if self.twice_sample else self.num_clips + if num_clips > 1: + num_segments = self.num_clips - 1 + # align test sample strategy with `PySlowFast` repo + if self.target_fps is not None: + offset_between = np.floor(max_offset / float(num_segments)) + clip_offsets = np.arange(num_clips) * offset_between + else: + offset_between = max_offset / float(num_segments) + clip_offsets = np.arange(num_clips) * offset_between + clip_offsets = np.round(clip_offsets) + else: + clip_offsets = np.array([max_offset // 2]) + return clip_offsets + + def _sample_clips(self, num_frames: int, ori_clip_len: float) -> np.array: + """Choose clip offsets for the video in a given mode. + + Args: + num_frames (int): Total number of frame in the video. + + Returns: + np.ndarray: Sampled frame indices. + """ + if self.test_mode: + clip_offsets = self._get_test_clips(num_frames, ori_clip_len) + else: + clip_offsets = self._get_train_clips(num_frames, ori_clip_len) + + return clip_offsets + + def _get_ori_clip_len(self, fps_scale_ratio: float) -> float: + """Calculate length of clip segment for different strategy. + + Args: + fps_scale_ratio (float): Scale ratio to adjust fps. + """ + if self.target_fps is not None: + # align test sample strategy with `PySlowFast` repo + ori_clip_len = self.clip_len * self.frame_interval + ori_clip_len = np.maximum(1, ori_clip_len * fps_scale_ratio) + elif self.test_mode: + ori_clip_len = (self.clip_len - 1) * self.frame_interval + 1 + else: + ori_clip_len = self.clip_len * self.frame_interval + + return ori_clip_len + + def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: + """Perform the SampleFrames loading.""" + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 + inputs = _inputs[0] + + total_frames = inputs.video_info.num_frames + # if can't get fps, same value of `fps` and `target_fps` + # will perform nothing + fps = inputs.video_info.avg_fps + fps_scale_ratio = 1.0 if self.target_fps is None or not fps else fps / self.target_fps + ori_clip_len = self._get_ori_clip_len(fps_scale_ratio) + clip_offsets = self._sample_clips(total_frames, ori_clip_len) + + if self.target_fps: + frame_inds = clip_offsets[:, None] + np.linspace(0, ori_clip_len - 1, self.clip_len).astype(np.int32) + else: + frame_inds = clip_offsets[:, None] + np.arange(self.clip_len)[None, :] * self.frame_interval + frame_inds = np.concatenate(frame_inds) + + if self.temporal_jitter: + perframe_offsets = np.random.randint(self.frame_interval, size=len(frame_inds)) + frame_inds += perframe_offsets + + frame_inds = frame_inds.reshape((-1, self.clip_len)) + if self.out_of_bound_opt == "loop": + frame_inds = np.mod(frame_inds, total_frames) + elif self.out_of_bound_opt == "repeat_last": + safe_inds = frame_inds < total_frames + unsafe_inds = 1 - safe_inds + last_ind = np.max(safe_inds * frame_inds, axis=1) + new_inds = safe_inds * frame_inds + (unsafe_inds.T * last_ind).T + frame_inds = new_inds + else: + msg = f"out_of_bound_opt should be 'loop' or 'repeat_last', but found {self.out_of_bound_opt}." + raise ValueError(msg) + + start_index = inputs.video_info.start_index + frame_inds = np.concatenate(frame_inds) + start_index + inputs.video_info.frame_inds = frame_inds.astype(np.int32) + inputs.video_info.clip_len = self.clip_len + inputs.video_info.frame_interval = self.frame_interval + inputs.video_info.num_clips = self.num_clips + return inputs + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"clip_len={self.clip_len}, " + f"frame_interval={self.frame_interval}, " + f"num_clips={self.num_clips}, " + f"temporal_jitter={self.temporal_jitter}, " + f"twice_sample={self.twice_sample}, " + f"out_of_bound_opt={self.out_of_bound_opt}, " + f"test_mode={self.test_mode})" + ) + + +class DecordDecode(tvt_v2.Transform): + """Using decord to decode the video.""" + + def __init__(self, mode: str = "accurate") -> None: + self.mode = mode + if self.mode not in ["accurate", "efficient"]: + msg = f"Decord mode should be 'accurate' or 'efficient', but found {self.mode}." + raise ValueError(msg) + + def _decord_load_frames(self, container: object, frame_inds: np.ndarray) -> list[np.ndarray]: + if self.mode == "accurate": + imgs = container.get_batch(frame_inds).asnumpy() + imgs = list(imgs) + elif self.mode == "efficient": + # This mode is faster, however it always returns I-FRAME + container.seek(0) + imgs = [] + for idx in frame_inds: + container.seek(idx) + frame = container.next() + imgs.append(frame.asnumpy()) + return imgs + + def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: + """Transform function to resize images, bounding boxes, and masks.""" + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 + inputs = _inputs[0] + + container = inputs.video_info.video_reader + + if inputs.video_info.frame_inds.ndim != 1: + inputs.video_info.frame_inds = np.squeeze(inputs.video_info.frame_inds) + + frame_inds = inputs.video_info.frame_inds + imgs = self._decord_load_frames(container, frame_inds) + + inputs.video_info.video_reader = None + del container + + inputs.image = imgs + inputs.img_info.ori_shape = imgs[0].shape[:2] + inputs.img_info.img_shape = imgs[0].shape[:2] + + # we resize the gt_bboxes and proposals to their real scale + if bboxes := getattr(inputs, "bboxes", None): + h, w = inputs.img_info.img_shape + scale_factor = np.array([w, h, w, h]) + gt_bboxes = (bboxes * scale_factor).astype(np.float32) + inputs.bboxes = gt_bboxes + if proposals := getattr(inputs, "proposals", None): + proposals = (proposals * scale_factor).astype(np.float32) + inputs.proposals = proposals + + return inputs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mode={self.mode})" + + +class Normalize3D(tvt_v2.Normalize): + """Using normalize the 3D video data.""" + + def __init__(self, mean: list[float], std: list[float], inplace: bool = False) -> None: + self.mean = torch.Tensor(mean).view(1, 3, 1, 1, 1) + self.std = torch.Tensor(std).view(1, 3, 1, 1, 1) + self.inplace = inplace + + def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: + """Transform function to resize images, bounding boxes, and masks.""" + assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101 + inputs = _inputs[0] + + inputs.image = F.normalize(inputs.image, self.mean, self.std, self.inplace) + return inputs + + class TorchVisionTransformLib: """Helper to support TorchVision transforms (only V2) in OTX.""" diff --git a/src/otx/core/model/action_classification.py b/src/otx/core/model/action_classification.py index 313487dce23..f3db89c4219 100644 --- a/src/otx/core/model/action_classification.py +++ b/src/otx/core/model/action_classification.py @@ -3,6 +3,8 @@ # """Class definition for action_classification model entity used in OTX.""" +# mypy: disable-error-code="attr-defined" + from __future__ import annotations from typing import TYPE_CHECKING, Any @@ -10,6 +12,7 @@ import numpy as np import torch +from otx.algo.action_classification.utils.data_sample import ActionDataSample from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.exporter.native import OTXNativeModelExporter @@ -32,9 +35,6 @@ from otx.core.metrics import MetricCallable -# ruff: noqa: F401 - - class OTXActionClsModel(OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity]): """Base class for the action classification models used in OTX.""" @@ -46,6 +46,9 @@ def __init__( metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, ) -> None: + self.image_size = (1, 1, 3, 8, 224, 224) + self.mean = (0.0, 0.0, 0.0) + self.std = (255.0, 255.0, 255.0) super().__init__( label_info=label_info, optimizer=optimizer, @@ -74,6 +77,94 @@ def _convert_pred_entity_to_compute_metric( "target": target, } + def _customize_inputs(self, entity: ActionClsBatchDataEntity) -> dict[str, Any]: + """Convert ActionClsBatchDataEntity into mmaction model's input.""" + mmaction_inputs: dict[str, Any] = {} + + mmaction_inputs["inputs"] = entity.images + mmaction_inputs["data_samples"] = [ + ActionDataSample( + metainfo={ + "img_id": img_info.img_idx, + "img_shape": img_info.img_shape, + "ori_shape": img_info.ori_shape, + "scale_factor": img_info.scale_factor, + }, + gt_label=labels, + ) + for img_info, labels in zip(entity.imgs_info, entity.labels) + ] + + mmaction_inputs["mode"] = "loss" if self.training else "predict" + return mmaction_inputs + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: ActionClsBatchDataEntity, + ) -> ActionClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + if not isinstance(outputs, dict): + raise TypeError(outputs) + + losses = OTXBatchLossEntity() + for k, v in outputs.items(): + losses[k] = v + return losses + + scores = [] + labels = [] + + for output in outputs: + if not isinstance(output, ActionDataSample): + raise TypeError(output) + + scores.append(output.pred_score) + labels.append(output.pred_label) + + return ActionClsBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + ) + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=self.mean, + std=self.std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=False, + onnx_export_configuration=None, + output_names=None, + ) + + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(inputs=image, mode="tensor") + + def get_classification_layers(self, prefix: str = "model.") -> dict[str, dict[str, int]]: + """Get final classification layer information for incremental learning case.""" + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + + classification_layers = {} + for key in sample_model_dict: + if sample_model_dict[key].shape != incremental_model_dict[key].shape: + sample_model_dim = sample_model_dict[key].shape[0] + incremental_model_dim = incremental_model_dict[key].shape[0] + stride = incremental_model_dim - sample_model_dim + num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim + classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes} + return classification_layers + class MMActionCompatibleModel(OTXActionClsModel): """Action classification model compitible for MMAction. @@ -211,8 +302,6 @@ def __init__( metric: MetricCallable = MultiClassClsMetricCallable, **kwargs, ) -> None: - from otx.algo.action_classification import openvino_model - super().__init__( model_name=model_name, model_type=model_type, diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/engine/utils/auto_configurator.py index 896c0bd034f..5a8f15bc998 100644 --- a/src/otx/engine/utils/auto_configurator.py +++ b/src/otx/engine/utils/auto_configurator.py @@ -40,8 +40,8 @@ OTXTaskType.ROTATED_DETECTION: RECIPE_PATH / "rotated_detection" / "maskrcnn_r50.yaml", OTXTaskType.SEMANTIC_SEGMENTATION: RECIPE_PATH / "semantic_segmentation" / "litehrnet_18.yaml", OTXTaskType.INSTANCE_SEGMENTATION: RECIPE_PATH / "instance_segmentation" / "maskrcnn_r50.yaml", - OTXTaskType.ACTION_CLASSIFICATION: RECIPE_PATH / "action" / "action_classification" / "x3d.yaml", - OTXTaskType.ACTION_DETECTION: RECIPE_PATH / "action" / "action_detection" / "x3d_fastrcnn.yaml", + OTXTaskType.ACTION_CLASSIFICATION: RECIPE_PATH / "action_classification" / "x3d.yaml", + OTXTaskType.ACTION_DETECTION: RECIPE_PATH / "action_detection" / "x3d_fastrcnn.yaml", OTXTaskType.ANOMALY_CLASSIFICATION: RECIPE_PATH / "anomaly_classification" / "padim.yaml", OTXTaskType.ANOMALY_SEGMENTATION: RECIPE_PATH / "anomaly_segmentation" / "padim.yaml", OTXTaskType.ANOMALY_DETECTION: RECIPE_PATH / "anomaly_detection" / "padim.yaml", diff --git a/src/otx/recipe/action/action_classification/movinet.yaml b/src/otx/recipe/action/action_classification/movinet.yaml deleted file mode 100644 index c1d7330c4ea..00000000000 --- a/src/otx/recipe/action/action_classification/movinet.yaml +++ /dev/null @@ -1,26 +0,0 @@ -model: - class_path: otx.algo.action_classification.movinet.MoViNet - init_args: - label_info: 400 - - optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 0.0003 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.5 - patience: 2 - monitor: val/accuracy - -engine: - task: ACTION_CLASSIFICATION - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmaction_base.yaml diff --git a/src/otx/recipe/action/action_classification/x3d.yaml b/src/otx/recipe/action/action_classification/x3d.yaml deleted file mode 100644 index 4df8fb9e624..00000000000 --- a/src/otx/recipe/action/action_classification/x3d.yaml +++ /dev/null @@ -1,26 +0,0 @@ -model: - class_path: otx.algo.action_classification.x3d.X3D - init_args: - label_info: 400 - - optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 0.0001 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.1 - patience: 1 - monitor: val/accuracy - -engine: - task: ACTION_CLASSIFICATION - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmaction_base.yaml diff --git a/src/otx/recipe/action_classification/movinet.yaml b/src/otx/recipe/action_classification/movinet.yaml new file mode 100644 index 00000000000..f378f54b539 --- /dev/null +++ b/src/otx/recipe/action_classification/movinet.yaml @@ -0,0 +1,137 @@ +model: + class_path: otx.algo.action_classification.movinet.MoViNet + init_args: + label_info: 400 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0003 + weight_decay: 0.0001 + + scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 2 + monitor: val/accuracy + +engine: + task: ACTION_CLASSIFICATION + device: auto + +callback_monitor: val/accuracy + +data: + task: ACTION_CLASSIFICATION + config: + data_format: kinetics + mem_cache_size: 1GB + mem_cache_img_max_size: + - 500 + - 500 + image_color_channel: BGR + stack_images: True + unannotated_items_ratio: 0.0 + train_subset: + subset_name: train + transform_lib_type: TORCHVISION + to_tv_image: True + batch_size: 8 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [0.0, 0.0, 0.0] + std: [255.0, 255.0, 255.0] + sampler: + class_path: torch.utils.data.RandomSampler + val_subset: + subset_name: val + transform_lib_type: TORCHVISION + batch_size: 8 + to_tv_image: True + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [0.0, 0.0, 0.0] + std: [255.0, 255.0, 255.0] + test_subset: + subset_name: test + transform_lib_type: TORCHVISION + to_tv_image: True + batch_size: 8 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [0.0, 0.0, 0.0] + std: [255.0, 255.0, 255.0] diff --git a/src/otx/recipe/action/action_classification/openvino_model.yaml b/src/otx/recipe/action_classification/openvino_model.yaml similarity index 96% rename from src/otx/recipe/action/action_classification/openvino_model.yaml rename to src/otx/recipe/action_classification/openvino_model.yaml index 95cd0766187..9095b045d43 100644 --- a/src/otx/recipe/action/action_classification/openvino_model.yaml +++ b/src/otx/recipe/action_classification/openvino_model.yaml @@ -13,7 +13,7 @@ engine: callback_monitor: val/accuracy -data: ../../_base_/data/torchvision_base.yaml +data: ../_base_/data/torchvision_base.yaml overrides: data: task: ACTION_CLASSIFICATION diff --git a/src/otx/recipe/action_classification/x3d.yaml b/src/otx/recipe/action_classification/x3d.yaml new file mode 100644 index 00000000000..cd49ef5ab19 --- /dev/null +++ b/src/otx/recipe/action_classification/x3d.yaml @@ -0,0 +1,137 @@ +model: + class_path: otx.algo.action_classification.x3d.X3D + init_args: + label_info: 400 + + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.0001 + weight_decay: 0.0001 + + scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 1 + monitor: val/accuracy + +engine: + task: ACTION_CLASSIFICATION + device: auto + +callback_monitor: val/accuracy + +data: + task: ACTION_CLASSIFICATION + config: + data_format: kinetics + mem_cache_size: 1GB + mem_cache_img_max_size: + - 500 + - 500 + image_color_channel: BGR + stack_images: True + unannotated_items_ratio: 0.0 + train_subset: + subset_name: train + transform_lib_type: TORCHVISION + to_tv_image: True + batch_size: 8 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [114.75, 114.75, 114.75] + std: [57.38, 57.38, 57.38] + sampler: + class_path: torch.utils.data.RandomSampler + val_subset: + subset_name: val + transform_lib_type: TORCHVISION + batch_size: 8 + to_tv_image: True + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [114.75, 114.75, 114.75] + std: [57.38, 57.38, 57.38] + test_subset: + subset_name: test + transform_lib_type: TORCHVISION + to_tv_image: True + batch_size: 8 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.DecordInit + - class_path: otx.core.data.transform_libs.torchvision.SampleFrames + init_args: + clip_len: 8 + frame_interval: 4 + num_clips: 1 + - class_path: otx.core.data.transform_libs.torchvision.DecordDecode + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: + - 224 + - 224 + keep_ratio: false + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: otx.core.data.transform_libs.torchvision.FormatShape + init_args: + input_format: NCTHW + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: otx.core.data.transform_libs.torchvision.Normalize3D + init_args: + mean: [114.75, 114.75, 114.75] + std: [57.38, 57.38, 57.38] diff --git a/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml b/src/otx/recipe/action_detection/x3d_fastrcnn.yaml similarity index 98% rename from src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml rename to src/otx/recipe/action_detection/x3d_fastrcnn.yaml index df2a59da6d6..b1d0d5edc18 100644 --- a/src/otx/recipe/action/action_detection/x3d_fastrcnn.yaml +++ b/src/otx/recipe/action_detection/x3d_fastrcnn.yaml @@ -29,7 +29,7 @@ engine: callback_monitor: val/map_50 -data: ../../_base_/data/mmaction_base.yaml +data: ../_base_/data/mmaction_base.yaml overrides: precision: 32 data: diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 1ac1d2796d4..d09234297c1 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -220,8 +220,6 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: tmp_path_test = tmp_path / f"otx_test_{model_name}" if "_cls" in recipe: export_test_recipe = f"src/otx/recipe/classification/{task}/openvino_model.yaml" - elif "action_classification" in recipe: - export_test_recipe = f"src/otx/recipe/action/{task}/openvino_model.yaml" else: export_test_recipe = f"src/otx/recipe/{task}/openvino_model.yaml" diff --git a/tests/perf/test_action.py b/tests/perf/test_action.py index bba1ab52ea6..2ebc9bfa0e6 100644 --- a/tests/perf/test_action.py +++ b/tests/perf/test_action.py @@ -17,8 +17,8 @@ class TestPerfActionClassification(PerfTestBase): """Benchmark action classification.""" MODEL_TEST_CASES = [ # noqa: RUF012 - Benchmark.Model(task="action/action_classification", name="movinet", category="speed"), - Benchmark.Model(task="action/action_classification", name="x3d", category="accuracy"), + Benchmark.Model(task="action_classification", name="movinet", category="speed"), + Benchmark.Model(task="action_classification", name="x3d", category="accuracy"), ] DATASET_TEST_CASES = [ # noqa: RUF012 @@ -64,6 +64,7 @@ class TestPerfActionClassification(PerfTestBase): Benchmark.Criterion(name="train/epoch", summary="max", compare="<", margin=0.1), Benchmark.Criterion(name="train/e2e_time", summary="max", compare="<", margin=0.1), Benchmark.Criterion(name="test/accuracy", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="export/test/accuracy", summary="max", compare=">", margin=0.1), Benchmark.Criterion(name="export/accuracy", summary="max", compare=">", margin=0.1), Benchmark.Criterion(name="optimize/accuracy", summary="max", compare=">", margin=0.1), Benchmark.Criterion(name="train/iter_time", summary="mean", compare="<", margin=0.1), @@ -107,7 +108,7 @@ class TestPerfActionDetection(PerfTestBase): """Benchmark action detection.""" MODEL_TEST_CASES = [ # noqa: RUF012 - Benchmark.Model(task="action/action_detection", name="x3d_fastrcnn", category="accuracy"), + Benchmark.Model(task="action_detection", name="x3d_fastrcnn", category="accuracy"), ] DATASET_TEST_CASES = [ # noqa: RUF012 diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py index eb9731d6324..40ff65c2a6e 100644 --- a/tests/regression/test_regression.py +++ b/tests/regression/test_regression.py @@ -825,8 +825,8 @@ def test_regression( class TestActionClassification(BaseTest): # Test case parametrization for model MODEL_TEST_CASES = [ # noqa: RUF012 - ModelTestCase(task="action/action_classification", name="x3d"), - ModelTestCase(task="action/action_classification", name="movinet"), + ModelTestCase(task="action_classification", name="x3d"), + ModelTestCase(task="action_classification", name="movinet"), ] DATASET_TEST_CASES = [ DatasetTestCase( diff --git a/tests/unit/algo/action_classification/backbones/test_movinet.py b/tests/unit/algo/action_classification/backbones/test_movinet.py index 792387d52e2..e944862d022 100644 --- a/tests/unit/algo/action_classification/backbones/test_movinet.py +++ b/tests/unit/algo/action_classification/backbones/test_movinet.py @@ -4,14 +4,14 @@ import pytest import torch -from otx.algo.action_classification.backbones.movinet import OTXMoViNet +from otx.algo.action_classification.backbones.movinet import MoViNetBackbone class TestMoViNet: @pytest.fixture() - def fxt_movinet(self) -> OTXMoViNet: - return OTXMoViNet() + def fxt_movinet(self) -> MoViNetBackbone: + return MoViNetBackbone() - def test_forward(self, fxt_movinet: OTXMoViNet) -> None: + def test_forward(self, fxt_movinet: MoViNetBackbone) -> None: x = torch.randn(1, 3, 8, 224, 224) assert fxt_movinet(x).shape == torch.Size([1, 480, 1, 1, 1]) diff --git a/tests/unit/core/data/transform_libs/test_mmaction.py b/tests/unit/core/data/transform_libs/test_mmaction.py index 0125c82882d..47bf36f3969 100644 --- a/tests/unit/core/data/transform_libs/test_mmaction.py +++ b/tests/unit/core/data/transform_libs/test_mmaction.py @@ -11,7 +11,7 @@ from otx.core.config.data import SubsetConfig from otx.core.data.entity.action_classification import ActionClsDataEntity from otx.core.data.entity.action_detection import ActionDetDataEntity -from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.base import ImageInfo, VideoInfo from otx.core.data.transform_libs.mmaction import ( LoadAnnotations, LoadVideoForClassification, @@ -32,6 +32,7 @@ class TestActionClsPipeline: def fxt_action_cls_data(self) -> dict: entity = ActionClsDataEntity( video=MockVideo(), + video_info=VideoInfo(), image=[], img_info=ImageInfo( img_idx=0, @@ -61,6 +62,8 @@ class TestActionDetPipelines: @pytest.fixture() def fxt_action_det_data(self, mocker) -> dict: entity = ActionDetDataEntity( + video=MockVideo(), + video_info=VideoInfo(), image=torch.randn([3, 10, 10]), img_info=ImageInfo( img_idx=0, diff --git a/tests/unit/core/data/transform_libs/test_torchvision.py b/tests/unit/core/data/transform_libs/test_torchvision.py index a8f637cbd97..1d4ba697aab 100644 --- a/tests/unit/core/data/transform_libs/test_torchvision.py +++ b/tests/unit/core/data/transform_libs/test_torchvision.py @@ -12,7 +12,7 @@ import torch from datumaro import Polygon from otx.core.data.entity.action_classification import ActionClsDataEntity -from otx.core.data.entity.base import ImageInfo, OTXDataEntity +from otx.core.data.entity.base import ImageInfo, OTXDataEntity, VideoInfo from otx.core.data.entity.detection import DetBatchDataEntity, DetDataEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegDataEntity from otx.core.data.transform_libs.torchvision import ( @@ -73,6 +73,7 @@ class TestPackVideo: def test_forward(self): entity = ActionClsDataEntity( video=MockVideo(), + video_info=VideoInfo(), image=[], img_info=ImageInfo( img_idx=0, diff --git a/tests/unit/engine/utils/test_api.py b/tests/unit/engine/utils/test_api.py index 2fe65010903..4076276bb11 100644 --- a/tests/unit/engine/utils/test_api.py +++ b/tests/unit/engine/utils/test_api.py @@ -16,8 +16,6 @@ def test_list_models_per_task(task: str) -> None: task_dir = task if task_dir.endswith("CLS"): task_dir = "classification/" + task_dir - elif task_dir.startswith("ACTION"): - task_dir = "action/" + task_dir target_dir = RECIPE_PATH / task_dir.lower() target_recipes = [str(recipe.stem) for recipe in target_dir.glob("**/*.yaml")]