@@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
70
70
score = jax .nn .softmax (score , axis = - 1 ) # (B, H, T, S)
71
71
score = score .astype (q .dtype ) # (B, H, T, S)
72
72
if dropout_rate > 0. :
73
- score , _ = drop_out (dkey , input = score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
73
+ score , _ = drop_out (dkey , score , rate = dropout_rate ) ## NOTE: normally you apply dropout here
74
74
attention = jnp .einsum ("BHTS,BHSE->BHTE" , score , v ) # (B, T, H, E)
75
75
attention = attention .transpose ([0 , 2 , 1 , 3 ]).reshape ((B , T , - 1 )) # (B, T, H, E) => (B, T, D)
76
76
return attention @ Wout + bout # (B, T, Dq)
@@ -105,6 +105,8 @@ def run_attention_probe(
105
105
Returns:
106
106
output scores/probabilities, cross-attention (hidden) features
107
107
"""
108
+ # Two separate dkeys for each dropout in two cross attention
109
+ dkey1 , dkey2 = random .split (dkey , 2 )
108
110
# encoded_image_feature: (B, hw, dim)
109
111
#learnable_query, *_params) = params
110
112
learnable_query , Wq , bq , Wk , bk , Wv , bv , Wout , bout ,\
@@ -116,14 +118,14 @@ def run_attention_probe(
116
118
if use_LN_input :
117
119
learnable_query = layer_normalize (learnable_query , ln_in_mu , ln_in_scale )
118
120
encodings = layer_normalize (encodings , ln_in_mu2 , ln_in_scale2 )
119
- features = cross_attention (dkey , cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
121
+ features = cross_attention (dkey1 , cross_attn_params , learnable_query , encodings , mask , n_heads , dropout )
120
122
# Perform a single self-attention block here
121
123
# Self-Attention
122
124
self_attn_params = (Wqs , bqs , Wks , bks , Wvs , bvs , Wouts , bouts )
123
125
skip = features
124
126
if use_LN :
125
127
features = layer_normalize (features , Wlnattn_mu , Wlnattn_scale )
126
- features = cross_attention (dkey , self_attn_params , features , features , None , n_heads , dropout )
128
+ features = cross_attention (dkey2 , self_attn_params , features , features , None , n_heads , dropout )
127
129
features = features + skip
128
130
features = features [:, 0 ] # (B, 1, dim) => (B, dim)
129
131
# MLP
@@ -222,7 +224,7 @@ def __init__(
222
224
super ().__init__ (dkey , batch_size , ** kwargs )
223
225
assert attn_dim % num_heads == 0 , f"`attn_dim` must be divisible by `num_heads`. Got { attn_dim } and { num_heads } ."
224
226
assert learnable_query_dim % num_heads == 0 , f"`learnable_query_dim` must be divisible by `num_heads`. Got { learnable_query_dim } and { num_heads } ."
225
- self .dkey , * subkeys = random .split (self .dkey , 25 )
227
+ self .dkey , * subkeys = random .split (self .dkey , 26 )
226
228
self .num_heads = num_heads
227
229
self .source_seq_length = source_seq_length
228
230
self .input_dim = input_dim
@@ -287,8 +289,12 @@ def __init__(
287
289
self .optim_params = adam .adam_init (self .probe_params )
288
290
self .eta = 0.0002 #0.001
289
291
292
+ # Finally, the dkey for the noise_key
293
+ self .noise_key = subkeys [24 ]
294
+
290
295
def process (self , embeddings , dkey = None ):
291
- noise_key = None
296
+ # noise_key = None
297
+ noise_key = self .noise_key
292
298
if dkey is not None :
293
299
dkey , * subkeys = random .split (dkey , 2 )
294
300
noise_key = subkeys [0 ]
@@ -299,7 +305,8 @@ def process(self, embeddings, dkey=None):
299
305
return outs
300
306
301
307
def update (self , embeddings , labels , dkey = None ):
302
- noise_key = None
308
+ # noise_key = None
309
+ noise_key = self .noise_key
303
310
if dkey is not None :
304
311
dkey , * subkeys = random .split (dkey , 2 )
305
312
noise_key = subkeys [0 ]
0 commit comments