-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds uniform processing kwargs to paligemma. #32377
base: main
Are you sure you want to change the base?
Changes from all commits
bfe6445
ea95baf
3c89e17
d62f00b
d5a5eed
b51e05f
580ef32
77baadc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,19 +17,22 @@ | |||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||
from typing import List, Optional, Union | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if sys.version_info >= (3, 11): | ||||||||||||||||||||||||
from typing import Unpack | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
from typing_extensions import Unpack | ||||||||||||||||||||||||
from ...feature_extraction_utils import BatchFeature | ||||||||||||||||||||||||
from ...image_utils import ImageInput, is_valid_image | ||||||||||||||||||||||||
from ...processing_utils import ProcessorMixin | ||||||||||||||||||||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs | ||||||||||||||||||||||||
from ...tokenization_utils_base import ( | ||||||||||||||||||||||||
AddedToken, | ||||||||||||||||||||||||
PaddingStrategy, | ||||||||||||||||||||||||
PreTokenizedInput, | ||||||||||||||||||||||||
TextInput, | ||||||||||||||||||||||||
TruncationStrategy, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
from ...utils import TensorType | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||
|
@@ -73,6 +76,31 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): | |||||||||||||||||||||||
return f"{image_token * image_seq_len}{bos_token}{prompt}\n" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class PaliGemmaTextKwargs(TextKwargs): | ||||||||||||||||||||||||
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class PaliGemmaImagesKwargs(ImagesKwargs): | ||||||||||||||||||||||||
do_convert_rgb: Optional[bool] | ||||||||||||||||||||||||
do_thumbnail: Optional[bool] | ||||||||||||||||||||||||
do_align_long_axis: Optional[bool] | ||||||||||||||||||||||||
do_rescale: Optional[bool] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): | ||||||||||||||||||||||||
text_kwargs: PaliGemmaTextKwargs | ||||||||||||||||||||||||
image_kwargs: PaliGemmaImagesKwargs | ||||||||||||||||||||||||
_defaults = { | ||||||||||||||||||||||||
"text_kwargs": { | ||||||||||||||||||||||||
"tokenize_newline_separately": True, | ||||||||||||||||||||||||
"padding": False, | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
"image_kwargs": { | ||||||||||||||||||||||||
"data_format": "channels_first", | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class PaliGemmaProcessor(ProcessorMixin): | ||||||||||||||||||||||||
r""" | ||||||||||||||||||||||||
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. | ||||||||||||||||||||||||
|
@@ -124,25 +152,8 @@ def __call__( | |||||||||||||||||||||||
self, | ||||||||||||||||||||||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, | ||||||||||||||||||||||||
images: ImageInput = None, | ||||||||||||||||||||||||
Comment on lines
153
to
154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two inputs should be reversed and support for backward compatibility should be added. This should be similar to what is needed for Fuyu: transformers/src/transformers/models/fuyu/processing_fuyu.py Lines 522 to 532 in aa3bc0b
|
||||||||||||||||||||||||
tokenize_newline_separately: bool = True, | ||||||||||||||||||||||||
padding: Union[bool, str, PaddingStrategy] = False, | ||||||||||||||||||||||||
truncation: Union[bool, str, TruncationStrategy] = None, | ||||||||||||||||||||||||
max_length=None, | ||||||||||||||||||||||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, | ||||||||||||||||||||||||
do_resize: bool = None, | ||||||||||||||||||||||||
do_normalize: bool = None, | ||||||||||||||||||||||||
image_mean: Optional[Union[float, List[float]]] = None, | ||||||||||||||||||||||||
image_std: Optional[Union[float, List[float]]] = None, | ||||||||||||||||||||||||
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 | ||||||||||||||||||||||||
input_data_format: Optional[ | ||||||||||||||||||||||||
Union[str, "ChannelDimension"] # noqa: F821 | ||||||||||||||||||||||||
] = None, | ||||||||||||||||||||||||
resample: "PILImageResampling" = None, # noqa: F821 | ||||||||||||||||||||||||
do_convert_rgb: bool = None, | ||||||||||||||||||||||||
do_thumbnail: bool = None, | ||||||||||||||||||||||||
do_align_long_axis: bool = None, | ||||||||||||||||||||||||
do_rescale: bool = None, | ||||||||||||||||||||||||
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, | ||||||||||||||||||||||||
video=None, | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can also advertise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
**kwargs: Unpack[PaliGemmaProcessorKwargs], | ||||||||||||||||||||||||
) -> BatchFeature: | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` | ||||||||||||||||||||||||
|
@@ -216,7 +227,12 @@ def __call__( | |||||||||||||||||||||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | ||||||||||||||||||||||||
- **labels** -- Labels compatible with training if `suffix` is not None | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
output_kwargs = self._merge_kwargs( | ||||||||||||||||||||||||
PaliGemmaProcessorKwargs, | ||||||||||||||||||||||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs, | ||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
suffix = output_kwargs["text_kwargs"]["suffix"] | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If suffix is not specified as a kwargs, this will cause an error. Better to use: |
||||||||||||||||||||||||
return_token_type_ids = True if suffix is not None else False | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if images is None: | ||||||||||||||||||||||||
|
@@ -253,27 +269,14 @@ def __call__( | |||||||||||||||||||||||
|
||||||||||||||||||||||||
pixel_values = self.image_processor( | ||||||||||||||||||||||||
images, | ||||||||||||||||||||||||
do_resize=do_resize, | ||||||||||||||||||||||||
do_normalize=do_normalize, | ||||||||||||||||||||||||
return_tensors=return_tensors, | ||||||||||||||||||||||||
image_mean=image_mean, | ||||||||||||||||||||||||
image_std=image_std, | ||||||||||||||||||||||||
input_data_format=input_data_format, | ||||||||||||||||||||||||
data_format=data_format, | ||||||||||||||||||||||||
resample=resample, | ||||||||||||||||||||||||
do_convert_rgb=do_convert_rgb, | ||||||||||||||||||||||||
**kwargs["image_kwargs"], | ||||||||||||||||||||||||
)["pixel_values"] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if max_length is not None: | ||||||||||||||||||||||||
max_length += self.image_seq_length # max_length has to account for the image tokens | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if output_kwargs["text_kwargs"].get("max_length", None) is not None: | ||||||||||||||||||||||||
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length | ||||||||||||||||||||||||
inputs = self.tokenizer( | ||||||||||||||||||||||||
input_strings, | ||||||||||||||||||||||||
text_pair=suffix, | ||||||||||||||||||||||||
return_tensors=return_tensors, | ||||||||||||||||||||||||
padding=padding, | ||||||||||||||||||||||||
max_length=max_length, | ||||||||||||||||||||||||
truncation=truncation, | ||||||||||||||||||||||||
**output_kwargs["text_kwargs"], | ||||||||||||||||||||||||
return_token_type_ids=return_token_type_ids, | ||||||||||||||||||||||||
Comment on lines
+279
to
280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,209 @@ | ||||||||||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||||||||||
# | ||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
# you may not use this file except in compliance with the License. | ||||||||||
# You may obtain a copy of the License at | ||||||||||
# | ||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
# | ||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
# See the License for the specific language governing permissions and | ||||||||||
# limitations under the License. | ||||||||||
import shutil | ||||||||||
import tempfile | ||||||||||
import unittest | ||||||||||
|
||||||||||
import numpy as np | ||||||||||
|
||||||||||
from transformers.testing_utils import require_torch, require_vision | ||||||||||
from transformers.utils import is_vision_available | ||||||||||
|
||||||||||
from ...test_processing_common import ProcessorTesterMixin | ||||||||||
|
||||||||||
|
||||||||||
if is_vision_available(): | ||||||||||
from PIL import Image | ||||||||||
|
||||||||||
from transformers import ( | ||||||||||
AutoProcessor, | ||||||||||
PaliGemmaProcessor, | ||||||||||
is_vision_available, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
@require_vision | ||||||||||
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase): | ||||||||||
processor_class = PaliGemmaProcessor | ||||||||||
|
||||||||||
def setUp(self): | ||||||||||
self.tmpdirname = tempfile.mkdtemp() | ||||||||||
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will indeed cause a gated repo issue. you could rebuild a processor without using this repo, something like
Suggested change
Where Not sure if that's the nicest way to fix this though, any idea @zucchini-nlp @molbap ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CI token can be updated so that it can read this repo: in self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224") so that should not be an issue already - any idea @ydshieh here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, what is the issue here? This repo seem to be public no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait, are we talking about google/paligemma-3b-pt-224 or google/siglip-so400m-patch14-384 but both are accessible even if I am using a firefox private window |
||||||||||
processor.save_pretrained(self.tmpdirname) | ||||||||||
|
||||||||||
def get_tokenizer(self, **kwargs): | ||||||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer | ||||||||||
|
||||||||||
def get_image_processor(self, **kwargs): | ||||||||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor | ||||||||||
|
||||||||||
def tearDown(self): | ||||||||||
shutil.rmtree(self.tmpdirname) | ||||||||||
|
||||||||||
def prepare_image_inputs(self): | ||||||||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, | ||||||||||
or a list of PyTorch tensors if one specifies torchify=True. | ||||||||||
""" | ||||||||||
|
||||||||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] | ||||||||||
|
||||||||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] | ||||||||||
|
||||||||||
return image_inputs | ||||||||||
Comment on lines
+54
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm noticing more of this function across the repo and it's identical in 17 places, I think we can move it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this one works, I usually use it for image processor tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! Let's move. I also have my own personal agenda to remove the "numpify" and "torchify" arguments which are confusing, clash and inconsistent so would be a good opportunity for that |
||||||||||
|
||||||||||
# Some kwargs tests are overriden from common tests to handle shortest_edge | ||||||||||
# and size_divisor behaviour | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_image_processor_defaults_preserved_by_image_kwargs(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
image_processor = self.get_component( | ||||||||||
"image_processor", | ||||||||||
crop_size={"shortest_edge": 234, "longest_edge": 234}, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may not be needed anymore as the base tests were changed recently, same for other tests. Please fetch and rebase on upstream main :) |
||||||||||
) | ||||||||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") | ||||||||||
|
||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
input_str = "lower newer" | ||||||||||
image_input = self.prepare_image_inputs() | ||||||||||
|
||||||||||
inputs = processor(text=input_str, images=image_input) | ||||||||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234) | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_structured_kwargs_nested_from_dict(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
|
||||||||||
image_processor = self.get_component("image_processor") | ||||||||||
tokenizer = self.get_component("tokenizer") | ||||||||||
|
||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
input_str = "lower newer" | ||||||||||
image_input = self.prepare_image_inputs() | ||||||||||
|
||||||||||
# Define the kwargs for each modality | ||||||||||
all_kwargs = { | ||||||||||
"common_kwargs": {"return_tensors": "pt"}, | ||||||||||
"images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"}, | ||||||||||
"text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False}, | ||||||||||
} | ||||||||||
|
||||||||||
inputs = processor(text=input_str, images=image_input, **all_kwargs) | ||||||||||
self.assertEqual(inputs["pixel_values"].shape[2], 214) | ||||||||||
|
||||||||||
self.assertEqual(len(inputs["input_ids"][0]), 76) | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_kwargs_overrides_default_image_processor_kwargs(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234}) | ||||||||||
tokenizer = self.get_component("tokenizer", max_length=117) | ||||||||||
if not tokenizer.pad_token: | ||||||||||
tokenizer.pad_token = "[TEST_PAD]" | ||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
input_str = "lower newer" | ||||||||||
image_input = self.prepare_image_inputs() | ||||||||||
inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224}) | ||||||||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224) | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_unstructured_kwargs_batched(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
image_processor = self.get_component("image_processor") | ||||||||||
tokenizer = self.get_component("tokenizer") | ||||||||||
if not tokenizer.pad_token: | ||||||||||
tokenizer.pad_token = "[TEST_PAD]" | ||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
input_str = ["lower newer", "upper older longer string"] | ||||||||||
image_input = self.prepare_image_inputs() * 2 | ||||||||||
inputs = processor( | ||||||||||
text=input_str, | ||||||||||
images=image_input, | ||||||||||
return_tensors="pt", | ||||||||||
crop_size={"shortest_edge": 214}, | ||||||||||
padding="longest", | ||||||||||
max_length=76, | ||||||||||
) | ||||||||||
self.assertEqual(inputs["pixel_values"].shape[2], 214) | ||||||||||
|
||||||||||
self.assertEqual(len(inputs["input_ids"][0]), 6) | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_unstructured_kwargs(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
image_processor = self.get_component("image_processor") | ||||||||||
tokenizer = self.get_component("tokenizer") | ||||||||||
if not tokenizer.pad_token: | ||||||||||
tokenizer.pad_token = "[TEST_PAD]" | ||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
input_str = "lower newer" | ||||||||||
image_input = self.prepare_image_inputs() | ||||||||||
inputs = processor( | ||||||||||
text=input_str, | ||||||||||
images=image_input, | ||||||||||
return_tensors="pt", | ||||||||||
crop_size={"shortest_edge": 214}, | ||||||||||
padding="max_length", | ||||||||||
max_length=76, | ||||||||||
) | ||||||||||
|
||||||||||
self.assertEqual(inputs["pixel_values"].shape[2], 214) | ||||||||||
self.assertEqual(len(inputs["input_ids"][0]), 76) | ||||||||||
|
||||||||||
@require_torch | ||||||||||
@require_vision | ||||||||||
def test_structured_kwargs_nested(self): | ||||||||||
if "image_processor" not in self.processor_class.attributes: | ||||||||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}") | ||||||||||
image_processor = self.get_component("image_processor") | ||||||||||
tokenizer = self.get_component("tokenizer") | ||||||||||
if not tokenizer.pad_token: | ||||||||||
tokenizer.pad_token = "[TEST_PAD]" | ||||||||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
input_str = "lower newer" | ||||||||||
image_input = self.prepare_image_inputs() | ||||||||||
|
||||||||||
# Define the kwargs for each modality | ||||||||||
all_kwargs = { | ||||||||||
"common_kwargs": {"return_tensors": "pt"}, | ||||||||||
"images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"}, | ||||||||||
"text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False}, | ||||||||||
} | ||||||||||
|
||||||||||
inputs = processor(text=input_str, images=image_input, **all_kwargs) | ||||||||||
self.skip_processor_without_typed_kwargs(processor) | ||||||||||
|
||||||||||
self.assertEqual(inputs["pixel_values"].shape[2], 214) | ||||||||||
self.assertEqual(len(inputs["input_ids"][0]), 76) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like
tokenize_newline_separately
is not use anywhere, and it is not a defaulttext_kwargs
, so it might be best to remove it entirely?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's not used anymore and is not needed - iiuc do_thumbnail, do_align_long_axis and do_rescale neither (FYI, they are not used here)
+1 for removing it