Skip to content

Commit

Permalink
add tests for video llava
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 16, 2024
1 parent e785a71 commit 0e8e99e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/video_llava/processing_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "paddding_side" in init_kwargs
_ = output_kwargs["text_kwargs"].pop("padding_side", None)

data = {}
if images is not None or videos is not None:
Expand Down
17 changes: 17 additions & 0 deletions tests/models/video_llava/test_processor_video_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tempfile
import unittest

from transformers.models.video_llava.processing_video_llava import VideoLlavaProcessor

from ...test_processing_common import ProcessorTesterMixin


class VideoLlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "LanguageBind/Video-LLaVA-7B-hf"
processor_class = VideoLlavaProcessor
images_data_arg_name = "pixel_values_images"

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)
3 changes: 1 addition & 2 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def test_unstructured_kwargs_batched(self):
)

self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), len(inputs[self.text_data_arg_name][1]))

@require_torch
@require_vision
Expand Down

0 comments on commit 0e8e99e

Please sign in to comment.