Skip to content

Commit

Permalink
VLM generate: tests can't generate image/video tokens (#33623)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 20, 2024
1 parent 653eb40 commit 2fdb5e7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
28 changes: 20 additions & 8 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _get_input_ids_and_config(self, batch_size=2):

return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
"repetition_penalty": 1.2,
Expand All @@ -146,6 +146,17 @@ def _get_logits_processor_kwargs(self, do_sample=False):
"temperature": 0.7,
}
)
# TODO (joao, raushan): see this comment for a long-term fix
# https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([video_token_index])

return logits_processor_kwargs

Expand Down Expand Up @@ -211,7 +222,7 @@ def _greedy_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -246,7 +257,7 @@ def _sample_generate(
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -281,7 +292,7 @@ def _beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -316,7 +327,7 @@ def _beam_sample_generate(
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -350,7 +361,7 @@ def _group_beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -385,7 +396,7 @@ def _constrained_beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -424,7 +435,7 @@ def _contrastive_generate(
"top_k": 5,
}

logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -2052,6 +2063,7 @@ def test_generate_methods_with_num_logits_to_keep(self):
)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())

@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _get_input_ids_and_config(self, batch_size=2):
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down Expand Up @@ -1485,7 +1485,7 @@ def _sample_generate(

return output_generate

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down
4 changes: 2 additions & 2 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_input_ids_and_config(self, batch_size=2):
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down Expand Up @@ -1469,7 +1469,7 @@ def _sample_generate(

return output_generate

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down
4 changes: 2 additions & 2 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ def is_pipeline_test_to_skip(

return False

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config)
logits_processor_kwargs["temperature"] = 0.0
return logits_processor_kwargs

Expand Down

0 comments on commit 2fdb5e7

Please sign in to comment.