Skip to content

Commit

Permalink
remove strip in tokenize, keep characters used in special tokens, fix…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
itazap committed Aug 23, 2024
1 parent 844c95c commit 50500a5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
5 changes: 4 additions & 1 deletion src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,14 +1099,17 @@ def post_processor(self):


class SiglipConverter(SpmConverter):
handle_byte_fallback = True

def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

list_normalizers = []

if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
list_normalizers.append(normalizers.Replace(Regex(r"[" + re.escape(string.punctuation) + "]"), ""))
punctuation_to_remove = string.punctuation.replace('>', '').replace('<', '').replace('/', '')
list_normalizers.append(normalizers.Replace(Regex(r"[" + re.escape(punctuation_to_remove) + "]"), ""))
list_normalizers.extend(
[
normalizers.Replace(Regex(r"\s+"), " "),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/siglip/tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def __setstate__(self, d):
self.sp_model.Load(self.vocab_file)

def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
punctuation_to_remove = string.punctuation.replace('>', '').replace('<', '').replace('/', '')
return text.translate(str.maketrans("", "", punctuation_to_remove))

# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
Expand All @@ -287,7 +288,6 @@ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()

return text

Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/siglip/tokenization_siglip_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,3 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]

def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
13 changes: 8 additions & 5 deletions tests/models/siglip/test_tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ def test_subword_regularization_tokenizer(self):
def test_pickle_subword_regularization_tokenizer(self):
pass

# @unittest.skip(reason="SiglipTokenizer has custom lowercase logic")
# def test_added_tokens_do_lower_case(self):
# pass
@unittest.skip(reason="SiglipTokenizer has custom lowercase logic")
def test_added_tokens_do_lower_case(self):
pass

@unittest.skip(reason="Sigliptokenizers trips the punctuation for chat tokens")
def test_chat_template_return_assistant_tokens_mask(self):
pass

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization with T5->Siglip
def test_special_tokens_initialization(self):
Expand Down Expand Up @@ -383,8 +387,7 @@ def test_some_edge_cases(self):
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["</", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>"])
self.assertEqual(tokens, ["</s>", ">"])

tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
Expand Down

0 comments on commit 50500a5

Please sign in to comment.