Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit e65cd41

Browse files
leezushishirb126szha
authored
[PERFORMANCE] Improve vocab lookup performance by working with a dict() directly (#1382) (#1385)
Co-authored-by: Sheng Zha <[email protected]> Co-authored-by: shishirb126 <[email protected]> Co-authored-by: Sheng Zha <[email protected]>
1 parent 3fbe961 commit e65cd41

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/gluonnlp/vocab/vocab.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from .. import _constants as C
3232
from .. import embedding as emb
33-
from ..data.utils import Counter, DefaultLookupDict, count_tokens
33+
from ..data.utils import Counter, count_tokens
3434

3535
UNK_IDX = 0
3636
_DEPR_PAD = object()
@@ -219,10 +219,7 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =
219219
# Set up idx_to_token and token_to_idx based on presence of unknown token
220220
self._unknown_token = unknown_token
221221
self._idx_to_token = [unknown_token] if unknown_token else []
222-
if unknown_token:
223-
self._token_to_idx = DefaultLookupDict(UNK_IDX)
224-
else:
225-
self._token_to_idx = {}
222+
self._token_to_idx = dict()
226223

227224
# Handle special tokens
228225
special_tokens = []
@@ -267,10 +264,6 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =
267264

268265
if token_to_idx:
269266
self._sort_index_according_to_user_specification(token_to_idx)
270-
if unknown_token:
271-
self._token_to_idx._default = \
272-
self._token_to_idx[unknown_token] # pytype: disable=not-writable
273-
274267

275268
def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size,
276269
min_freq):
@@ -395,9 +388,17 @@ def __getitem__(self, tokens):
395388
"""
396389

397390
if not isinstance(tokens, (list, tuple)):
398-
return self._token_to_idx[tokens]
391+
if self._unknown_token:
392+
unknown_token_idx = self._token_to_idx[self._unknown_token]
393+
return self._token_to_idx.get(tokens, unknown_token_idx)
394+
else:
395+
return self._token_to_idx[tokens]
399396
else:
400-
return [self._token_to_idx[token] for token in tokens]
397+
if self._unknown_token:
398+
unknown_token_idx = self._token_to_idx[self._unknown_token]
399+
return [self._token_to_idx.get(token, unknown_token_idx) for token in tokens]
400+
else:
401+
return [self._token_to_idx[token] for token in tokens]
401402

402403
def __len__(self):
403404
return len(self._idx_to_token)

0 commit comments

Comments
 (0)