Skip to content

Commit 41f7c1f

Browse files
No public description
PiperOrigin-RevId: 675667232
1 parent 764091c commit 41f7c1f

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self,
115115
use_sigmoid_attn=False,
116116
sigmoid_attn_bias=None,
117117
linformer_dim=None,
118+
linformer_shared_kv_projection=True,
118119
**kwargs):
119120
"""Initializes `TransformerEncoderBlock`.
120121
@@ -194,6 +195,8 @@ def __init__(self,
194195
`block_sparse_attention.MultiHeadAttention`
195196
linformer_dim: Applies low-rank factorization on keys/values as in
196197
https://arxiv.org/pdf/2006.04768.
198+
linformer_shared_kv_projection: If set, projection layer is shared for
199+
keys and values.
197200
**kwargs: keyword arguments.
198201
"""
199202
util.filter_kwargs(kwargs)
@@ -234,6 +237,7 @@ def __init__(self,
234237
self._use_sigmoid_attn = use_sigmoid_attn
235238
self._sigmoid_attn_bias = sigmoid_attn_bias
236239
self._linformer_dim = linformer_dim
240+
self._linformer_shared_kv_projection = linformer_shared_kv_projection
237241
if self._num_kv_heads is not None and self._src_block_size is not None:
238242
raise ValueError(
239243
"Block sparse attention does not support Multi-query attention."
@@ -383,11 +387,13 @@ def build(self, input_shape):
383387
dtype=tf.float32,
384388
)
385389
if self._linformer_dim is not None:
386-
# Current implementation uses the same weights for keys and values.
387-
# TODO(akandoor): Explore using different weights for keys and values.
390+
if self._linformer_shared_kv_projection:
391+
low_rank_dim = self._linformer_dim
392+
else:
393+
low_rank_dim = 2 * self._linformer_dim
388394
self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
389395
"...bc,cd->...bd",
390-
output_shape=(None, self._linformer_dim),
396+
output_shape=(None, low_rank_dim),
391397
kernel_initializer=tf_utils.clone_initializer(
392398
self._kernel_initializer
393399
),
@@ -444,6 +450,8 @@ def get_config(self):
444450
"tgt_block_size": self._tgt_block_size,
445451
"use_sigmoid_attn": self._use_sigmoid_attn,
446452
"sigmoid_attn_bias": self._sigmoid_attn_bias,
453+
"linformer_dim": self._linformer_dim,
454+
"linformer_shared_kv_projection": self._linformer_shared_kv_projection,
447455
}
448456
base_config = super().get_config()
449457
return dict(list(base_config.items()) + list(config.items()))
@@ -499,6 +507,8 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
499507
if key_value is None:
500508
key_value = input_tensor
501509

510+
key = key_value
511+
value = key_value
502512
if self._linformer_dim is not None:
503513
if attention_mask is not None:
504514
# Applying mask before the low rank factorization so that padding is
@@ -510,17 +520,28 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
510520
attention_mask = None
511521
key_value = tf.transpose(key_value, [0, 2, 1])
512522
key_value = self._lowrank_kv_projection(key_value)
513-
key_value = tf.transpose(key_value, [0, 2, 1])
514-
523+
if self._linformer_shared_kv_projection:
524+
key_value = tf.transpose(key_value, [0, 2, 1])
525+
key = key_value
526+
value = key_value
527+
else:
528+
key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
529+
value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
515530
if self._return_attention_scores:
516531
attention_output, attention_scores = self._attention_layer(
517532
query=target_tensor,
518-
value=key_value,
533+
key=key,
534+
value=value,
519535
attention_mask=attention_mask,
520-
return_attention_scores=True)
536+
return_attention_scores=True,
537+
)
521538
else:
522539
attention_output = self._attention_layer(
523-
query=target_tensor, value=key_value, attention_mask=attention_mask)
540+
query=target_tensor,
541+
key=key,
542+
value=value,
543+
attention_mask=attention_mask,
544+
)
524545
attention_output = self._attention_dropout(attention_output)
525546

526547
if self._norm_first:

official/nlp/modeling/layers/transformer_encoder_block_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,11 @@ def test_block_sparse_attention(self, use_sigmoid_attn):
800800
output_tensor[1].shape.as_list(), expected_attention_scores_shape
801801
)
802802

803-
def test_low_rank_attention(self):
803+
@parameterized.named_parameters(
804+
('unshared_kv_projection', False),
805+
('shared_kv_projection', True),
806+
)
807+
def test_low_rank_attention(self, shared_kv_projection):
804808
num_attention_heads = 8
805809
sequence_length = 21
806810
linformer_dim = 7
@@ -812,6 +816,7 @@ def test_low_rank_attention(self):
812816
inner_activation='relu',
813817
return_attention_scores=True,
814818
linformer_dim=linformer_dim,
819+
linformer_shared_kv_projection=shared_kv_projection,
815820
)
816821
# Create a 3-dimensional input (the first dimension is implicit).
817822
data_tensor = tf_keras.Input(shape=(sequence_length, width))

0 commit comments

Comments
 (0)