From 41f7c1f8ac0502619e468d530bf149bce5190898 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Sep 2024 12:31:37 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 675667232 --- .../layers/transformer_encoder_block.py | 37 +++++++++++++++---- .../layers/transformer_encoder_block_test.py | 7 +++- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/official/nlp/modeling/layers/transformer_encoder_block.py b/official/nlp/modeling/layers/transformer_encoder_block.py index c33b81a3d5..bbe7f47bab 100644 --- a/official/nlp/modeling/layers/transformer_encoder_block.py +++ b/official/nlp/modeling/layers/transformer_encoder_block.py @@ -115,6 +115,7 @@ def __init__(self, use_sigmoid_attn=False, sigmoid_attn_bias=None, linformer_dim=None, + linformer_shared_kv_projection=True, **kwargs): """Initializes `TransformerEncoderBlock`. @@ -194,6 +195,8 @@ def __init__(self, `block_sparse_attention.MultiHeadAttention` linformer_dim: Applies low-rank factorization on keys/values as in https://arxiv.org/pdf/2006.04768. + linformer_shared_kv_projection: If set, projection layer is shared for + keys and values. **kwargs: keyword arguments. """ util.filter_kwargs(kwargs) @@ -234,6 +237,7 @@ def __init__(self, self._use_sigmoid_attn = use_sigmoid_attn self._sigmoid_attn_bias = sigmoid_attn_bias self._linformer_dim = linformer_dim + self._linformer_shared_kv_projection = linformer_shared_kv_projection if self._num_kv_heads is not None and self._src_block_size is not None: raise ValueError( "Block sparse attention does not support Multi-query attention." @@ -383,11 +387,13 @@ def build(self, input_shape): dtype=tf.float32, ) if self._linformer_dim is not None: - # Current implementation uses the same weights for keys and values. - # TODO(akandoor): Explore using different weights for keys and values. + if self._linformer_shared_kv_projection: + low_rank_dim = self._linformer_dim + else: + low_rank_dim = 2 * self._linformer_dim self._lowrank_kv_projection = tf_keras.layers.EinsumDense( "...bc,cd->...bd", - output_shape=(None, self._linformer_dim), + output_shape=(None, low_rank_dim), kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer ), @@ -444,6 +450,8 @@ def get_config(self): "tgt_block_size": self._tgt_block_size, "use_sigmoid_attn": self._use_sigmoid_attn, "sigmoid_attn_bias": self._sigmoid_attn_bias, + "linformer_dim": self._linformer_dim, + "linformer_shared_kv_projection": self._linformer_shared_kv_projection, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @@ -499,6 +507,8 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any: if key_value is None: key_value = input_tensor + key = key_value + value = key_value if self._linformer_dim is not None: if attention_mask is not None: # Applying mask before the low rank factorization so that padding is @@ -510,17 +520,28 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any: attention_mask = None key_value = tf.transpose(key_value, [0, 2, 1]) key_value = self._lowrank_kv_projection(key_value) - key_value = tf.transpose(key_value, [0, 2, 1]) - + if self._linformer_shared_kv_projection: + key_value = tf.transpose(key_value, [0, 2, 1]) + key = key_value + value = key_value + else: + key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1]) + value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1]) if self._return_attention_scores: attention_output, attention_scores = self._attention_layer( query=target_tensor, - value=key_value, + key=key, + value=value, attention_mask=attention_mask, - return_attention_scores=True) + return_attention_scores=True, + ) else: attention_output = self._attention_layer( - query=target_tensor, value=key_value, attention_mask=attention_mask) + query=target_tensor, + key=key, + value=value, + attention_mask=attention_mask, + ) attention_output = self._attention_dropout(attention_output) if self._norm_first: diff --git a/official/nlp/modeling/layers/transformer_encoder_block_test.py b/official/nlp/modeling/layers/transformer_encoder_block_test.py index 7eccca5a49..a6c5097f5a 100644 --- a/official/nlp/modeling/layers/transformer_encoder_block_test.py +++ b/official/nlp/modeling/layers/transformer_encoder_block_test.py @@ -800,7 +800,11 @@ def test_block_sparse_attention(self, use_sigmoid_attn): output_tensor[1].shape.as_list(), expected_attention_scores_shape ) - def test_low_rank_attention(self): + @parameterized.named_parameters( + ('unshared_kv_projection', False), + ('shared_kv_projection', True), + ) + def test_low_rank_attention(self, shared_kv_projection): num_attention_heads = 8 sequence_length = 21 linformer_dim = 7 @@ -812,6 +816,7 @@ def test_low_rank_attention(self): inner_activation='relu', return_attention_scores=True, linformer_dim=linformer_dim, + linformer_shared_kv_projection=shared_kv_projection, ) # Create a 3-dimensional input (the first dimension is implicit). data_tensor = tf_keras.Input(shape=(sequence_length, width))