Skip to content

Commit

Permalink
Add equivalence test
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Aug 29, 2024
1 parent 2afe096 commit 29175dc
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 6 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h

[[autodoc]] LlavaConfig

## LlavaImageProcessor

[[autodoc]] LlavaImageProcessor
- preprocess

## LlavaProcessor

[[autodoc]] LlavaProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.llava"].append("LlavaImageProcessor")
_import_structure["models.llava_next"].append("LlavaNextImageProcessor")
_import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor")
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
Expand Down Expand Up @@ -5979,6 +5980,7 @@
LayoutLMv3ImageProcessor,
)
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.llava import LlavaImageProcessor
from .models.llava_next import LlavaNextImageProcessor
from .models.llava_next_video import LlavaNextVideoImageProcessor
from .models.mask2former import Mask2FormerImageProcessor
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/llava/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {
Expand All @@ -33,6 +33,14 @@
"LlavaPreTrainedModel",
]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_llava"] = ["LlavaImageProcessor"]


if TYPE_CHECKING:
from .configuration_llava import LlavaConfig
Expand All @@ -49,6 +57,14 @@
LlavaPreTrainedModel,
)

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_llava import LlavaImageProcessor

else:
import sys

Expand Down
66 changes: 62 additions & 4 deletions src/transformers/models/llava/image_processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
get_resize_output_image_size,
pad,
resize,
to_channel_dimension_format,
)
Expand All @@ -32,6 +34,7 @@
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
Expand All @@ -55,7 +58,7 @@ class LlavaImageProcessor(BaseImageProcessor):
Constructs a LLaVa image processor.
Args:
do_pad (`bool`, *optional*, defaults to `True`):
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image to a square. Can be overridden by `do_pad` in the `preprocess` method.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
Expand Down Expand Up @@ -95,7 +98,7 @@ class LlavaImageProcessor(BaseImageProcessor):

def __init__(
self,
do_pad: bool = True,
do_pad: bool = False,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
Expand Down Expand Up @@ -154,10 +157,13 @@ def __init__(
# `shortest_edge` key.
delattr(self, "use_square_size")

def pad_to_square(image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0) -> Image.Image:
def pad_to_square_original(
self, image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0
) -> Image.Image:
"""
Pads an image to make it square.
"""
print("Image size:", image.size)
width, height = image.size
if width == height:
return image
Expand All @@ -170,6 +176,51 @@ def pad_to_square(image: Image.Image, background_color: Union[int, Tuple[int, in
result.paste(image, ((height - width) // 2, 0))
return result

def pad_to_square(
self,
image: np.array,
background_color: Union[int, Tuple[int, int, int]] = 0,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to make it square.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)

if height == width:
return image

max_dim = max(height, width)
pad_height = max_dim - height
pad_width = max_dim - width

padding = (
(pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2),
)

return pad(
image=image,
padding=padding,
mode=PaddingMode.CONSTANT,
constant_values=background_color,
input_data_format=input_data_format,
)

def resize(
self,
image: np.ndarray,
Expand Down Expand Up @@ -347,7 +398,14 @@ def preprocess(
input_data_format = infer_channel_dimension_format(images[0])

if do_pad:
images = [self.pad_to_square(image=image) for image in images]
images = [
self.pad_to_square(
image=image,
background_color=tuple(int(x * 255) for x in self.image_mean),
input_data_format=input_data_format,
)
for image in images
]

if do_resize:
images = [
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/models/llava/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import requests
from PIL import Image

from transformers import LlavaImageProcessor


processor = LlavaImageProcessor()

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

result1 = processor.pad_to_square_original(image=image)
result1 = np.array(result1)

print("Shape of result1:", result1.shape)

result2 = processor.pad_to_square(np.array(image))

assert result1.shape == result2.shape

np.testing.assert_allclose(result1, result2)
2 changes: 1 addition & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .utils import is_torch_xla_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm]
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.get_logger(__name__)

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_vision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class LlavaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class LlavaNextImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

Expand Down

0 comments on commit 29175dc

Please sign in to comment.