Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[Dict] Add extra special tokens (#2828)
Browse files Browse the repository at this point in the history
* special tokens

* lint

* not decoding the right special tokens

* add a test

* get rid of tODO

* add special tokens after

* offset messed up again

* special tokens agent

* lint

* lint

* just build this into torch agent

* what did i do

* lint

* nit

* test transformer generator

* lint

* git test add special tokens

* oopsy
  • Loading branch information
Emily Dinan authored Jul 16, 2020
1 parent 0292f57 commit 20cc87d
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 13 deletions.
30 changes: 30 additions & 0 deletions parlai/agents/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from parlai.core.torch_classifier_agent import TorchClassifierAgent
from parlai.core.torch_ranker_agent import TorchRankerAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.utils.misc import recursive_getattr
from parlai.utils.logging import logging

from .modules import (
TransformerMemNetModel,
Expand Down Expand Up @@ -326,6 +328,34 @@ def build_model(self, states=None):
)
return model

def _resize_token_embeddings(self, state_dict, msg=None):
"""
Resize the token embeddings when are adding extra special tokens.
"""
# map extra special tokens carefully
new_size = self.model.embeddings.weight.size()[0]
orig_size = state_dict['embeddings.weight'].size()[0]
logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
if new_size <= orig_size:
# new size should be greater than original size,
# as we are adding special tokens
raise RuntimeError(msg)

for emb_weights in [
'embeddings.weight',
'encoder.embeddings.weight',
'decoder.embeddings.weight',
]:
# get new_embs
old_embs = state_dict[emb_weights]
new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
# copy over old weights
new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
# reset in state dict
state_dict[emb_weights] = new_embs

return state_dict


class TransformerClassifierAgent(TorchClassifierAgent):
"""
Expand Down
41 changes: 39 additions & 2 deletions parlai/core/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import re
import parlai.utils.logging as logging
from typing import List

RETOK = re.compile(r'\w+|[^\w\s]|\n', re.UNICODE)

Expand Down Expand Up @@ -324,6 +325,38 @@ def __init__(self, opt: Opt, shared=None):
if opt.get('dict_file'):
self.save_path = opt['dict_file']

def add_additional_special_tokens(self, additional_special_tokens: List[str]):
"""
Add additional special tokens to the dictionary.
Should only be called after initialization of the existing dictionary.
"""
self.additional_special_tokens = additional_special_tokens

if (
self.additional_special_tokens
and not self.supports_additional_special_tokens()
):
raise RuntimeError(
f'{self.tokenizer} does not currently support adding additional special tokens'
)

for tok in self.additional_special_tokens:
self.add_token(tok)

for i, tok in enumerate(self.additional_special_tokens):
self.freq[tok] = 1000000000 + 4 + i

if self.tokenizer == 'bytelevelbpe':
self.bpe.add_special_tokens(self, self.additional_special_tokens)

def supports_additional_special_tokens(self):
"""
Indicates whether the dictionary supports additional special tokens.
"""
# TODO: add to others
return self.tokenizer in ['bytelevelbpe', 'split', 'space']

def is_prebuilt(self):
"""
Indicates whether the dictionary is fixed, and does not require building.
Expand Down Expand Up @@ -708,9 +741,13 @@ def vec2txt(self, vector, delimiter=' '):
text = self.bpe.decode(tokens, vector, delimiter)
elif self.tokenizer == 'bytelevelbpe':
# We add special tokens in the beginning of ParlAI dict but in the
# end of Hugging Face dict,there is an offset of 4 between them.
# end of Hugging Face dict, there is an offset of #(extra tokens) between them.
extra_tokens = 4 # length of special tokens
vector = [
idx + len(self.tok2ind) - 4 if idx < 4 else idx - 4 for idx in vector
self.bpe.special_tok_map[idx]
if idx in self.bpe.special_tok_map
else idx - extra_tokens
for idx in vector
]
tokens = [self[int(idx)] for idx in vector]
text = self.bpe.decode(tokens, vector, delimiter)
Expand Down
58 changes: 50 additions & 8 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ def add_cmdline_args(cls, argparser):
choices=[None, 'end'],
help='Add special token to the end of history encoding.',
)
agent.add_argument(
'--special-tok-lst',
type=str,
default=None,
help='Comma separated list of special tokens',
)
# GPU arguments
# these gpu options are all mutually exclusive, and should error if the
# user tries to present multiple of them
Expand Down Expand Up @@ -801,11 +807,28 @@ def build_dictionary(self):
place to do it.
"""
d = self.dictionary_class()(self.opt)
self.special_toks = self._get_special_tokens()
if self.special_toks:
d.add_additional_special_tokens(self.special_toks)

if self.opt.get('person_tokens'):
d[self.P1_TOKEN] = 999_999_999
d[self.P2_TOKEN] = 999_999_998
return d

def _resize_token_embeddings(self, state_dict, msg=None):
"""
Must define this for your agent if you wish to add additional special tokens.
Must make a call to resize the token embeddings and load the model state dict
with the resized token embeddings.
"""
raise NotImplementedError(
'If you are intending to add special tokens to an already pretrained model, '
'you must write the function `_resize_token_embeddings` for your specific '
'agent.'
)

def _get_init_model(self, opt: Opt, shared):
"""
Get model file to initialize with.
Expand Down Expand Up @@ -845,6 +868,16 @@ def _get_init_model(self, opt: Opt, shared):

