Skip to content

Commit

Permalink
[Fix] Resize interpolation difference CV2 and Pillow to come closer w…
Browse files Browse the repository at this point in the history
…ith docTR (#22)
  • Loading branch information
felixdittrich92 authored Jul 8, 2024
1 parent 03a1277 commit a2b1e9b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 64 deletions.
14 changes: 7 additions & 7 deletions onnxtr/models/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def sample_transforms(self, x: np.ndarray) -> np.ndarray:
if x.dtype not in (np.uint8, np.float32):
raise TypeError("unsupported data type for numpy.ndarray")
x = shape_translate(x, "HWC")
# Data type & 255 division
if x.dtype == np.uint8:
x = x.astype(np.float32) / 255.0

# Resizing
x = self.resize(x)
# Data type & 255 division
if x.dtype == np.uint8 or np.max(x) > 1:
x = x.astype(np.float32) / 255.0

return x

Expand All @@ -95,13 +96,12 @@ def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]:
raise TypeError("unsupported data type for numpy.ndarray")
x = shape_translate(x, "BHWC")

# Data type & 255 division
if x.dtype == np.uint8:
x = x.astype(np.float32) / 255.0
# Resizing
if (x.shape[1], x.shape[2]) != self.resize.output_size:
x = np.array([self.resize(sample) for sample in x])

# Data type & 255 division
if x.dtype == np.uint8:
x = x.astype(np.float32) / 255.0
batches = [x]

elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
Expand Down
79 changes: 33 additions & 46 deletions onnxtr/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from typing import Tuple, Union

import cv2
import numpy as np
from PIL import Image, ImageOps

__all__ = ["Resize", "Normalize"]

Expand All @@ -17,64 +17,51 @@ class Resize:
def __init__(
self,
size: Union[int, Tuple[int, int]],
interpolation=cv2.INTER_LINEAR,
interpolation=Image.Resampling.BILINEAR,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
) -> None:
super().__init__()
self.size = size
self.size = size if isinstance(size, tuple) else (size, size)
self.interpolation = interpolation
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.output_size = size if isinstance(size, tuple) else (size, size)

if not isinstance(self.size, (int, tuple, list)):
raise AssertionError("size should be either a tuple, a list or an int")
if not isinstance(self.size, (tuple, int)):
raise AssertionError("size should be either a tuple or an int")

def __call__(
self,
img: np.ndarray,
) -> np.ndarray:
if img.ndim == 3:
h, w = img.shape[0:2]
else:
h, w = img.shape[1:3]
sh, sw = self.size if isinstance(self.size, tuple) else (self.size, self.size)
def __call__(self, img: np.ndarray) -> np.ndarray:
img = (img * 255).astype(np.uint8) if img.dtype != np.uint8 else img
h, w = img.shape[:2] if img.ndim == 3 else img.shape[1:3]
sh, sw = self.size

# Calculate aspect ratio of the image
aspect = w / h
if not self.preserve_aspect_ratio:
return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))

# Compute scaling and padding sizes
if self.preserve_aspect_ratio:
if aspect > 1: # Horizontal image
new_w = sw
new_h = int(sw / aspect)
elif aspect < 1: # Vertical image
new_h = sh
new_w = int(sh * aspect)
else: # Square image
new_h, new_w = sh, sw

img_resized = cv2.resize(img, (new_w, new_h), interpolation=self.interpolation)

# Calculate padding
pad_top = max((sh - new_h) // 2, 0)
pad_bottom = max(sh - new_h - pad_top, 0)
pad_left = max((sw - new_w) // 2, 0)
pad_right = max(sw - new_w - pad_left, 0)

# Pad the image
img_resized = cv2.copyMakeBorder( # type: ignore[call-overload]
img_resized, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=0
)

# Ensure the image matches the target size by resizing it again if needed
img_resized = cv2.resize(img_resized, (sw, sh), interpolation=self.interpolation)
actual_ratio = h / w
target_ratio = sh / sw

if target_ratio == actual_ratio:
return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))

