From f4c3df9121d17991366d5afca8431666d102ddca Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 16:59:30 +0200 Subject: [PATCH] Update test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151e9..803942cb8c9841 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1868,6 +1868,9 @@ def test_new_cache_format(self, num_beams, do_sample): if config.is_encoder_decoder: cache_cls = EncoderDecoderCache past_key_values = cache_cls(DynamicCache(), DynamicCache()) + elif getattr(self.config, "sliding_window", None) is not None: + cache_cls = DynamicSlidingWindowCache + past_key_values = cache_cls(self.config.sliding_window) else: cache_cls = DynamicCache past_key_values = cache_cls()