return init_model, is_finetune

def _get_special_tokens(self) -> List[str]:
"""
Return list of special tokens.
Made easily overridable for special cases.
"""
if self.opt.get('special_tok_lst') is not None:
return self.opt['special_tok_lst'].split(',')
return []

@abstractmethod
def build_model(self):
"""
Expand Down Expand Up @@ -878,6 +911,10 @@ def init_optim(self, params, optim_states=None, saved_optim_type=None):
type of optimizer being loaded, if changed will skip loading
optimizer states
"""
if hasattr(self, 'resized_embeddings') and self.resized_embeddings:
optim_states = None
logging.warn('Not loading optimizer due to resize in token embeddings')

opt = self.opt

# set up optimizer args
Expand Down Expand Up @@ -1810,14 +1847,19 @@ def load_state_dict(self, state_dict):
except RuntimeError as msg:
msg_ = str(msg)
if 'size mismatch' in msg_ and 'embedding' in msg_:
raise RuntimeError(
f'{msg_}\n'
'-----------------\n'
'Could not load the model due to a size mismatch in the '
'embeddings. A common reason for this is trying to load '
'a model trained with fp16 but loaded without fp16. Try '
'adding --fp16 true or --force-fp16-tokens true.'
)
if hasattr(self, 'special_toks') and len(self.special_toks) > 0:
state_dict = self._resize_token_embeddings(state_dict, msg_)
self.model.load_state_dict(state_dict)
self.resized_embeddings = True # make note that we resized here
else:
raise RuntimeError(
f'{msg_}\n'
'-----------------\n'
'Could not load the model due to a size mismatch in the '
'embeddings. A common reason for this is trying to load '
'a model trained with fp16 but loaded without fp16. Try '
'adding --fp16 true or --force-fp16-tokens true.'
)
else:
raise

Expand Down
31 changes: 28 additions & 3 deletions parlai/utils/bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def add_cmdline_args(argparser):
hidden=True,
help='add prefix space before encoding',
)
parser.add_argument(
'--hf-skip-special-tokens',
hidden=True,
type='bool',
default=True,
help='do not decode special tokens with bytelevelbpe',
)
return parser

@final
Expand Down Expand Up @@ -689,7 +696,9 @@ class HuggingFaceBpeHelper(BPEHelper):
def __init__(self, opt: Opt, shared: TShared = None):
super().__init__(opt, shared)
# Default true for HF
self.special_tok_map = {} # map from HF
self.add_prefix_space = opt.get('bpe_add_prefix_space', True)
self.skip_special_tokens = opt.get('hf_skip_special_tokens', True)
if self.add_prefix_space is None:
self.add_prefix_space = True
if opt.get('dict_loaded'):
Expand Down Expand Up @@ -769,9 +778,24 @@ def helper_decode(
:return text:
decoded text
"""
text = self.tokenizer.decode(token_ids)
text = self.tokenizer.decode(
token_ids, skip_special_tokens=self.skip_special_tokens
)

return text

def add_special_tokens(self, dict_agent, special_tokens: List[str]):
"""
Add special tokens to the tokenizer and dict_agent.
"""
logging.info(f'adding the following special tokens: {special_tokens}')
self.tokenizer.add_special_tokens(special_tokens) # add to HF

for tok in special_tokens:
parlai_key = dict_agent[tok]
hf_key = self.tokenizer.token_to_id(tok)
self.special_tok_map[parlai_key] = hf_key

def sync_with_dict(self, dict_agent):
"""
Sync the dictionary agent with Hugging Face tokenizer's BPE dict.
Expand All @@ -784,8 +808,9 @@ def sync_with_dict(self, dict_agent):
dict_agent.end_token,
dict_agent.unk_token,
]
self.tokenizer.add_special_tokens(special_tokens)
for i in range(self.tokenizer.get_vocab_size() - 4):
self.add_special_tokens(dict_agent, special_tokens)

