Skip to content

Commit 695e9d8

Browse files
committed
fix bug in attention probe dropout, fix bug in None noise_key passed in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention
1 parent 23e8c84 commit 695e9d8

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
7070
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
7171
score = score.astype(q.dtype) # (B, H, T, S)
7272
if dropout_rate > 0.:
73-
score, _ = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
73+
score, _ = drop_out(dkey, score, rate=dropout_rate) ## NOTE: normally you apply dropout here
7474
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
7575
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
7676
return attention @ Wout + bout # (B, T, Dq)
@@ -105,6 +105,8 @@ def run_attention_probe(
105105
Returns:
106106
output scores/probabilities, cross-attention (hidden) features
107107
"""
108+
# Two separate dkeys for each dropout in two cross attention
109+
dkey1, dkey2 = random.split(dkey, 2)
108110
# encoded_image_feature: (B, hw, dim)
109111
#learnable_query, *_params) = params
110112
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\
@@ -116,14 +118,14 @@ def run_attention_probe(
116118
if use_LN_input:
117119
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
118120
encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2)
119-
features = cross_attention(dkey, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
121+
features = cross_attention(dkey1, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
120122
# Perform a single self-attention block here
121123
# Self-Attention
122124
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts)
123125
skip = features
124126
if use_LN:
125127
features = layer_normalize(features, Wlnattn_mu, Wlnattn_scale)
126-
features = cross_attention(dkey, self_attn_params, features, features, None, n_heads, dropout)
128+
features = cross_attention(dkey2, self_attn_params, features, features, None, n_heads, dropout)
127129
features = features + skip
128130
features = features[:, 0] # (B, 1, dim) => (B, dim)
129131
# MLP
@@ -222,7 +224,7 @@ def __init__(
222224
super().__init__(dkey, batch_size, **kwargs)
223225
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
224226
assert learnable_query_dim % num_heads == 0, f"`learnable_query_dim` must be divisible by `num_heads`. Got {learnable_query_dim} and {num_heads}."
225-
self.dkey, *subkeys = random.split(self.dkey, 25)
227+
self.dkey, *subkeys = random.split(self.dkey, 26)
226228
self.num_heads = num_heads
227229
self.source_seq_length = source_seq_length
228230
self.input_dim = input_dim
@@ -287,8 +289,12 @@ def __init__(
287289
self.optim_params = adam.adam_init(self.probe_params)
288290
self.eta = 0.0002 #0.001
289291

292+
# Finally, the dkey for the noise_key
293+
self.noise_key = subkeys[24]
294+
290295
def process(self, embeddings, dkey=None):
291-
noise_key = None
296+
# noise_key = None
297+
noise_key = self.noise_key
292298
if dkey is not None:
293299
dkey, *subkeys = random.split(dkey, 2)
294300
noise_key = subkeys[0]
@@ -299,7 +305,8 @@ def process(self, embeddings, dkey=None):
299305
return outs
300306

301307
def update(self, embeddings, labels, dkey=None):
302-
noise_key = None
308+
# noise_key = None
309+
noise_key = self.noise_key
303310
if dkey is not None:
304311
dkey, *subkeys = random.split(dkey, 2)
305312
noise_key = subkeys[0]

0 commit comments

Comments
 (0)