Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[v0.8.x][BUGFIX] avoid using dict for attention cell parameter creati…
Browse files Browse the repository at this point in the history
…on (#1051)

* avoid using dict for attention cell parameter creation

* fix

* Update test_sequence_sampler.py

* Update cache.py
  • Loading branch information
eric-haibin-lin authored Dec 14, 2019
1 parent 88fb92e commit b5ded8f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def __init__(self, base_cell, query_units, key_units, value_units, num_heads, us
self._base_cell = base_cell
self._num_heads = num_heads
self._use_bias = use_bias
units = {'query': query_units, 'key': key_units, 'value': value_units}
for name, unit in units.items():
units = [('query', query_units), ('key', key_units), ('value', value_units)]
for name, unit in units:
if unit % self._num_heads != 0:
raise ValueError(
'In MultiHeadAttetion, the {name}_units should be divided exactly'
Expand Down
8 changes: 2 additions & 6 deletions src/gluonnlp/model/train/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,13 @@ def __init__(self, lm_model, vocab_size, window, theta, lambdas, **kwargs):
with self.name_scope():
self.lm_model = lm_model

def save_parameters(self, filename, deduplicate=False): # pylint: disable=arguments-differ
def save_parameters(self, filename): # pylint: disable=arguments-differ
"""Save parameters to file.
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.
"""
self.lm_model.save_parameters(filename, deduplicate=deduplicate)
self.lm_model.save_parameters(filename)

def load_parameters(self, filename, ctx=mx.cpu()): # pylint: disable=arguments-differ
"""Load parameters from file.
Expand Down
4 changes: 2 additions & 2 deletions tests/unittest/test_sequence_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def context_free_distribution(step_input, states):
true_dist = dist.softmax().asnumpy()
assert_allclose(true_dist, np.array(emp_dist), atol=0.01, rtol=0.1)

# temporarily disabled model.HybridBeamSearchSampler test
# due to https://github.com/dmlc/gluon-nlp/issues/706

@pytest.mark.skip(reason='https://github.com/dmlc/gluon-nlp/issues/1020')
@pytest.mark.seed(1)
@pytest.mark.parametrize('hybridize', [False, True])
@pytest.mark.parametrize('sampler_cls', [model.BeamSearchSampler, model.HybridBeamSearchSampler])
Expand Down

0 comments on commit b5ded8f

Please sign in to comment.