Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM generate: tests can't generate image/video tokens #33623

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _get_input_ids_and_config(self, batch_size=2):

return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
"repetition_penalty": 1.2,
Expand All @@ -146,6 +146,17 @@ def _get_logits_processor_kwargs(self, do_sample=False):
"temperature": 0.7,
}
)
# TODO (joao, raushan): see this comment for a long-term fix
# https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([video_token_index])

return logits_processor_kwargs

Expand Down Expand Up @@ -211,7 +222,7 @@ def _greedy_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -246,7 +257,7 @@ def _sample_generate(
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -281,7 +292,7 @@ def _beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -316,7 +327,7 @@ def _beam_sample_generate(
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -350,7 +361,7 @@ def _group_beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -385,7 +396,7 @@ def _constrained_beam_search_generate(
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down Expand Up @@ -424,7 +435,7 @@ def _contrastive_generate(
"top_k": 5,
}

logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _get_input_ids_and_config(self, batch_size=2):
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down Expand Up @@ -1485,7 +1485,7 @@ def _sample_generate(

return output_generate

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down
4 changes: 2 additions & 2 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_input_ids_and_config(self, batch_size=2):
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down Expand Up @@ -1469,7 +1469,7 @@ def _sample_generate(

return output_generate

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

Expand Down
4 changes: 2 additions & 2 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ def is_pipeline_test_to_skip(

return False

def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config)
logits_processor_kwargs["temperature"] = 0.0
return logits_processor_kwargs

Expand Down
Loading