Skip to content

Commit 736a2e6

Browse files
committed
add int -> int scaling setup for cvcuda, use bit diff for scale
1 parent d871331 commit 736a2e6

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,8 +2748,6 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_
27482748
pytest.xfail("float to int64 conversion is not supported")
27492749
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
27502750
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")
2751-
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and scale and make_input == make_image_cvcuda:
2752-
pytest.xfail("uint8 to uint16 conversion with scale is not supported in F._misc._to_dtype_cvcuda")
27532751

27542752
input = make_input(dtype=input_dtype, device=device)
27552753
out = F.to_dtype(input, dtype=output_dtype, scale=scale)

torchvision/transforms/v2/functional/_misc.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def _to_dtype_cvcuda(
388388
4. int -> float
389389
If scale is True, the values will be scaled to the new dtype.
390390
If scale is False, the values will not be scaled.
391-
The scale values for float -> float and int -> int are 1.0 and 0.0 respectively.
391+
The scale values for float -> float are 1.0 and 0.0 respectively.
392+
The scale values for int -> int are 2^(bit_diff) of the new dtype.
393+
Where bit_diff is the difference in the number of bits of the new dtype and the input dtype.
392394
The scale values for float -> int and int -> float are the maximum value of the new dtype.
393395
394396
Returns:
@@ -405,8 +407,13 @@ def _to_dtype_cvcuda(
405407
in_dtype_float = dtype_in.is_floating_point
406408
out_dtype_float = dtype.is_floating_point
407409

408-
if in_dtype_float == out_dtype_float:
410+
if in_dtype_float and out_dtype_float:
409411
scale_val, offset = 1.0, 0.0
412+
elif not in_dtype_float and not out_dtype_float:
413+
in_bits = torch.iinfo(dtype_in).bits
414+
out_bits = torch.iinfo(dtype).bits
415+
scale_val = float(2 ** (out_bits - in_bits))
416+
offset = 0.0
410417
elif in_dtype_float and not out_dtype_float:
411418
scale_val, offset = float(_max_value(dtype)), 0.0
412419
else:

0 commit comments

Comments
 (0)