Skip to content

Commit

Permalink
fix missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Sep 20, 2024
1 parent 2fdb5e7 commit e4d4429
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
from transformers.generation.utils import _speculative_sampling


@pytest.mark.generate
class GenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
Expand Down Expand Up @@ -2035,6 +2034,7 @@ def test_generate_compile_fullgraph(self):
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())

@pytest.mark.generate
def test_generate_methods_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
Expand Down Expand Up @@ -2063,6 +2063,7 @@ def test_generate_methods_with_num_logits_to_keep(self):
)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())

@pytest.mark.generate
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
Expand Down

0 comments on commit e4d4429

Please sign in to comment.