Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for processors - GroundingDINO #31964

Merged
merged 23 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/grounding-dino.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ import requests

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

model_id = "IDEA-Research/grounding-dino-tiny"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)

image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Check for cats and remote controls
text = "a cat. a remote control."

inputs = processor(images=image, text=text, return_tensors="pt").to(device)
inputs = processor(images=image, text=text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def _set_gradient_checkpointing(self, module, value=False):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.

Indices can be obtained using [`AutoTokenizer`]. See [`GroundingDinoTokenizer.__call__`] for details.
Indices can be obtained using [`AutoTokenizer`]. See [`BertTokenizer.__call__`] for details.

token_type_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
Expand Down
94 changes: 58 additions & 36 deletions src/transformers/models/grounding_dino/processing_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,35 @@
Processor class for Grounding DINO.
"""

from typing import List, Optional, Tuple, Union
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_available
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import ExplicitEnum, TensorType, is_torch_available


if is_torch_available():
import torch

AnnotationType = Dict[str, Union[int, str, List[Dict]]]


class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"

SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved

def get_phrases_from_posmap(posmaps, input_ids):
"""Get token ids of phrases from posmaps and input_ids.
Expand Down Expand Up @@ -56,6 +72,32 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids


class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_convert_annotations: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]


class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
images_kwargs: GroundingDinoImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
}


class GroundingDinoProcessor(ProcessorMixin):
r"""
Constructs a Grounding DINO processor which wraps a Deformable DETR image processor and a BERT tokenizer into a
Expand Down Expand Up @@ -83,21 +125,9 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = True,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
audio=None,
videos=None,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -106,32 +136,24 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
raise ValueError("You must specify either text or images.")

output_kwargs = self._merge_kwargs(
GroundingDinoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# Get only text
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
encoding_image_processor = BatchFeature()

if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["text_kwargs"],
)
else:
text_encoding = BatchEncoding()
Expand Down
83 changes: 36 additions & 47 deletions src/transformers/processing_utils.py
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class methods and docstrings.
Standard deviation to use if normalizing the image.
do_pad (`bool`, *optional*):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
data_format (`ChannelDimension` or `str`, *optional*):
Expand All @@ -170,6 +172,7 @@ class methods and docstrings.
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
Expand Down Expand Up @@ -321,7 +324,6 @@ class ProcessorMixin(PushToHubMixin):
feature_extractor_class = None
tokenizer_class = None
_auto_class = None
valid_kwargs: List[str] = []

# args have to match the attributes class attribute
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -700,15 +702,14 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
chat_template = kwargs.pop("chat_template", None)

# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
# If we don't pop, some specific kwargs will raise a warning
# Unlike image processors or feature extractors whose `__init__` accept `kwargs`, processor don't have `kwargs`.
# We have to pop up some unused (but specific) arguments to make it work.
if "processor_class" in processor_dict:
del processor_dict["processor_class"]

if "auto_map" in processor_dict:
del processor_dict["auto_map"]

unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
processor = cls(*args, **processor_dict)
if chat_template is not None:
setattr(processor, "chat_template", chat_template)
Expand All @@ -718,7 +719,6 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
if hasattr(processor, key):
setattr(processor, key, kwargs.pop(key))

kwargs.update(unused_kwargs)
logger.info(f"Processor {processor}")
if return_unused_kwargs:
return processor, kwargs
Expand All @@ -736,12 +736,12 @@ def _merge_kwargs(
The order of operations is as follows:
1) kwargs passed as before have highest priority to preserve BC.
```python
high_priority_kwargs = {"crop_size" = (224, 224), "padding" = "max_length"}
high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
processor(..., **high_priority_kwargs)
```
2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
```python
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": (224, 224)}})
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
```
3) kwargs passed during instantiation of a modality processor have fourth priority.
```python
Expand Down Expand Up @@ -796,38 +796,40 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
if modality_key in tokenizer_init_kwargs:
default_kwargs[modality][modality_key] = tokenizer_init_kwargs[modality_key]
# now defaults kwargs are updated with the tokenizers defaults.
# pass defaults to output dictionary
output_kwargs.update(default_kwargs)

# gather common kwargs and remove them from individual kwargs if present
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
common_kwargs = {}
for key, value in kwargs.items():
if key == "common_kwargs":
for common_key, common_value in value.items():
common_kwargs[common_key] = common_value
elif key in ["text_kwargs", "images_kwargs", "audio_kwargs", "videos_kwargs"]:
pass
elif (
key not in ModelProcessorKwargs.__annotations__["text_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["images_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["audio_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["videos_kwargs"].__annotations__
):
common_kwargs[key] = value

# ensure common kwargs are propagated to all relevant modalities
for key, value in common_kwargs.items():
for modality in output_kwargs:
if modality != "common_kwargs":
output_kwargs[modality][key] = value

# remove common kwargs from the kwargs to process the rest
kwargs = {k: v for k, v in kwargs.items() if k not in common_kwargs}

# update modality kwargs with passed kwargs
non_modality_kwargs = set(kwargs) - set(output_kwargs)
for modality in output_kwargs:
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
# check if we received a structured kwarg dict or not to handle it correctly
if modality in kwargs:
kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
# check if this key was passed as a flat kwarg.
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
raise ValueError(
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
)
elif modality_key in kwargs:
kwarg_value = kwargs.pop(modality_key, "__empty__")
else:
kwarg_value = "__empty__"
if kwarg_value != "__empty__":
output_kwargs[modality][modality_key] = kwarg_value
# if something remains in kwargs, it belongs to common after flattening
if set(kwargs) & set(default_kwargs):
# here kwargs is dictionary-based since it shares keys with default set
[output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()]
else:
# here it's a flat dict
output_kwargs["common_kwargs"].update(kwargs)

# all modality-specific kwargs are updated with common kwargs
for modality in output_kwargs:
output_kwargs[modality].update(output_kwargs["common_kwargs"])
if modality_key in kwargs:
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
output_kwargs[modality][modality_key] = kwargs[modality_key]
elif modality in kwargs and modality_key in kwargs[modality]:
output_kwargs[modality][modality_key] = kwargs[modality][modality_key]
return output_kwargs

@classmethod
Expand Down Expand Up @@ -943,19 +945,6 @@ def model_input_names(self):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)

@staticmethod
def validate_init_kwargs(processor_config, valid_kwargs):
kwargs_from_config = processor_config.keys()
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
)
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs

SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
Expand Down
41 changes: 40 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_torch_available():
import torch
Expand All @@ -40,7 +42,10 @@

@require_torch
@require_vision
class GroundingDinoProcessorTest(unittest.TestCase):
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "IDEA-Research/grounding-dino-base"
processor_class = GroundingDinoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand All @@ -63,6 +68,13 @@ def setUp(self):
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)

image_processor = GroundingDinoImageProcessor()
tokenizer = BertTokenizer.from_pretrained(self.from_pretrained_id)

processor = GroundingDinoProcessor(image_processor, tokenizer)

processor.save_pretrained(self.tmpdirname)

self.batch_size = 7
self.num_queries = 5
self.embed_dim = 5
Expand Down Expand Up @@ -251,3 +263,30 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)
Loading