Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for processors - GroundingDINO #31964

Merged
merged 23 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/grounding-dino.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ import requests

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

model_id = "IDEA-Research/grounding-dino-tiny"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)

image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Check for cats and remote controls
text = "a cat. a remote control."

inputs = processor(images=image, text=text, return_tensors="pt").to(device)
inputs = processor(images=image, text=text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)

Expand Down
76 changes: 45 additions & 31 deletions src/transformers/models/grounding_dino/processing_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
Processor class for Grounding DINO.
"""

import sys
from typing import List, Optional, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...processing_utils import ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


Expand Down Expand Up @@ -56,6 +64,26 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids


class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": None,
"max_length": None,
"stride": 0,
"pad_to_multiple_of": None,
"return_attention_mask": None,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_length": False,
"verbose": True,
}
}

SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved

class GroundingDinoProcessor(ProcessorMixin):
r"""
Constructs a Grounding DINO processor which wraps a Deformable DETR image processor and a BERT tokenizer into a
Expand Down Expand Up @@ -83,21 +111,8 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = True,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -108,30 +123,29 @@ def __call__(
if images is None and text is None:
raise ValueError("You have to specify either images or text.")

output_kwargs = self._merge_kwargs(
GroundingDinoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)

# Get only text
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding_image_processor = self.image_processor(
images, return_tensors=return_tensors, **output_kwargs["images_kwargs"]
)
else:
encoding_image_processor = BatchFeature()

if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["text_kwargs"],
)
else:
text_encoding = BatchEncoding()
Expand Down