Skip to content

Commit

Permalink
remove special tokens that are not special
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Aug 15, 2024
1 parent 22c5ccd commit d9dbd6b
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import sys
from typing import TYPE_CHECKING, List, Optional, Union
import re

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
Expand Down Expand Up @@ -144,9 +145,10 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, ch
self.image_token = AddedToken("<image>", normalized=False, special=True)
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
self.global_img_token = "<global-img>"
self.global_img_token_id = tokenizer.convert_tokens_to_ids(self.global_img_token)
self.image_seq_len = image_seq_len

self._regex_to_remove_extra_special_tokens = re.compile(r'(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+')

tokens_to_add = {
"additional_special_tokens": [
self.fake_image_token,
Expand Down Expand Up @@ -343,14 +345,17 @@ def batch_decode(self, *args, **kwargs):
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs)
return [self._regex_to_remove_extra_special_tokens.sub("<image>", s) for s in batched_decode_output]

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)


@property
def model_input_names(self):
Expand Down

0 comments on commit d9dbd6b

Please sign in to comment.