Skip to content

Commit

Permalink
First draft
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Aug 28, 2024
1 parent 74e19e8 commit 7d93bb5
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def convert_blip2_checkpoint(
else:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")

tokenizer.model_input_names = ["input_ids", "attention_mask"]

if "itm" in model_name:
eos_token_id = None
else:
Expand Down
12 changes: 2 additions & 10 deletions src/transformers/pipelines/zero_shot_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwar
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
tokenizer_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the tokenizer.
Return:
A list of dictionaries containing one entry per proposed label. Each dictionary contains the
following keys:
Expand All @@ -109,16 +106,14 @@ def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwar
"""
return super().__call__(images, **kwargs)

def _sanitize_parameters(self, tokenizer_kwargs=None, **kwargs):
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
if tokenizer_kwargs is not None:
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs

return preprocess_params, {}, {}

Expand All @@ -128,18 +123,15 @@ def preprocess(
candidate_labels=None,
hypothesis_template="This is a photo of {}.",
timeout=None,
tokenizer_kwargs=None,
):
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors=self.framework)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]
padding = "max_length" if self.model.config.model_type == "siglip" else True
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding, **tokenizer_kwargs)
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding)
inputs["text_inputs"] = [text_inputs]
return inputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def test_blip2_model_pt(self):
output = image_classifier(
image,
candidate_labels=["2 cats", "a plane", "a remote"],
tokenizer_kwargs={"return_token_type_ids": False},
)

self.assertEqual(
Expand All @@ -308,7 +307,6 @@ def test_blip2_model_pt(self):
[image] * 5,
candidate_labels=["2 cats", "a plane", "a remote"],
batch_size=2,
tokenizer_kwargs={"return_token_type_ids": False},
)

self.assertEqual(
Expand Down

0 comments on commit 7d93bb5

Please sign in to comment.