if actual_ratio > target_ratio:
tmp_size = (int(sh / actual_ratio), sh)
else:
# Resize the image without preserving aspect ratio
img_resized = cv2.resize(img, (sw, sh), interpolation=self.interpolation)
tmp_size = (sw, int(sw * actual_ratio))

img_resized = Image.fromarray(img).resize(tmp_size, resample=self.interpolation)
pad_left = pad_top = 0
pad_right = sw - img_resized.width
pad_bottom = sh - img_resized.height

if self.symmetric_pad:
pad_left = pad_right // 2
pad_right -= pad_left
pad_top = pad_bottom // 2
pad_bottom -= pad_top

return img_resized
img_resized = ImageOps.expand(img_resized, (pad_left, pad_top, pad_right, pad_bottom))
return np.array(img_resized)

def __repr__(self) -> str:
interpolate_str = self.interpolation
Expand Down
8 changes: 5 additions & 3 deletions onnxtr/utils/fonts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

import logging
import platform
from typing import Optional
from typing import Optional, Union

from PIL import ImageFont

__all__ = ["get_font"]


def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:
def get_font(
font_family: Optional[str] = None, font_size: int = 13
) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
"""Resolves a compatible ImageFont for the system
Args:
Expand All @@ -29,7 +31,7 @@ def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFon
try:
font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
except OSError: # pragma: no cover
font = ImageFont.load_default()
font = ImageFont.load_default() # type: ignore[assignment]
logging.warning(
"unable to load recommended font family. Loading default PIL font,"
"font size issues may be expected."
Expand Down
2 changes: 1 addition & 1 deletion tests/common/test_models_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, e
assert all(isinstance(b, np.ndarray) for b in out)
assert all(b.dtype == np.float32 for b in out)
assert all(b.shape[1:3] == output_size for b in out)
assert all(np.all(np.abs(b - expected_value) < 1e-6) for b in out)
assert all(np.all(b == expected_value) for b in out)
assert len(repr(processor).split("\n")) == 4
14 changes: 7 additions & 7 deletions tests/common/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@ def test_resize():
input_t = np.ones((64, 64, 3), dtype=np.float32)
out = transfo(input_t)

assert np.all(out == 1)
assert np.all(out == 255)
assert out.shape[:2] == output_size
assert repr(transfo) == f"Resize(output_size={output_size}, interpolation='1')"
assert repr(transfo) == f"Resize(output_size={output_size}, interpolation='2')"

transfo = Resize(output_size, preserve_aspect_ratio=True)
input_t = np.ones((32, 64, 3), dtype=np.float32)
out = transfo(input_t)

assert out.shape[:2] == output_size
assert not np.all(out == 1)
assert not np.all(out == 255)
# Asymetric padding
assert np.all(out[-1] == 0) and np.all(out[0] == 0)
assert np.all(out[-1] == 0) and np.all(out[0] == 255)

# Symetric padding
transfo = Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True)
assert repr(transfo) == (
f"Resize(output_size={output_size}, interpolation='1', " f"preserve_aspect_ratio=True, symmetric_pad=True)"
f"Resize(output_size={output_size}, interpolation='2', " f"preserve_aspect_ratio=True, symmetric_pad=True)"
)
out = transfo(input_t)
assert out.shape[:2] == output_size
# symetric padding
assert np.all(out[-1] == 0) and np.all(out[0] == 0)

# Inverse aspect ratio
input_t = np.ones((3, 64, 32), dtype=np.float32)
input_t = np.ones((64, 32, 3), dtype=np.float32)
out = transfo(input_t)

assert not np.all(out == 1)
Expand All @@ -43,7 +43,7 @@ def test_resize():
# Same aspect ratio
output_size = (32, 128)
transfo = Resize(output_size, preserve_aspect_ratio=True)
out = transfo(np.ones((3, 16, 64), dtype=np.float32))
out = transfo(np.ones((16, 64, 3), dtype=np.float32))
assert out.shape[:2] == output_size


Expand Down

0 comments on commit a2b1e9b

Please sign in to comment.