From 0e8e99e214018e63591334a5daa7949537e6a22b Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 22:16:30 +0800 Subject: [PATCH] add tests for video llava --- .../video_llava/processing_video_llava.py | 2 ++ .../video_llava/test_processor_video_llava.py | 17 +++++++++++++++++ tests/test_processing_common.py | 3 +-- 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 tests/models/video_llava/test_processor_video_llava.py diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 35eab8bdc14060..774c4003f3cb0f 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -136,6 +136,8 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + # Temporary fix for "paddding_side" in init_kwargs + _ = output_kwargs["text_kwargs"].pop("padding_side", None) data = {} if images is not None or videos is not None: diff --git a/tests/models/video_llava/test_processor_video_llava.py b/tests/models/video_llava/test_processor_video_llava.py new file mode 100644 index 00000000000000..9ddc84a6bcb944 --- /dev/null +++ b/tests/models/video_llava/test_processor_video_llava.py @@ -0,0 +1,17 @@ +import tempfile +import unittest + +from transformers.models.video_llava.processing_video_llava import VideoLlavaProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class VideoLlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "LanguageBind/Video-LLaVA-7B-hf" + processor_class = VideoLlavaProcessor + images_data_arg_name = "pixel_values_images" + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index ec1e211872e667..185e079bded800 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -242,8 +242,7 @@ def test_unstructured_kwargs_batched(self): ) self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214) - - self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6) + self.assertEqual(len(inputs[self.text_data_arg_name][0]), len(inputs[self.text_data_arg_name][1])) @require_torch @require_vision