diff --git a/tripy/nvtripy/frontend/module/conv/base.py b/tripy/nvtripy/frontend/module/conv/base.py index 812e08fef..ddc7defa1 100644 --- a/tripy/nvtripy/frontend/module/conv/base.py +++ b/tripy/nvtripy/frontend/module/conv/base.py @@ -14,11 +14,12 @@ # limitations under the License. from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union from nvtripy import utils from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend.dimension_size import DimensionSize from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.ops import utils as op_utils @@ -39,12 +40,12 @@ class ConvBase(Module): def __init__( self, - in_channels: int, - out_channels: int, - kernel_dims: Sequence[int], + in_channels: Union[int, DimensionSize], + out_channels: Union[int, DimensionSize], + kernel_dims: Union[Sequence[int], Sequence[DimensionSize]], padding: Sequence[Sequence[int]] = None, stride: Sequence[int] = None, - groups: int = None, + groups: Union[int, DimensionSize] = None, dilation: Sequence[int] = None, bias: bool = True, dtype: datatype.dtype = datatype.float32, diff --git a/tripy/nvtripy/frontend/module/conv/conv.py b/tripy/nvtripy/frontend/module/conv/conv.py index 2c322fdf8..59e99ee91 100644 --- a/tripy/nvtripy/frontend/module/conv/conv.py +++ b/tripy/nvtripy/frontend/module/conv/conv.py @@ -17,10 +17,11 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from nvtripy import export from nvtripy.common import datatype +from nvtripy.frontend.dimension_size import DimensionSize from nvtripy.frontend.module.conv.base import ConvBase from nvtripy.frontend.module.conv.utils import conv_deconv_helper from nvtripy.frontend.module.parameter import DefaultParameter @@ -116,13 +117,13 @@ class Conv(ConvBase): def __init__( self, - in_channels: int, - out_channels: int, - kernel_dims: Sequence[int], + in_channels: Union[int, DimensionSize], + out_channels: Union[int, DimensionSize], + kernel_dims: Union[Sequence[int], Sequence[DimensionSize]], stride: Optional[Sequence[int]] = None, padding: Optional[Sequence[Tuple[int, int]]] = None, dilation: Optional[Sequence[int]] = None, - groups: Optional[int] = None, + groups: Optional[Union[int, DimensionSize]] = None, bias: bool = True, dtype: datatype.dtype = datatype.float32, ) -> None: