@@ -2708,6 +2708,28 @@ def fn(value):
27082708
27092709 return torch .tensor (tree_map (fn , image .tolist ())).to (dtype = output_dtype , device = image .device )
27102710
2711+ def _get_dtype_conversion_atol (self , input_dtype , output_dtype , scale ):
2712+ is_uint16_to_uint8 = input_dtype == torch .uint16 and output_dtype == torch .uint8
2713+ is_uint8_to_uint16 = input_dtype == torch .uint8 and output_dtype == torch .uint16
2714+ changes_type_class = output_dtype .is_floating_point != input_dtype .is_floating_point
2715+
2716+ in_bits = torch .iinfo (input_dtype ).bits if not input_dtype .is_floating_point else None
2717+ out_bits = torch .iinfo (output_dtype ).bits if not output_dtype .is_floating_point else None
2718+ expands_bits = in_bits is not None and out_bits is not None and out_bits > in_bits
2719+
2720+ if is_uint16_to_uint8 :
2721+ atol = 255
2722+ elif is_uint8_to_uint16 and not scale :
2723+ atol = 255
2724+ elif expands_bits and not scale :
2725+ atol = 1
2726+ elif changes_type_class :
2727+ atol = 1
2728+ else :
2729+ atol = 0
2730+
2731+ return atol
2732+
27112733 @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
27122734 @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
27132735 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -2732,54 +2754,16 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_
27322754 input = make_input (dtype = input_dtype , device = device )
27332755 out = F .to_dtype (input , dtype = output_dtype , scale = scale )
27342756
2735- if isinstance (input , torch .Tensor ):
2736- expected = self .reference_convert_dtype_image_tensor (input , dtype = output_dtype , scale = scale )
2737- if input_dtype .is_floating_point and not output_dtype .is_floating_point and scale :
2738- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2739- else :
2740- torch .testing .assert_close (out , expected )
2741- else : # cvcuda.Tensor
2757+ if make_input == make_image_cvcuda :
27422758 expected = self .reference_convert_dtype_image_tensor (
27432759 F .cvcuda_to_tensor (input ), dtype = output_dtype , scale = scale
27442760 )
27452761 out = F .cvcuda_to_tensor (out )
2746- # there are some differences in dtype conversion between torchvision and cvcuda
2747- # due to different rounding behavior when converting between types with different bit widths
2748- # Check if we're converting to a type with more bits (without scaling)
2749- in_bits = torch .iinfo (input_dtype ).bits if not input_dtype .is_floating_point else None
2750- out_bits = torch .iinfo (output_dtype ).bits if not output_dtype .is_floating_point else None
2751-
2752- if scale :
2753- if input_dtype .is_floating_point and not output_dtype .is_floating_point :
2754- # float -> int with scaling: allow for rounding differences
2755- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2756- elif input_dtype == torch .uint16 and output_dtype == torch .uint8 :
2757- # uint16 -> uint8 with scaling: allow large differences
2758- torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2759- else :
2760- torch .testing .assert_close (out , expected )
2761- else :
2762- if in_bits is not None and out_bits is not None and out_bits > in_bits :
2763- # uint to larger uint without scaling: allow large differences due to bit expansion
2764- if input_dtype == torch .uint8 and output_dtype == torch .uint16 :
2765- torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2766- else :
2767- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2768- elif not input_dtype .is_floating_point and not output_dtype .is_floating_point :
2769- # uint to uint without scaling (same or smaller bits): allow for rounding
2770- if input_dtype == torch .uint16 and output_dtype == torch .uint8 :
2771- # uint16 -> uint8 can have large differences due to bit reduction
2772- torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2773- else :
2774- torch .testing .assert_close (out , expected )
2775- elif input_dtype .is_floating_point and not output_dtype .is_floating_point :
2776- # float -> uint without scaling: allow for rounding differences
2777- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2778- elif not input_dtype .is_floating_point and output_dtype .is_floating_point :
2779- # uint -> float without scaling: allow for rounding differences
2780- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2781- else :
2782- torch .testing .assert_close (out , expected )
2762+ else :
2763+ expected = self .reference_convert_dtype_image_tensor (input , dtype = output_dtype , scale = scale )
2764+
2765+ atol = self ._get_dtype_conversion_atol (input_dtype , output_dtype , scale )
2766+ torch .testing .assert_close (out , expected , rtol = 0 , atol = atol )
27832767
27842768 def was_scaled (self , inpt ):
27852769 # this assumes the target dtype is float
0 commit comments