Skip to content

Commit

Permalink
Remove obsolete code in the attention layers
Browse files Browse the repository at this point in the history
  • Loading branch information
kpot committed Dec 10, 2018
1 parent b91ea60 commit 1bbd5b2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions keras_transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_config(self):
return config

# noinspection PyAttributeOutsideInit
def build_output_params(self, d_model, k_seq_length):
def build_output_params(self, d_model):
self.output_weights = self.add_weight(
name='output_weights',
shape=(d_model, d_model),
Expand Down Expand Up @@ -231,7 +231,6 @@ def build(self, input_shape):
'You must call this layer passing a list of two tensors'
'(for keys/values and queries)')
values_dim, query_dim = input_shape[0][-1], input_shape[1][-1]
values_seq_length = input_shape[0][-2]
if query_dim != values_dim:
raise ValueError(
f'Both keys/value and query inputs must be '
Expand All @@ -250,7 +249,7 @@ def build(self, input_shape):
self.q_weights = self.add_weight(
name='q_weights', shape=(d_model, d_model),
initializer='glorot_uniform', trainable=True)
self.build_output_params(d_model, values_seq_length)
self.build_output_params(d_model)
return super().build(input_shape)

def call(self, inputs, **kwargs):
Expand Down Expand Up @@ -291,7 +290,7 @@ class MultiHeadSelfAttention(_BaseMultiHeadAttention):
def build(self, input_shape):
if not isinstance(input_shape, tuple):
raise ValueError('Invalid input')
seq_length, d_model = input_shape[-2:]
d_model = input_shape[-1]
self.validate_model_dimensionality(d_model)
# These weights are concatenated matrices W_q, W_k and W_v which
# are, in turn, concatenated W matrices of keys, queries and values
Expand All @@ -303,7 +302,7 @@ def build(self, input_shape):
shape=(d_model, d_model * 3), # * 3 for q, k and v
initializer='glorot_uniform',
trainable=True)
self.build_output_params(d_model, seq_length)
self.build_output_params(d_model)
return super().build(input_shape)

def call(self, inputs, **kwargs):
Expand Down

0 comments on commit 1bbd5b2

Please sign in to comment.