Skip to content

Commit 3a6adcb

Browse files
supporting overriding validate_ends_with_eos in tokenizer.call (#128)
The following changes were made to allow this: 1. In the init: change validate_ends_with_eos from str to bool, and add another argument eos, with default="<EOS>". default behavior is maintained. I searched for all call for the init - None used validate_ends_with_eos 2. add a new boolean argument validate_ends_with_eos to tokenizer.call with default None. If not None - will override self._validate_ends_with_eos during the call. The change was made to allow not adding <EOS> to the decoder, to allow it to be the same length as the label. This allows cropping redundant padding in a batch - also in decoder input and label fields
1 parent 7fc56d9 commit 3a6adcb

File tree

5 files changed

+39
-16
lines changed

5 files changed

+39
-16
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pathlib import Path
2+
3+
# The directory containing this file
4+
CONFIG_DIRPATH = Path(__file__).parent
5+
6+
7+
def get_modular_tokenizer_config_dirpath() -> str:
8+
return str(CONFIG_DIRPATH.resolve())

fusedrug/data/tokenizer/modulartokenizer/modular_tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,9 @@ def set_field(tokenizers_info_cfg: List, name: str, key: str, val: Any) -> List:
776776
key="json_path",
777777
val=config_out_path,
778778
)
779+
tokenizer_dir = os.path.dirname(write_out_path)
780+
if not os.path.exists(tokenizer_dir):
781+
os.mkdir(tokenizer_dir)
779782
tokenizer_inst.save(write_out_path)
780783
tokenizer_config_overall = {
781784
"tokenizers_info": tokenizers_info_cfg,

fusedrug/data/tokenizer/ops/fast_tokenizer_ops.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(
2222
max_size: int = None,
2323
pad_token: str = None,
2424
pad_type_id: str = None,
25-
validate_ends_with_eos: Optional[str] = "<EOS>",
25+
validate_ends_with_eos: Optional[bool] = True,
26+
eos: Optional[str] = "<EOS>",
2627
verbose: bool = False,
2728
**kwargs: dict,
2829
):
@@ -54,11 +55,12 @@ def __init__(
5455
)
5556

5657
self._validate_ends_with_eos = validate_ends_with_eos
58+
self._eos = eos
5759

58-
if self._validate_ends_with_eos is not None:
59-
if self._validate_ends_with_eos not in vocab.keys():
60+
if self._validate_ends_with_eos:
61+
if self._eos not in vocab.keys():
6062
raise Exception(
61-
f"Could not find eos token = {validate_ends_with_eos} in {tokenizer_json}. You can disable the validation by setting validate_ends_with_eos=None"
63+
f"Could not find eos token = {self._eos} in {tokenizer_json}. You can disable the validation by setting validate_ends_with_eos=False"
6264
)
6365

6466
self._pad_id = pad_id
@@ -171,6 +173,7 @@ def __call__(
171173
key_out_tokens_ids: str = None,
172174
key_out_attention_mask: str = None,
173175
convert_attention_mask_to_bool: bool = True,
176+
validate_ends_with_eos: Optional[bool] = None,
174177
) -> NDict:
175178
# if self._verbose:
176179
# print(
@@ -182,11 +185,13 @@ def __call__(
182185
raise Exception(
183186
f"Expected key_in={key_in} to point to a string, and instead got a {type(data_str)}. value={data_str}"
184187
)
188+
if validate_ends_with_eos is None:
189+
validate_ends_with_eos = self._validate_ends_with_eos
185190

186-
if self._validate_ends_with_eos is not None:
187-
if not data_str.rstrip().endswith(self._validate_ends_with_eos):
191+
if validate_ends_with_eos:
192+
if not data_str.rstrip().endswith(self._eos):
188193
raise Exception(
189-
f"self._validate_ends_with_eos was set to {self._validate_ends_with_eos}, but about to encode a string that does not end with it. The str was: {data_str}"
194+
f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str was: {data_str}"
190195
)
191196

192197
encoded = self._tokenizer.encode(data_str)

fusedrug/data/tokenizer/ops/modular_tokenizer_ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __init__(
2121
max_size: Union[int, None] = None,
2222
pad_token: Union[str, None] = None,
2323
pad_type_id: Union[int, None] = None,
24-
validate_ends_with_eos: Optional[str] = "<EOS>",
24+
validate_ends_with_eos: Optional[bool] = True,
25+
eos: Optional[str] = "<EOS>",
2526
verbose: Optional[bool] = False,
2627
**kwargs: Any,
2728
) -> None:
@@ -53,12 +54,13 @@ def __init__(
5354
)
5455

5556
self._validate_ends_with_eos = validate_ends_with_eos
57+
self._eos = eos
5658

57-
if self._validate_ends_with_eos is not None:
58-
eos_id = self._tokenizer.token_to_id(self._validate_ends_with_eos)
59+
if self._validate_ends_with_eos:
60+
eos_id = self._tokenizer.token_to_id(self._eos)
5961
if eos_id is None:
6062
raise Exception(
61-
f"Could not find eos token = {validate_ends_with_eos} in {tokenizer_path}. You can disable the validation by setting validate_ends_with_eos=None"
63+
f"Could not find eos token = {self._eos} in {tokenizer_path}. You can disable the validation by setting validate_ends_with_eos=False"
6264
)
6365

6466
self._pad_id = pad_id
@@ -190,6 +192,7 @@ def __call__(
190192
max_seq_len: Optional[int] = None,
191193
on_unknown: Optional[str] = "warn",
192194
verbose: Optional[int] = 1,
195+
validate_ends_with_eos: Optional[bool] = None,
193196
) -> NDict:
194197
"""_summary_
195198
@@ -207,6 +210,7 @@ def __call__(
207210
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn".
208211
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
209212
with full data. Defaults to 1.
213+
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
210214
211215
Raises:
212216
Exception: _description_
@@ -222,15 +226,17 @@ def __call__(
222226
raise Exception(
223227
f"Expected key_in={key_in} to point to a list of inputs or string with builtin tokenizer hints, and instead got a {type(data)}. value={data}"
224228
)
229+
if validate_ends_with_eos is None:
230+
validate_ends_with_eos = self._validate_ends_with_eos
225231

226-
if self._validate_ends_with_eos is not None:
232+
if validate_ends_with_eos:
227233
if isinstance(data, str):
228234
last_seq = data
229235
else:
230236
last_seq = data[-1].input_string
231-
if not last_seq.rstrip().endswith(self._validate_ends_with_eos):
237+
if not last_seq.rstrip().endswith(self._eos):
232238
raise Exception(
233-
f"self._validate_ends_with_eos was set to {self._validate_ends_with_eos}, but about to encode a string that does not end with it. The str end was: {last_seq}"
239+
f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str end was: {last_seq}"
234240
)
235241

236242
if isinstance(data, str):

fusedrug/data/tokenizer/ops/test_modular_tokenizer_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hydra
22
from omegaconf import DictConfig, OmegaConf
33

4+
45
from typing import Dict, Optional, Any
56
import pytorch_lightning as pl
67
from fuse.utils import NDict
@@ -142,7 +143,7 @@ def main(cfg: DictConfig) -> None:
142143
tokenizer_path=cfg_raw["data"]["tokenizer"]["out_path"],
143144
max_size=global_max_len,
144145
pad_token="<PAD>",
145-
validate_ends_with_eos="<EOS>",
146+
validate_ends_with_eos=True,
146147
)
147148
test_tokenizer_op(
148149
tokenizer_op_inst=mod_tokenizer_op,
@@ -154,7 +155,7 @@ def main(cfg: DictConfig) -> None:
154155
tokenizer_path=cfg_raw["data"]["tokenizer"]["out_path"],
155156
max_size=global_max_len,
156157
pad_token="<PAD>",
157-
validate_ends_with_eos="<EOS>",
158+
validate_ends_with_eos=True,
158159
)
159160
test_tokenizer_op(
160161
tokenizer_op_inst=mod_tokenizer_op,

0 commit comments

Comments
 (0)