for i in range(self.tokenizer.get_vocab_size() - len(special_tokens)):
token = self.tokenizer.id_to_token(i)
dict_agent.add_token(token)
# We don't have access to the hugging face word frequency table,
Expand Down
12 changes: 12 additions & 0 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import deque, OrderedDict
from typing import Union, Optional, Set, Any, Dict, List, Tuple
from datetime import timedelta
import functools
import math
import time
import re
Expand Down Expand Up @@ -752,3 +753,14 @@ def error_once(msg: str) -> None:
if msg not in _seen_logs:
_seen_logs.add(msg)
logging.error(msg)


def recursive_getattr(obj, attr, *args):
"""
Recursive call to getattr for nested attributes.
"""

def _getattr(obj, attr):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split('.'))
25 changes: 25 additions & 0 deletions tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,31 @@ def test_save_reload(self):
)
assert da2.txt2vec("hello") == da.txt2vec("hello")

def test_add_special_tokens(self):
"""
Add a list of special tokens to the dictionary.
"""
special_toks_lst = ['MY', 'NAME', 'IS', 'EMILY']
# create Dictionary Agent
parser = ParlaiParser()
parser.set_params(
dict_tokenizer='bytelevelbpe',
bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
hf_skip_special_tokens=False,
)
opt = parser.parse_args([], print_args=False)

agent = DictionaryAgent(opt)
agent.add_additional_special_tokens(special_toks_lst)

self.assertEqual(agent.additional_special_tokens, special_toks_lst)
phrases = ['Hi what is up EMILY', 'What IS your NAME', 'That is MY dog']
for phrase in phrases:
vec = agent.txt2vec(phrase)
text = agent.vec2txt(vec)
self.assertEqual(phrase, text)


class TestBuildDict(unittest.TestCase):
def _run_test(self, opt):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import parlai.utils.testing as testing_utils
from parlai.core.agents import create_agent
from parlai.core.opt import Opt
from tests.test_dict import DEFAULT_BYTELEVEL_BPE_VOCAB, DEFAULT_BYTELEVEL_BPE_MERGE
from parlai.core.params import ParlaiParser


class TestTransformerRanker(unittest.TestCase):
Expand Down Expand Up @@ -674,6 +676,50 @@ def test_temperature(self):
)
)

def test_resize_embeddings(self):
# train original model
with testing_utils.tempdir() as tmpdir:
model_file = os.path.join(tmpdir, 'model_file')
_, _ = testing_utils.train_model(
dict(
model='transformer/generator',
task='integration_tests:short_fixed',
n_layers=1,
n_encoder_layers=2,
n_decoder_layers=4,
num_epochs=1,
dict_tokenizer='bytelevelbpe',
bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
bpe_add_prefix_space=False,
model_file=model_file,
save_after_valid=True,
)
)

# now create agent with special tokens
parser = ParlaiParser()
parser.set_params(
model='transformer/generator',
task='integration_tests:short_fixed',
n_layers=1,
n_encoder_layers=2,
n_decoder_layers=4,
dict_tokenizer='bytelevelbpe',
bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
bpe_add_prefix_space=False,
model_file=model_file,
save_after_valid=True,
special_tok_lst='PARTY,PARROT',
)
opt = parser.parse_args([], print_args=False)
agent = create_agent(opt)
# assert that the embeddings were resized
assert agent.resized_embeddings
# assert model has special tokens
self.assertEqual(agent.special_toks, ['PARTY', 'PARROT'])


class TestClassifier(unittest.TestCase):
"""
Expand Down

0 comments on commit 20cc87d

Please sign in to comment.