Skip to content
Merged
6 changes: 4 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,7 +2128,8 @@ def get_params(alpha: list[float], sigma: list[float], size: list[int]) -> Tenso
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], sigma)
dx = dx * alpha[0] / size[0]
# normalize horizontal displacement by width (size[1])
dx = dx * alpha[0] / size[1]

dy = torch.rand([1, 1] + size) * 2 - 1
if sigma[1] > 0.0:
Expand All @@ -2137,7 +2138,8 @@ def get_params(alpha: list[float], sigma: list[float], size: list[int]) -> Tenso
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], sigma)
dy = dy * alpha[1] / size[1]
# normalize vertical displacement by height (size[0])
dy = dy * alpha[1] / size[0]
return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2

def forward(self, tensor: Tensor) -> Tensor:
Expand Down
10 changes: 5 additions & 5 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,25 +1061,25 @@ def __init__(
self._fill = _setup_fill_arg(fill)

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
size = list(query_size(flat_inputs))
height, width = query_size(flat_inputs)

dx = torch.rand([1, 1] + size) * 2 - 1
dx = torch.rand(1, 1, height, width) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dx = dx * self.alpha[0] / width

dy = torch.rand([1, 1] + size) * 2 - 1
dy = torch.rand(1, 1, height, width) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
dy = dy * self.alpha[1] / height
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)

Expand Down
Loading