Skip to content

Commit

Permalink
fixes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Aug 16, 2024
1 parent d9dbd6b commit 483d5d8
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 69 deletions.
16 changes: 16 additions & 0 deletions docs/source/en/model_doc/idefics3.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## Idefics3Config

[[autodoc]] Idefics3Config


## Idefics3Model

[[autodoc]] Idefics3Model
- forward

## Idefics3ForConditionalGeneration

[[autodoc]] Idefics3ForConditionalGeneration
- forward


## Idefics3ImageProcessor
[[autodoc]] Idefics3ImageProcessor
- preprocess
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/idefics3/configuration_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class Idefics3Config(PretrainedConfig):
The scale factor for the image encoder.
pad_token_id (`int`, *optional*, defaults to 128002):
The id of the padding token.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum length of the input sequence.
Example:
```python
Expand All @@ -178,7 +176,6 @@ def __init__(
text_config=None,
scale_factor=2,
pad_token_id=128_002,
max_position_embeddings=131_072,
**kwargs,
):
self.image_token_id = image_token_id
Expand All @@ -199,7 +196,6 @@ def __init__(
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["llama"](
max_position_embeddings=max_position_embeddings,
rms_norm_eps=1e-5,
pad_token_id=pad_token_id,
tie_word_embeddings=False,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,6 @@ class Idefics3ImageProcessor(BaseImageProcessor):
do_pad (`bool`, *optional*, defaults to `True`):
Whether or not to pad the images to the largest height and width in the batch and number of images per
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
vision_encoder_max_size (`int`, *optional*, defaults to `364`):
Maximum size of the images accepted by the vision encoder. The images are split into patches of this size.
"""

model_input_names = ["pixel_values"]
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
Processor class for Idefics3.
"""

import sys
from typing import TYPE_CHECKING, List, Optional, Union
import re
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_image_prompt_string(
class Idefics3ImagesKwargs(ImagesKwargs, total=False):
image_seq_len: Optional[int]
return_row_col_info: Optional[bool]
max_image_size: Optional[dict[str, int]]
max_image_size: Optional[Dict[str, int]]


class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
Expand All @@ -111,6 +111,9 @@ class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
}


Idefics3ProcessorKwargs.__annotations__["images_kwargs"] = Idefics3ImagesKwargs # python 3.8 compatibility


class Idefics3Processor(ProcessorMixin):
r"""
Constructs a Idefics3 processor which wraps a LLama tokenizer and Idefics3 image processor into a single processor.
Expand Down Expand Up @@ -147,7 +150,7 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, ch
self.global_img_token = "<global-img>"
self.image_seq_len = image_seq_len

self._regex_to_remove_extra_special_tokens = re.compile(r'(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+')
self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+")

tokens_to_add = {
"additional_special_tokens": [
Expand All @@ -158,7 +161,7 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, ch
}
tokenizer.add_special_tokens(tokens_to_add)

super().__init__(image_processor, tokenizer, chat_template=chat_template)
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)

def _extract_images_from_prompts(self, prompts):
prompt_images = []
Expand Down Expand Up @@ -356,7 +359,6 @@ def decode(self, *args, **kwargs):
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)


@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,17 @@ class ModelProcessorKwargs(ProcessingKwargs, total=False):
}
```
For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs,
you need to manually update the __annotations__ dictionary. This can be done as follows:
```python
class CustomProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: CustomImagesKwargs
CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs # python 3.8 compatibility
```python
"""

common_kwargs: CommonKwargs = {
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4817,35 +4817,42 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3ForConditionalGeneration(metaclass=DummyObject):
class Idefics2Model(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2Model(metaclass=DummyObject):
class Idefics2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2PreTrainedModel(metaclass=DummyObject):
class Idefics2Processor(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3PreTrainedModel(metaclass=DummyObject):
class Idefics3ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2Processor(metaclass=DummyObject):
class Idefics3Model(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
23 changes: 14 additions & 9 deletions tests/models/idefics3/test_modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@

from transformers import (
AutoProcessor,
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3Model,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_multi_gpu, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand All @@ -38,6 +35,12 @@

if is_torch_available():
import torch

from transformers import (
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3Model,
)
else:
is_torch_greater_or_equal_than_2_0 = False

Expand Down Expand Up @@ -483,13 +486,13 @@ def tearDown(self):
torch.cuda.empty_cache()

@slow
@require_torch_multi_gpu
def test_integration_test(self):
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.to(torch_device)

# Create inputs
text = "<image>In this image, we see"
Expand All @@ -500,16 +503,18 @@ def test_integration_test(self):
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

# Batch affects generated text. Single batch output: ['In this image, we see the Statue of Liberty in the foreground and']
expected_generated_text = "In this image, we see the Statue of Liberty, the New York City"
expected_generated_text = "<image>In this image, we see the Statue of Liberty, which is located on Liberty"
self.assertEqual(generated_texts[0], expected_generated_text)

@slow
@require_bitsandbytes
@require_torch_multi_gpu
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3", load_in_4bit=True, device_map="auto"
"HuggingFaceM4/Idefics3-8B-Llama3",
load_in_4bit=True,
device_map="auto",
)

# Create pixel inputs
Expand All @@ -520,5 +525,5 @@ def test_integration_test_4bit(self):
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
expected_generated_text = "<image>In this image, we see the Statue of Liberty, trees, buildings, water"
self.assertEqual(generated_texts[0], expected_generated_text)
Loading

0 comments on commit 483d5d8

Please sign in to comment.