@@ -45,7 +45,6 @@ class ValueZeroingSimilarityMetric(Enum):
45
45
class ValueZeroingModule (Enum ):
46
46
DECODER = "decoder"
47
47
ENCODER = "encoder"
48
- CROSS = "cross"
49
48
50
49
51
50
class ValueZeroing (InseqAttribution ):
@@ -155,20 +154,26 @@ def compute_modules_post_zeroing_similarity(
155
154
inputs : TensorOrTupleOfTensorsGeneric ,
156
155
additional_forward_args : TensorOrTupleOfTensorsGeneric ,
157
156
hidden_states : MultiLayerEmbeddingsTensor ,
157
+ attention_module_name : str ,
158
+ attributed_seq_len : Optional [int ] = None ,
158
159
similarity_metric : str = ValueZeroingSimilarityMetric .COSINE .value ,
159
160
mode : str = ValueZeroingModule .DECODER .value ,
160
161
zeroed_units_indices : Optional [OneOrMoreIndicesDict ] = None ,
161
- threshold : float = 1e-5 ,
162
+ min_score_threshold : float = 1e-5 ,
163
+ use_causal_mask : bool = False ,
162
164
) -> MultiLayerScoreTensor :
163
165
"""Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block.
164
166
165
167
Args:
166
168
modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for.
167
169
hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean
168
170
counterparts when computing the similarity.
169
- similarity_scores_shape (:obj:`torch.Size`): The shape of the similarity scores tensor to be returned.
171
+ attention_module_name (:obj:`str`): The name of the attention module to zero the values for.
172
+ attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed
173
+ to be the same as the length of the hidden states.
170
174
similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine".
171
175
mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder".
176
+
172
177
zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
173
178
`Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
174
179
that should be zeroed to compute corrupted states.
@@ -179,18 +184,25 @@ def compute_modules_post_zeroing_similarity(
179
184
- If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
180
185
the corresponding layer. Any missing layer will not be zeroed.
181
186
Default: None.
187
+ min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the
188
+ similarity. Default: 1e-5.
189
+ use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False.
182
190
183
191
Returns:
184
192
:obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances
185
193
(1 - similarity score) between original and corrupted states for each layer.
186
194
"""
187
195
if mode == ValueZeroingModule .DECODER .value :
188
196
modules : nn .ModuleList = find_block_stack (self .forward_func .get_decoder ())
189
- batch_size = hidden_states .size (0 )
190
- num_layers = len (modules )
191
- sequence_length = hidden_states .size (2 )
197
+ elif mode == ValueZeroingModule .ENCODER .value :
198
+ modules : nn .ModuleList = find_block_stack (self .forward_func .get_encoder ())
192
199
else :
193
200
raise NotImplementedError (f"Mode { mode } not implemented for value zeroing." )
201
+ if attributed_seq_len is None :
202
+ attributed_seq_len = hidden_states .size (2 )
203
+ batch_size = hidden_states .size (0 )
204
+ generated_seq_len = hidden_states .size (2 )
205
+ num_layers = len (modules )
194
206
195
207
# Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the
196
208
# embedding layer, and we are only interested in the transformer blocks outputs.
@@ -199,7 +211,7 @@ def compute_modules_post_zeroing_similarity(
199
211
}
200
212
# Scores for every layer of the model
201
213
all_scores = torch .ones (
202
- batch_size , num_layers , sequence_length , sequence_length , device = hidden_states .device
214
+ batch_size , num_layers , generated_seq_len , attributed_seq_len , device = hidden_states .device
203
215
) * float ("nan" )
204
216
205
217
# Hooks:
@@ -218,11 +230,11 @@ def compute_modules_post_zeroing_similarity(
218
230
modules [block_idx ].register_forward_hook (states_extract_and_patch_hook )
219
231
)
220
232
# Zeroing is done for every token in the sequence separately (O(n) complexity)
221
- for token_idx in range (sequence_length ):
233
+ for token_idx in range (attributed_seq_len ):
222
234
value_zeroing_hook_handles : list [RemovableHandle ] = []
223
235
# Value zeroing hooks are registered for every token separately since they are token-dependent
224
236
for block_idx , block in enumerate (modules ):
225
- attention_module = block .get_submodule (self . forward_func . config . attention_module )
237
+ attention_module = block .get_submodule (attention_module_name )
226
238
if isinstance (zeroed_units_indices , dict ):
227
239
if block_idx not in zeroed_units_indices :
228
240
continue
@@ -259,19 +271,22 @@ def compute_modules_post_zeroing_similarity(
259
271
for block_idx in range (len (modules )):
260
272
similarity_scores = self .SIMILARITY_METRICS [similarity_metric ](
261
273
self .clean_block_output_states [block_idx ].float (), self .corrupted_block_output_states [block_idx ]
262
- )[:, token_idx :]
263
- all_scores [:, block_idx , token_idx :, token_idx ] = 1 - similarity_scores
274
+ )
275
+ if use_causal_mask :
276
+ all_scores [:, block_idx , token_idx :, token_idx ] = 1 - similarity_scores [:, token_idx :]
277
+ else :
278
+ all_scores [:, block_idx , :, token_idx ] = 1 - similarity_scores
264
279
self .corrupted_block_output_states = {}
265
280
for handle in states_extraction_hook_handles :
266
281
handle .remove ()
267
282
self .clean_block_output_states = {}
268
- all_scores = torch .where (all_scores < threshold , torch .zeros_like (all_scores ), all_scores )
283
+ all_scores = torch .where (all_scores < min_score_threshold , torch .zeros_like (all_scores ), all_scores )
269
284
# Normalize scores to sum to 1
270
- per_token_sum_score = all_scores .sum (dim = - 1 , keepdim = True )
285
+ per_token_sum_score = all_scores .nansum (dim = - 1 , keepdim = True )
271
286
per_token_sum_score [per_token_sum_score == 0 ] = 1
272
287
all_scores = all_scores / per_token_sum_score
273
288
274
- # Final shape: [batch_size, seq_len, seq_len , num_layers]
289
+ # Final shape: [batch_size, attributed_seq_len, generated_seq_len , num_layers]
275
290
return all_scores .permute (0 , 3 , 2 , 1 )
276
291
277
292
def attribute (
@@ -312,18 +327,39 @@ def attribute(
312
327
f"Similarity metric { similarity_metric } not available."
313
328
f"Available metrics: { ',' .join (self .SIMILARITY_METRICS .keys ())} "
314
329
)
330
+
315
331
decoder_scores = self .compute_modules_post_zeroing_similarity (
316
332
inputs = inputs ,
317
333
additional_forward_args = additional_forward_args ,
318
334
hidden_states = decoder_hidden_states ,
335
+ attention_module_name = self .forward_func .config .self_attention_module ,
319
336
similarity_metric = similarity_metric ,
320
337
mode = ValueZeroingModule .DECODER .value ,
321
338
zeroed_units_indices = zeroed_units_indices ,
339
+ use_causal_mask = True ,
322
340
)
323
- return decoder_scores
324
341
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
325
342
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
326
- # if is_encoder_decoder:
327
- # encoder_hidden_states = torch.stack(outputs.encoder_hidden_states)
328
- # encoder = self.forward_func.get_encoder()
329
- # encoder_stack = find_block_stack(encoder)
343
+ if self .forward_func .is_encoder_decoder :
344
+ # TODO: Enable different encoder/decoder/cross zeroing indices
345
+ encoder_scores = self .compute_modules_post_zeroing_similarity (
346
+ inputs = inputs ,
347
+ additional_forward_args = additional_forward_args ,
348
+ hidden_states = encoder_hidden_states ,
349
+ attention_module_name = self .forward_func .config .self_attention_module ,
350
+ similarity_metric = similarity_metric ,
351
+ mode = ValueZeroingModule .ENCODER .value ,
352
+ zeroed_units_indices = zeroed_units_indices ,
353
+ )
354
+ cross_scores = self .compute_modules_post_zeroing_similarity (
355
+ inputs = inputs ,
356
+ additional_forward_args = additional_forward_args ,
357
+ hidden_states = decoder_hidden_states ,
358
+ attributed_seq_len = encoder_hidden_states .size (2 ),
359
+ attention_module_name = self .forward_func .config .cross_attention_module ,
360
+ similarity_metric = similarity_metric ,
361
+ mode = ValueZeroingModule .DECODER .value ,
362
+ zeroed_units_indices = zeroed_units_indices ,
363
+ )
364
+ return encoder_scores , cross_scores , decoder_scores
365
+ return (decoder_scores ,)
0 commit comments