Skip to content

Commit

Permalink
fixes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Aug 15, 2024
1 parent d9dbd6b commit 4756044
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 30 deletions.
16 changes: 16 additions & 0 deletions docs/source/en/model_doc/idefics3.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## Idefics3Config

[[autodoc]] Idefics3Config


## Idefics3Model

[[autodoc]] Idefics3Model
- forward

## Idefics3ForConditionalGeneration

[[autodoc]] Idefics3ForConditionalGeneration
- forward


## Idefics3ImageProcessor
[[autodoc]] Idefics3ImageProcessor
- preprocess
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/idefics3/configuration_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class Idefics3Config(PretrainedConfig):
The scale factor for the image encoder.
pad_token_id (`int`, *optional*, defaults to 128002):
The id of the padding token.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum length of the input sequence.
Example:
```python
Expand All @@ -178,7 +176,6 @@ def __init__(
text_config=None,
scale_factor=2,
pad_token_id=128_002,
max_position_embeddings=131_072,
**kwargs,
):
self.image_token_id = image_token_id
Expand All @@ -199,7 +196,6 @@ def __init__(
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["llama"](
max_position_embeddings=max_position_embeddings,
rms_norm_eps=1e-5,
pad_token_id=pad_token_id,
tie_word_embeddings=False,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,6 @@ class Idefics3ImageProcessor(BaseImageProcessor):
do_pad (`bool`, *optional*, defaults to `True`):
Whether or not to pad the images to the largest height and width in the batch and number of images per
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
vision_encoder_max_size (`int`, *optional*, defaults to `364`):
Maximum size of the images accepted by the vision encoder. The images are split into patches of this size.
"""

model_input_names = ["pixel_values"]
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
Processor class for Idefics3.
"""

import sys
from typing import TYPE_CHECKING, List, Optional, Union
import re
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_image_prompt_string(
class Idefics3ImagesKwargs(ImagesKwargs, total=False):
image_seq_len: Optional[int]
return_row_col_info: Optional[bool]
max_image_size: Optional[dict[str, int]]
max_image_size: Optional[Dict[str, int]]


class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, ch
self.global_img_token = "<global-img>"
self.image_seq_len = image_seq_len

self._regex_to_remove_extra_special_tokens = re.compile(r'(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+')
self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+")

tokens_to_add = {
"additional_special_tokens": [
Expand Down Expand Up @@ -356,7 +356,6 @@ def decode(self, *args, **kwargs):
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)


@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4817,35 +4817,42 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3ForConditionalGeneration(metaclass=DummyObject):
class Idefics2Model(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2Model(metaclass=DummyObject):
class Idefics2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2PreTrainedModel(metaclass=DummyObject):
class Idefics2Processor(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3PreTrainedModel(metaclass=DummyObject):
class Idefics3ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics2Processor(metaclass=DummyObject):
class Idefics3Model(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class Idefics3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
18 changes: 10 additions & 8 deletions tests/models/idefics3/test_modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

from transformers import (
AutoProcessor,
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3Model,
is_torch_available,
is_vision_available,
)
Expand All @@ -38,6 +35,12 @@

if is_torch_available():
import torch

from transformers import (
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3Model,
)
else:
is_torch_greater_or_equal_than_2_0 = False

Expand Down Expand Up @@ -483,13 +486,13 @@ def tearDown(self):
torch.cuda.empty_cache()

@slow
@unittest.skip("Test hits OOM on CI - Same as idefics2 - https://github.com/huggingface/transformers/issues/32288")
def test_integration_test(self):
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.to(torch_device)

# Create inputs
text = "<image>In this image, we see"
Expand All @@ -500,16 +503,15 @@ def test_integration_test(self):
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

# Batch affects generated text. Single batch output: ['In this image, we see the Statue of Liberty in the foreground and']
expected_generated_text = "In this image, we see the Statue of Liberty, the New York City"
expected_generated_text = "<image>In this image, we see the Statue of Liberty, which is located on Liberty"
self.assertEqual(generated_texts[0], expected_generated_text)

@slow
@require_bitsandbytes
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3", load_in_4bit=True, device_map="auto"
"HuggingFaceM4/Idefics3-8B-Llama3", load_in_4bit=True
)

# Create pixel inputs
Expand All @@ -520,5 +522,5 @@ def test_integration_test_4bit(self):
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
expected_generated_text = "<image>In this image, we see the Statue of Liberty, trees, buildings, water"
self.assertEqual(generated_texts[0], expected_generated_text)
11 changes: 5 additions & 6 deletions tests/models/idefics3/test_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setUp(self):
self.bos_token_id = processor.tokenizer.convert_tokens_to_ids(self.bos_token)
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(self.image_token)
self.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(self.fake_image_token)
self.global_img_token_id = processor.global_img_token_id
self.global_img_tokens_id = processor.tokenizer(self.global_img_token, add_special_tokens=False)["input_ids"]
self.padding_token_id = processor.tokenizer.pad_token_id
self.image_seq_len = processor.image_seq_len

Expand Down Expand Up @@ -96,7 +96,7 @@ def get_splitted_image_expected_tokens(self, processor, image_rows, image_cols):
] # add double newline, as it gets its own token
text_split_images += (
[self.fake_image_token_id]
+ [self.global_img_token_id]
+ self.global_img_tokens_id
+ [self.image_token_id] * self.image_seq_len
+ [self.fake_image_token_id]
)
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_process_interleaved_images_prompts_no_image_splitting(self):

# fmt: off
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + [self.fake_image_token_id] + [self.global_img_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
expected_input_ids = [[self.bos_token_id] + [self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 3, 1092, 1456))
Expand All @@ -147,7 +147,7 @@ def test_process_interleaved_images_prompts_no_image_splitting(self):
# fmt: off
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
image_tokens = [self.fake_image_token_id] + [self.global_img_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
image_tokens = [self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
expected_input_ids_1 = [self.bos_token_id] + image_tokens + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + 2 * image_tokens + tokenized_sentence_2["input_ids"]
# Pad the first input to match the second input
Expand Down Expand Up @@ -424,7 +424,6 @@ def test_unstructured_kwargs_batched(self):

input_str = ["<image>lower newer", "<image>upper older longer string"]
image_input = self.prepare_image_inputs()
print(image_input)
inputs = processor(
text=input_str,
images=[image_input, image_input],
Expand All @@ -436,7 +435,7 @@ def test_unstructured_kwargs_batched(self):

self.assertEqual(inputs["pixel_values"].shape[2], 3)
self.assertEqual(inputs["pixel_values"].shape[3], 214)
self.assertEqual(len(inputs["input_ids"][0]), 88)
self.assertEqual(len(inputs["input_ids"][0]), 91)

# We need to overwrite this test to adapt it to our processor.
@require_torch
Expand Down
1 change: 1 addition & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"SeamlessM4Tv2TextToUnitModel",
"SeamlessM4Tv2CodeHifiGan",
"SeamlessM4Tv2TextToUnitForConditionalGeneration",
"Idefics3VisionTransformer",
]

# Update this list for models that are not tested with a comment explaining the reason it should not be.
Expand Down

0 comments on commit 4756044

Please sign in to comment.