Skip to content

Commit

Permalink
don't hardcode arg names
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 24, 2024
1 parent e57e988 commit ed0e8aa
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
@require_torch
class ProcessorTesterMixin:
processor_class = None
text_data_arg_name = "input_ids"
images_data_arg_name = "pixel_values"

def prepare_processor_dict(self):
return {}
Expand Down Expand Up @@ -136,7 +138,7 @@ def test_tokenizer_defaults_preserved_by_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 117)

@require_torch
@require_vision
Expand All @@ -153,7 +155,7 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 234)

@require_vision
@require_torch
Expand All @@ -171,7 +173,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self):
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 112)

@require_torch
@require_vision
Expand All @@ -188,7 +190,7 @@ def test_kwargs_overrides_default_image_processor_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, size=[224, 224], crop_size=(224, 224))
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 224)

@require_torch
@require_vision
Expand All @@ -213,8 +215,8 @@ def test_unstructured_kwargs(self):
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -239,9 +241,9 @@ def test_unstructured_kwargs_batched(self):
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6)

@require_torch
@require_vision
Expand Down Expand Up @@ -292,9 +294,9 @@ def test_structured_kwargs_nested(self):
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -321,9 +323,9 @@ def test_structured_kwargs_nested_from_dict(self):
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)


class MyProcessor(ProcessorMixin):
Expand Down

0 comments on commit ed0e8aa

Please sign in to comment.