Skip to content

Commit

Permalink
Uniformize kwargs for processors - GroundingDINO (#31964)
Browse files Browse the repository at this point in the history
* fix typo

* uniform kwargs

* make style

* add comments

* remove return_tensors

* remove common_kwargs from processor since it propagates

* make style

* return_token_type_ids to True

* revert the default imagekwargs since does not accept any value in the image processro

* revert processing_utils.py

* make style

* add molbap's commit

* fix typo

* fix common processor

* remain

* Revert "add molbap's commit"

This reverts commit a476c6e.

* add unsync PR

* revert

* make CI happy

* nit

* import annotationformat
  • Loading branch information
SangbumChoi committed Aug 8, 2024
1 parent e28784f commit d3b3551
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def _set_gradient_checkpointing(self, module, value=False):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`GroundingDinoTokenizer.__call__`] for details.
Indices can be obtained using [`AutoTokenizer`]. See [`BertTokenizer.__call__`] for details.
token_type_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
Expand Down
89 changes: 53 additions & 36 deletions src/transformers/models/grounding_dino/processing_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,32 @@
Processor class for Grounding DINO.
"""

from typing import List, Optional, Tuple, Union
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...image_utils import AnnotationFormat, ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


if is_torch_available():
import torch


AnnotationType = Dict[str, Union[int, str, List[Dict]]]


def get_phrases_from_posmap(posmaps, input_ids):
"""Get token ids of phrases from posmaps and input_ids.
Expand All @@ -56,6 +68,31 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids


class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_convert_annotations: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]


class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: GroundingDinoImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_length": False,
"verbose": True,
}
}


class GroundingDinoProcessor(ProcessorMixin):
r"""
Constructs a Grounding DINO processor which wraps a Deformable DETR image processor and a BERT tokenizer into a
Expand Down Expand Up @@ -83,21 +120,9 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = True,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
audio=None,
videos=None,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -106,32 +131,24 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
raise ValueError("You must specify either text or images.")

output_kwargs = self._merge_kwargs(
GroundingDinoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# Get only text
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
encoding_image_processor = BatchFeature()

if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["text_kwargs"],
)
else:
text_encoding = BatchEncoding()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,12 @@ def _merge_kwargs(
The order of operations is as follows:
1) kwargs passed as before have highest priority to preserve BC.
```python
high_priority_kwargs = {"crop_size" = (224, 224), "padding" = "max_length"}
high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
processor(..., **high_priority_kwargs)
```
2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
```python
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": (224, 224)}})
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
```
3) kwargs passed during instantiation of a modality processor have fourth priority.
```python
Expand Down
188 changes: 187 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_torch_available():
import torch
Expand All @@ -40,7 +42,10 @@

@require_torch
@require_vision
class GroundingDinoProcessorTest(unittest.TestCase):
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "IDEA-Research/grounding-dino-base"
processor_class = GroundingDinoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand All @@ -63,6 +68,13 @@ def setUp(self):
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)

image_processor = GroundingDinoImageProcessor()
tokenizer = BertTokenizer.from_pretrained(self.from_pretrained_id)

processor = GroundingDinoProcessor(image_processor, tokenizer)

processor.save_pretrained(self.tmpdirname)

self.batch_size = 7
self.num_queries = 5
self.embed_dim = 5
Expand Down Expand Up @@ -251,3 +263,177 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 234, "width": 234})
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)

@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(
text=input_str, images=image_input, return_tensors="pt", padding="max_length", max_length=112
)
self.assertEqual(len(inputs["input_ids"][0]), 112)

@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, return_tensors="pt", padding="max_length")
self.assertEqual(len(inputs["input_ids"][0]), 117)

@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)

@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")

image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)

0 comments on commit d3b3551

Please sign in to comment.