Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675667232
  • Loading branch information
tensorflower-gardener committed Sep 17, 2024
1 parent 764091c commit 41f7c1f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
37 changes: 29 additions & 8 deletions official/nlp/modeling/layers/transformer_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
use_sigmoid_attn=False,
sigmoid_attn_bias=None,
linformer_dim=None,
linformer_shared_kv_projection=True,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Expand Down Expand Up @@ -194,6 +195,8 @@ def __init__(self,
`block_sparse_attention.MultiHeadAttention`
linformer_dim: Applies low-rank factorization on keys/values as in
https://arxiv.org/pdf/2006.04768.
linformer_shared_kv_projection: If set, projection layer is shared for
keys and values.
**kwargs: keyword arguments.
"""
util.filter_kwargs(kwargs)
Expand Down Expand Up @@ -234,6 +237,7 @@ def __init__(self,
self._use_sigmoid_attn = use_sigmoid_attn
self._sigmoid_attn_bias = sigmoid_attn_bias
self._linformer_dim = linformer_dim
self._linformer_shared_kv_projection = linformer_shared_kv_projection
if self._num_kv_heads is not None and self._src_block_size is not None:
raise ValueError(
"Block sparse attention does not support Multi-query attention."
Expand Down Expand Up @@ -383,11 +387,13 @@ def build(self, input_shape):
dtype=tf.float32,
)
if self._linformer_dim is not None:
# Current implementation uses the same weights for keys and values.
# TODO(akandoor): Explore using different weights for keys and values.
if self._linformer_shared_kv_projection:
low_rank_dim = self._linformer_dim
else:
low_rank_dim = 2 * self._linformer_dim
self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
"...bc,cd->...bd",
output_shape=(None, self._linformer_dim),
output_shape=(None, low_rank_dim),
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer
),
Expand Down Expand Up @@ -444,6 +450,8 @@ def get_config(self):
"tgt_block_size": self._tgt_block_size,
"use_sigmoid_attn": self._use_sigmoid_attn,
"sigmoid_attn_bias": self._sigmoid_attn_bias,
"linformer_dim": self._linformer_dim,
"linformer_shared_kv_projection": self._linformer_shared_kv_projection,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -499,6 +507,8 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
if key_value is None:
key_value = input_tensor

key = key_value
value = key_value
if self._linformer_dim is not None:
if attention_mask is not None:
# Applying mask before the low rank factorization so that padding is
Expand All @@ -510,17 +520,28 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
attention_mask = None
key_value = tf.transpose(key_value, [0, 2, 1])
key_value = self._lowrank_kv_projection(key_value)
key_value = tf.transpose(key_value, [0, 2, 1])

if self._linformer_shared_kv_projection:
key_value = tf.transpose(key_value, [0, 2, 1])
key = key_value
value = key_value
else:
key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
if self._return_attention_scores:
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
key=key,
value=value,
attention_mask=attention_mask,
return_attention_scores=True)
return_attention_scores=True,
)
else:
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
query=target_tensor,
key=key,
value=value,
attention_mask=attention_mask,
)
attention_output = self._attention_dropout(attention_output)

if self._norm_first:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,11 @@ def test_block_sparse_attention(self, use_sigmoid_attn):
output_tensor[1].shape.as_list(), expected_attention_scores_shape
)

def test_low_rank_attention(self):
@parameterized.named_parameters(
('unshared_kv_projection', False),
('shared_kv_projection', True),
)
def test_low_rank_attention(self, shared_kv_projection):
num_attention_heads = 8
sequence_length = 21
linformer_dim = 7
Expand All @@ -812,6 +816,7 @@ def test_low_rank_attention(self):
inner_activation='relu',
return_attention_scores=True,
linformer_dim=linformer_dim,
linformer_shared_kv_projection=shared_kv_projection,
)
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf_keras.Input(shape=(sequence_length, width))
Expand Down

0 comments on commit 41f7c1f

Please sign in to comment.