Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for processors - GroundingDINO #31964

Merged
merged 23 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/grounding-dino.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ import requests

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

model_id = "IDEA-Research/grounding-dino-tiny"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)

image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Check for cats and remote controls
text = "a cat. a remote control."

inputs = processor(images=image, text=text, return_tensors="pt").to(device)
inputs = processor(images=image, text=text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)

Expand Down
78 changes: 43 additions & 35 deletions src/transformers/models/grounding_dino/processing_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
Processor class for Grounding DINO.
"""

from typing import List, Optional, Tuple, Union
import sys
from typing import List, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...processing_utils import ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


Expand Down Expand Up @@ -56,6 +64,26 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids


class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_length": False,
"verbose": True,
},
"images_kwargs": {
"do_convert_annotations": True,
"do_resize": True,
},
}
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved


class GroundingDinoProcessor(ProcessorMixin):
r"""
Constructs a Grounding DINO processor which wraps a Deformable DETR image processor and a BERT tokenizer into a
Expand Down Expand Up @@ -83,21 +111,9 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = True,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
audio=None,
videos=None,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -106,32 +122,24 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
raise ValueError("You must specify either text or images.")

output_kwargs = self._merge_kwargs(
GroundingDinoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# Get only text
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
encoding_image_processor = BatchFeature()

if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["text_kwargs"],
)
else:
text_encoding = BatchEncoding()
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/processing_utils.py
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import inspect
import json
import os
import pathlib
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
Expand All @@ -40,6 +41,7 @@
)
from .utils import (
PROCESSOR_NAME,
ExplicitEnum,
PushToHubMixin,
TensorType,
add_model_info_to_auto_map,
Expand All @@ -56,6 +58,14 @@

logger = logging.get_logger(__name__)

AnnotationType = Dict[str, Union[int, str, List[Dict]]]


class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"


# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
transformers_module = direct_transformers_import(Path(__file__).parent)

Expand Down Expand Up @@ -128,6 +138,12 @@ class ImagesKwargs(TypedDict, total=False):
class methods and docstrings.

Attributes:
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images.
return_segmentation_masks (`bool`, *optional*):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*):
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
Whether to resize the image.
size (`Dict[str, int]`, *optional*):
Expand All @@ -144,6 +160,8 @@ class methods and docstrings.
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*):
Whether to convert the annotations to the format expected by the model.
image_mean (`float` or `List[float]`, *optional*):
Mean to use if normalizing the image.
image_std (`float` or `List[float]`, *optional*):
Expand All @@ -152,12 +170,19 @@ class methods and docstrings.
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
format (`str` or `AnnotationFormat`, *optional*):
Format of the annotations.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
"""

annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_resize: Optional[bool]
size: Optional[Dict[str, int]]
size_divisor: Optional[int]
Expand All @@ -166,12 +191,15 @@ class methods and docstrings.
do_rescale: Optional[bool]
rescale_factor: Optional[float]
do_normalize: Optional[bool]
do_convert_annotations: Optional[bool]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
do_center_crop: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
pad_size: Optional[Dict[str, int]]


class VideosKwargs(TypedDict, total=False):
Expand Down
33 changes: 32 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_torch_available():
import torch
Expand All @@ -40,7 +42,9 @@

@require_torch
@require_vision
class GroundingDinoProcessorTest(unittest.TestCase):
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = GroundingDinoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand Down Expand Up @@ -251,3 +255,30 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 11)