Skip to content

Commit

Permalink
fix bug in attention probe dropout, fix bug in None noise_key passed …
Browse files Browse the repository at this point in the history
…in the probing jit function, add the spliting of noise_keys to two dropout in two cross attention
  • Loading branch information
rxng8 committed Mar 7, 2025
1 parent 23e8c84 commit 695e9d8
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions ngclearn/utils/analysis/attentive_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
score = score.astype(q.dtype) # (B, H, T, S)
if dropout_rate > 0.:
score, _ = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
score, _ = drop_out(dkey, score, rate=dropout_rate) ## NOTE: normally you apply dropout here
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
return attention @ Wout + bout # (B, T, Dq)
Expand Down Expand Up @@ -105,6 +105,8 @@ def run_attention_probe(
Returns:
output scores/probabilities, cross-attention (hidden) features
"""
# Two separate dkeys for each dropout in two cross attention
dkey1, dkey2 = random.split(dkey, 2)
# encoded_image_feature: (B, hw, dim)
#learnable_query, *_params) = params
learnable_query, Wq, bq, Wk, bk, Wv, bv, Wout, bout,\
Expand All @@ -116,14 +118,14 @@ def run_attention_probe(
if use_LN_input:
learnable_query = layer_normalize(learnable_query, ln_in_mu, ln_in_scale)
encodings = layer_normalize(encodings, ln_in_mu2, ln_in_scale2)
features = cross_attention(dkey, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
features = cross_attention(dkey1, cross_attn_params, learnable_query, encodings, mask, n_heads, dropout)
# Perform a single self-attention block here
# Self-Attention
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts)
skip = features
if use_LN:
features = layer_normalize(features, Wlnattn_mu, Wlnattn_scale)
features = cross_attention(dkey, self_attn_params, features, features, None, n_heads, dropout)
features = cross_attention(dkey2, self_attn_params, features, features, None, n_heads, dropout)
features = features + skip
features = features[:, 0] # (B, 1, dim) => (B, dim)
# MLP
Expand Down Expand Up @@ -222,7 +224,7 @@ def __init__(
super().__init__(dkey, batch_size, **kwargs)
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."
assert learnable_query_dim % num_heads == 0, f"`learnable_query_dim` must be divisible by `num_heads`. Got {learnable_query_dim} and {num_heads}."
self.dkey, *subkeys = random.split(self.dkey, 25)
self.dkey, *subkeys = random.split(self.dkey, 26)
self.num_heads = num_heads
self.source_seq_length = source_seq_length
self.input_dim = input_dim
Expand Down Expand Up @@ -287,8 +289,12 @@ def __init__(
self.optim_params = adam.adam_init(self.probe_params)
self.eta = 0.0002 #0.001

# Finally, the dkey for the noise_key
self.noise_key = subkeys[24]

def process(self, embeddings, dkey=None):
noise_key = None
# noise_key = None
noise_key = self.noise_key
if dkey is not None:
dkey, *subkeys = random.split(dkey, 2)
noise_key = subkeys[0]
Expand All @@ -299,7 +305,8 @@ def process(self, embeddings, dkey=None):
return outs

def update(self, embeddings, labels, dkey=None):
noise_key = None
# noise_key = None
noise_key = self.noise_key
if dkey is not None:
dkey, *subkeys = random.split(dkey, 2)
noise_key = subkeys[0]
Expand Down

0 comments on commit 695e9d8

Please sign in to comment.