Skip to content

Commit d871331

Browse files
committed
simplify todtype testing
1 parent 973e058 commit d871331

File tree

2 files changed

+38
-61
lines changed

2 files changed

+38
-61
lines changed

test/test_transforms_v2.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,28 @@ def fn(value):
27082708

27092709
return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)
27102710

2711+
def _get_dtype_conversion_atol(self, input_dtype, output_dtype, scale):
2712+
is_uint16_to_uint8 = input_dtype == torch.uint16 and output_dtype == torch.uint8
2713+
is_uint8_to_uint16 = input_dtype == torch.uint8 and output_dtype == torch.uint16
2714+
changes_type_class = output_dtype.is_floating_point != input_dtype.is_floating_point
2715+
2716+
in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None
2717+
out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None
2718+
expands_bits = in_bits is not None and out_bits is not None and out_bits > in_bits
2719+
2720+
if is_uint16_to_uint8:
2721+
atol = 255
2722+
elif is_uint8_to_uint16 and not scale:
2723+
atol = 255
2724+
elif expands_bits and not scale:
2725+
atol = 1
2726+
elif changes_type_class:
2727+
atol = 1
2728+
else:
2729+
atol = 0
2730+
2731+
return atol
2732+
27112733
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
27122734
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
27132735
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -2732,54 +2754,16 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_
27322754
input = make_input(dtype=input_dtype, device=device)
27332755
out = F.to_dtype(input, dtype=output_dtype, scale=scale)
27342756

2735-
if isinstance(input, torch.Tensor):
2736-
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
2737-
if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
2738-
torch.testing.assert_close(out, expected, atol=1, rtol=0)
2739-
else:
2740-
torch.testing.assert_close(out, expected)
2741-
else: # cvcuda.Tensor
2757+
if make_input == make_image_cvcuda:
27422758
expected = self.reference_convert_dtype_image_tensor(
27432759
F.cvcuda_to_tensor(input), dtype=output_dtype, scale=scale
27442760
)
27452761
out = F.cvcuda_to_tensor(out)
2746-
# there are some differences in dtype conversion between torchvision and cvcuda
2747-
# due to different rounding behavior when converting between types with different bit widths
2748-
# Check if we're converting to a type with more bits (without scaling)
2749-
in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None
2750-
out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None
2751-
2752-
if scale:
2753-
if input_dtype.is_floating_point and not output_dtype.is_floating_point:
2754-
# float -> int with scaling: allow for rounding differences
2755-
torch.testing.assert_close(out, expected, atol=1, rtol=0)
2756-
elif input_dtype == torch.uint16 and output_dtype == torch.uint8:
2757-
# uint16 -> uint8 with scaling: allow large differences
2758-
torch.testing.assert_close(out, expected, atol=255, rtol=0)
2759-
else:
2760-
torch.testing.assert_close(out, expected)
2761-
else:
2762-
if in_bits is not None and out_bits is not None and out_bits > in_bits:
2763-
# uint to larger uint without scaling: allow large differences due to bit expansion
2764-
if input_dtype == torch.uint8 and output_dtype == torch.uint16:
2765-
torch.testing.assert_close(out, expected, atol=255, rtol=0)
2766-
else:
2767-
torch.testing.assert_close(out, expected, atol=1, rtol=0)
2768-
elif not input_dtype.is_floating_point and not output_dtype.is_floating_point:
2769-
# uint to uint without scaling (same or smaller bits): allow for rounding
2770-
if input_dtype == torch.uint16 and output_dtype == torch.uint8:
2771-
# uint16 -> uint8 can have large differences due to bit reduction
2772-
torch.testing.assert_close(out, expected, atol=255, rtol=0)
2773-
else:
2774-
torch.testing.assert_close(out, expected)
2775-
elif input_dtype.is_floating_point and not output_dtype.is_floating_point:
2776-
# float -> uint without scaling: allow for rounding differences
2777-
torch.testing.assert_close(out, expected, atol=1, rtol=0)
2778-
elif not input_dtype.is_floating_point and output_dtype.is_floating_point:
2779-
# uint -> float without scaling: allow for rounding differences
2780-
torch.testing.assert_close(out, expected, atol=1, rtol=0)
2781-
else:
2782-
torch.testing.assert_close(out, expected)
2762+
else:
2763+
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
2764+
2765+
atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale)
2766+
torch.testing.assert_close(out, expected, rtol=0, atol=atol)
27832767

27842768
def was_scaled(self, inpt):
27852769
# this assumes the target dtype is float

torchvision/transforms/v2/functional/_misc.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -400,24 +400,17 @@ def _to_dtype_cvcuda(
400400
dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype]
401401
cvc_dtype = _torch_to_cvcuda_dtypes[dtype]
402402

403-
if not scale:
404-
return cvcuda.convertto(inpt, dtype=cvc_dtype)
405-
406403
scale_val, offset = 1.0, 0.0
407-
in_dtype_float = dtype_in.is_floating_point
408-
out_dtype_float = dtype.is_floating_point
409-
410-
# four cases for the scaling setup
411-
# 1. float -> float
412-
# 2. int -> int
413-
# 3. float -> int
414-
# 4. int -> float
415-
if in_dtype_float == out_dtype_float:
416-
scale_val, offset = 1.0, 0.0
417-
elif in_dtype_float and not out_dtype_float:
418-
scale_val, offset = float(_max_value(dtype)), 0.0
419-
else:
420-
scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0
404+
if scale:
405+
in_dtype_float = dtype_in.is_floating_point
406+
out_dtype_float = dtype.is_floating_point
407+
408+
if in_dtype_float == out_dtype_float:
409+
scale_val, offset = 1.0, 0.0
410+
elif in_dtype_float and not out_dtype_float:
411+
scale_val, offset = float(_max_value(dtype)), 0.0
412+
else:
413+
scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0
421414

422415
return cvcuda.convertto(
423416
inpt,

0 commit comments

Comments
 (0)