Skip to content

Commit 4259d7f

Browse files
committed
update to standards from flip PR
1 parent 7f63b11 commit 4259d7f

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,

torchvision/transforms/v2/_misc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchvision import transforms as _transforms, tv_tensors
1111
from torchvision.transforms.v2 import functional as F, Transform
12+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
1213

1314
from ._utils import (
1415
_parse_labels_getter,
@@ -17,11 +18,13 @@
1718
get_bounding_boxes,
1819
get_keypoints,
1920
has_any,
20-
is_cvcuda_tensor,
2121
is_pure_tensor,
2222
)
2323

2424

25+
CVCUDA_AVAILABLE = _is_cvcuda_available()
26+
27+
2528
# TODO: do we want/need to expose this?
2629
class Identity(Transform):
2730
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
@@ -268,7 +271,8 @@ class ToDtype(Transform):
268271
Default: ``False``.
269272
"""
270273

271-
_transformed_types = (torch.Tensor, is_cvcuda_tensor)
274+
if CVCUDA_AVAILABLE:
275+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
272276

273277
def __init__(
274278
self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False

torchvision/transforms/v2/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
19-
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
19+
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor
2020

2121

2222
def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")
@@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210-
is_cvcuda_tensor,
210+
_is_cvcuda_tensor,
211211
),
212212
)
213213
}

0 commit comments

Comments
 (0)