@@ -115,6 +115,7 @@ def __init__(self,
115
115
use_sigmoid_attn = False ,
116
116
sigmoid_attn_bias = None ,
117
117
linformer_dim = None ,
118
+ linformer_shared_kv_projection = True ,
118
119
** kwargs ):
119
120
"""Initializes `TransformerEncoderBlock`.
120
121
@@ -194,6 +195,8 @@ def __init__(self,
194
195
`block_sparse_attention.MultiHeadAttention`
195
196
linformer_dim: Applies low-rank factorization on keys/values as in
196
197
https://arxiv.org/pdf/2006.04768.
198
+ linformer_shared_kv_projection: If set, projection layer is shared for
199
+ keys and values.
197
200
**kwargs: keyword arguments.
198
201
"""
199
202
util .filter_kwargs (kwargs )
@@ -234,6 +237,7 @@ def __init__(self,
234
237
self ._use_sigmoid_attn = use_sigmoid_attn
235
238
self ._sigmoid_attn_bias = sigmoid_attn_bias
236
239
self ._linformer_dim = linformer_dim
240
+ self ._linformer_shared_kv_projection = linformer_shared_kv_projection
237
241
if self ._num_kv_heads is not None and self ._src_block_size is not None :
238
242
raise ValueError (
239
243
"Block sparse attention does not support Multi-query attention."
@@ -383,11 +387,13 @@ def build(self, input_shape):
383
387
dtype = tf .float32 ,
384
388
)
385
389
if self ._linformer_dim is not None :
386
- # Current implementation uses the same weights for keys and values.
387
- # TODO(akandoor): Explore using different weights for keys and values.
390
+ if self ._linformer_shared_kv_projection :
391
+ low_rank_dim = self ._linformer_dim
392
+ else :
393
+ low_rank_dim = 2 * self ._linformer_dim
388
394
self ._lowrank_kv_projection = tf_keras .layers .EinsumDense (
389
395
"...bc,cd->...bd" ,
390
- output_shape = (None , self . _linformer_dim ),
396
+ output_shape = (None , low_rank_dim ),
391
397
kernel_initializer = tf_utils .clone_initializer (
392
398
self ._kernel_initializer
393
399
),
@@ -444,6 +450,8 @@ def get_config(self):
444
450
"tgt_block_size" : self ._tgt_block_size ,
445
451
"use_sigmoid_attn" : self ._use_sigmoid_attn ,
446
452
"sigmoid_attn_bias" : self ._sigmoid_attn_bias ,
453
+ "linformer_dim" : self ._linformer_dim ,
454
+ "linformer_shared_kv_projection" : self ._linformer_shared_kv_projection ,
447
455
}
448
456
base_config = super ().get_config ()
449
457
return dict (list (base_config .items ()) + list (config .items ()))
@@ -499,6 +507,8 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
499
507
if key_value is None :
500
508
key_value = input_tensor
501
509
510
+ key = key_value
511
+ value = key_value
502
512
if self ._linformer_dim is not None :
503
513
if attention_mask is not None :
504
514
# Applying mask before the low rank factorization so that padding is
@@ -510,17 +520,28 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
510
520
attention_mask = None
511
521
key_value = tf .transpose (key_value , [0 , 2 , 1 ])
512
522
key_value = self ._lowrank_kv_projection (key_value )
513
- key_value = tf .transpose (key_value , [0 , 2 , 1 ])
514
-
523
+ if self ._linformer_shared_kv_projection :
524
+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
525
+ key = key_value
526
+ value = key_value
527
+ else :
528
+ key = tf .transpose (key_value [:, :, :self ._linformer_dim ], [0 , 2 , 1 ])
529
+ value = tf .transpose (key_value [:, :, self ._linformer_dim :], [0 , 2 , 1 ])
515
530
if self ._return_attention_scores :
516
531
attention_output , attention_scores = self ._attention_layer (
517
532
query = target_tensor ,
518
- value = key_value ,
533
+ key = key ,
534
+ value = value ,
519
535
attention_mask = attention_mask ,
520
- return_attention_scores = True )
536
+ return_attention_scores = True ,
537
+ )
521
538
else :
522
539
attention_output = self ._attention_layer (
523
- query = target_tensor , value = key_value , attention_mask = attention_mask )
540
+ query = target_tensor ,
541
+ key = key ,
542
+ value = value ,
543
+ attention_mask = attention_mask ,
544
+ )
524
545
attention_output = self ._attention_dropout (attention_output )
525
546
526
547
if self ._norm_first :
0 commit comments