|
30 | 30 |
|
31 | 31 | from .. import _constants as C
|
32 | 32 | from .. import embedding as emb
|
33 |
| -from ..data.utils import Counter, DefaultLookupDict, count_tokens |
| 33 | +from ..data.utils import Counter, count_tokens |
34 | 34 |
|
35 | 35 | UNK_IDX = 0
|
36 | 36 | _DEPR_PAD = object()
|
@@ -219,10 +219,7 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =
|
219 | 219 | # Set up idx_to_token and token_to_idx based on presence of unknown token
|
220 | 220 | self._unknown_token = unknown_token
|
221 | 221 | self._idx_to_token = [unknown_token] if unknown_token else []
|
222 |
| - if unknown_token: |
223 |
| - self._token_to_idx = DefaultLookupDict(UNK_IDX) |
224 |
| - else: |
225 |
| - self._token_to_idx = {} |
| 222 | + self._token_to_idx = dict() |
226 | 223 |
|
227 | 224 | # Handle special tokens
|
228 | 225 | special_tokens = []
|
@@ -267,10 +264,6 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =
|
267 | 264 |
|
268 | 265 | if token_to_idx:
|
269 | 266 | self._sort_index_according_to_user_specification(token_to_idx)
|
270 |
| - if unknown_token: |
271 |
| - self._token_to_idx._default = \ |
272 |
| - self._token_to_idx[unknown_token] # pytype: disable=not-writable |
273 |
| - |
274 | 267 |
|
275 | 268 | def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size,
|
276 | 269 | min_freq):
|
@@ -395,9 +388,17 @@ def __getitem__(self, tokens):
|
395 | 388 | """
|
396 | 389 |
|
397 | 390 | if not isinstance(tokens, (list, tuple)):
|
398 |
| - return self._token_to_idx[tokens] |
| 391 | + if self._unknown_token: |
| 392 | + unknown_token_idx = self._token_to_idx[self._unknown_token] |
| 393 | + return self._token_to_idx.get(tokens, unknown_token_idx) |
| 394 | + else: |
| 395 | + return self._token_to_idx[tokens] |
399 | 396 | else:
|
400 |
| - return [self._token_to_idx[token] for token in tokens] |
| 397 | + if self._unknown_token: |
| 398 | + unknown_token_idx = self._token_to_idx[self._unknown_token] |
| 399 | + return [self._token_to_idx.get(token, unknown_token_idx) for token in tokens] |
| 400 | + else: |
| 401 | + return [self._token_to_idx[token] for token in tokens] |
401 | 402 |
|
402 | 403 | def __len__(self):
|
403 | 404 | return len(self._idx_to_token)
|
|
0 commit comments