Skip to content

Commit efa9ca3

Browse files
committed
customizable zeroing indexes, model config complete
1 parent 6c4d1df commit efa9ca3

File tree

4 files changed

+104
-41
lines changed

4 files changed

+104
-41
lines changed

inseq/attr/feat/feature_attribution.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,6 @@ def attribute(
514514
attributions=attribution_outputs,
515515
tokenized_target_sentences=target_tokens_with_ids,
516516
pad_token=self.attribution_model.pad_token,
517-
has_bos_token=self.attribution_model.is_encoder_decoder,
518517
attr_pos_end=attr_pos_end,
519518
),
520519
step_attributions=attribution_outputs if output_step_attributions else None,

inseq/attr/feat/ops/value_zeroing.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ def attribute(
294294
inputs: TensorOrTupleOfTensorsGeneric,
295295
additional_forward_args: TensorOrTupleOfTensorsGeneric,
296296
similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
297-
zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
297+
encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
298+
decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
299+
cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
298300
encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
299301
decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
300302
) -> TensorOrTupleOfTensorsGeneric:
@@ -312,7 +314,8 @@ def attribute(
312314
- If a list of integers, the attention heads in the list are zeroed across all layers.
313315
- If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
314316
the corresponding layer.
315-
Default: None.
317+
318+
Default: None (all heads are zeroed for every layer).
316319
encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
317320
source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for
318321
encoder-decoders models. Default: None.
@@ -335,21 +338,20 @@ def attribute(
335338
attention_module_name=self.forward_func.config.self_attention_module,
336339
similarity_metric=similarity_metric,
337340
mode=ValueZeroingModule.DECODER.value,
338-
zeroed_units_indices=zeroed_units_indices,
341+
zeroed_units_indices=decoder_zeroed_units_indices,
339342
use_causal_mask=True,
340343
)
341344
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
342345
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
343346
if self.forward_func.is_encoder_decoder:
344-
# TODO: Enable different encoder/decoder/cross zeroing indices
345347
encoder_scores = self.compute_modules_post_zeroing_similarity(
346348
inputs=inputs,
347349
additional_forward_args=additional_forward_args,
348350
hidden_states=encoder_hidden_states,
349351
attention_module_name=self.forward_func.config.self_attention_module,
350352
similarity_metric=similarity_metric,
351353
mode=ValueZeroingModule.ENCODER.value,
352-
zeroed_units_indices=zeroed_units_indices,
354+
zeroed_units_indices=encoder_zeroed_units_indices,
353355
)
354356
cross_scores = self.compute_modules_post_zeroing_similarity(
355357
inputs=inputs,
@@ -359,7 +361,12 @@ def attribute(
359361
attention_module_name=self.forward_func.config.cross_attention_module,
360362
similarity_metric=similarity_metric,
361363
mode=ValueZeroingModule.DECODER.value,
362-
zeroed_units_indices=zeroed_units_indices,
364+
zeroed_units_indices=cross_zeroed_units_indices,
363365
)
364366
return encoder_scores, cross_scores, decoder_scores
367+
elif encoder_zeroed_units_indices is not None or cross_zeroed_units_indices is not None:
368+
logger.warning(
369+
"Zeroing indices for encoder and cross-attentions were specified, but the model is not an "
370+
"encoder-decoder. Use `decoder_zeroed_units_indices` to parametrize zeroing for the decoder module."
371+
)
365372
return (decoder_scores,)

inseq/data/attribution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def from_step_attributions(
181181
attributions: list["FeatureAttributionStepOutput"],
182182
tokenized_target_sentences: list[list[TokenWithId]],
183183
pad_token: Optional[Any] = None,
184-
has_bos_token: bool = True,
185184
attr_pos_end: Optional[int] = None,
186185
) -> list["FeatureAttributionSequenceOutput"]:
187186
"""Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple
@@ -208,9 +207,9 @@ def from_step_attributions(
208207
sources.append(drop_padding(attr.source[seq_idx], pad_token))
209208
curr_target = [a.target[seq_idx][0] for a in attributions]
210209
targets.append(drop_padding(curr_target, pad_token))
211-
if has_bos_token:
210+
if all(attr.prefix[seq_idx][0] == pad_token for seq_idx in range(num_sequences)):
212211
tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding(
213-
tokenized_target_sentences[seq_idx], pad_token
212+
tokenized_target_sentences[seq_idx][1:], pad_token
214213
)
215214
else:
216215
tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)

inseq/models/model_config.yaml

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
1-
# AutoModelForCausalLM
1+
# Decoder-only models
2+
BioGptForCausalLM:
3+
self_attention_module: "self_attn"
4+
value_vector: "value_states"
25
BloomForCausalLM:
36
self_attention_module: "self_attention"
47
value_vector: "value_layer"
8+
CodeGenForCausalLM:
9+
self_attention_module: "attn"
10+
value_vector: "value"
11+
FalconForCausalLM:
12+
self_attention_module: "self_attention"
13+
value_vector: "value_layer"
14+
GemmaForCausalLM:
15+
self_attention_module: "self_attn"
16+
value_vector: "value_states"
17+
GPTBigCodeForCausalLM:
18+
self_attention_module: "attn"
19+
value_vector: "value"
20+
GPTJForCausalLM:
21+
self_attention_module: "attn"
22+
value_vector: "value"
523
GPT2LMHeadModel:
624
self_attention_module: "attn"
725
value_vector: "value"
8-
OpenAIGPTLMHeadModel:
26+
GPTNeoForCausalLM:
927
self_attention_module: "attn"
1028
value_vector: "value"
1129
GPTNeoXForCausalLM:
@@ -14,40 +32,80 @@ GPTNeoXForCausalLM:
1432
LlamaForCausalLM:
1533
self_attention_module: "self_attn"
1634
value_vector: "value_states"
17-
GPTBigCodeForCausalLM:
35+
MistralForCausalLM:
36+
self_attention_module: "self_attn"
37+
value_vector: "value_states"
38+
MixtralForCausalLM:
39+
self_attention_module: "self_attn"
40+
value_vector: "value_states"
41+
MptForCausalLM:
1842
self_attention_module: "attn"
19-
value_vector: "value"
20-
CodeGenForCausalLM:
43+
value_vector: "value_states"
44+
OpenAIGPTLMHeadModel:
2145
self_attention_module: "attn"
2246
value_vector: "value"
23-
# TODO
24-
# BioGptForCausalLM
25-
# GemmaForCausalLM
26-
# GPTNeoForCausalLM
27-
# GPTJForCausalLM
28-
# MistralForCausalLM
29-
# MixtralForCausalLM
30-
# MptForCausalLM
31-
# OpenLlamaForCausalLM
32-
# OPTForCausalLM
33-
# PhiForCausalLM
34-
# StableLmForCausalLM
35-
# XGLMForCausalLM
47+
OPTForCausalLM:
48+
self_attention_module: "self_attn"
49+
value_vector: "value_states"
50+
PhiForCausalLM:
51+
self_attention_module: "self_attn"
52+
value_vector: "value_states"
53+
Qwen2ForCausalLM:
54+
self_attention_module: "self_attn"
55+
value_vector: "value_states"
56+
StableLmForCausalLM:
57+
self_attention_module: "self_attn"
58+
value_vector: "value_states"
59+
XGLMForCausalLM:
60+
self_attention_module: "self_attn"
61+
value_vector: "value_states"
3662

37-
# AutoModelForSeq2SeqLM
63+
# Encoder-decoder models
64+
BartForConditionalGeneration:
65+
self_attention_module: "self_attn"
66+
cross_attention_module: "encoder_attn"
67+
value_vector: "value_states"
3868
MarianMTModel:
3969
self_attention_module: "self_attn"
4070
cross_attention_module: "encoder_attn"
4171
value_vector: "value_states"
42-
43-
# TODO ForConditionalGeneration
44-
# BartForConditionalGeneration
45-
# FSMTForConditionalGeneration
46-
# LongT5ForConditionalGeneration
47-
# M2M100ForConditionalGeneration
48-
# MBartForConditionalGeneration
49-
# MT5ForConditionalGeneration
50-
# NllbMoeForConditionalGeneration
51-
# SeamlessM4TForTextToText
52-
# SeamlessM4Tv2ForTextToText
53-
# T5ForConditionalGeneration
72+
FSMTForConditionalGeneration:
73+
self_attention_module: "self_attn"
74+
cross_attention_module: "encoder_attn"
75+
value_vector: "v"
76+
M2M100ForConditionalGeneration:
77+
self_attention_module: "self_attn"
78+
cross_attention_module: "encoder_attn"
79+
value_vector: "value_states"
80+
MBartForConditionalGeneration:
81+
self_attention_module: "self_attn"
82+
cross_attention_module: "encoder_attn"
83+
value_vector: "value_states"
84+
MT5ForConditionalGeneration:
85+
self_attention_module: "SelfAttention"
86+
cross_attention_module: "EncDecAttention"
87+
value_vector: "value_states"
88+
NllbMoeForConditionalGeneration:
89+
self_attention_module: "self_attn"
90+
cross_attention_module: "cross_attention"
91+
value_vector: "value_states"
92+
PegasusForConditionalGeneration:
93+
self_attention_module: "self_attn"
94+
cross_attention_module: "encoder_attn"
95+
value_vector: "value_states"
96+
SeamlessM4TForTextToText:
97+
self_attention_module: "self_attn"
98+
cross_attention_module: "cross_attention"
99+
value_vector: "value"
100+
SeamlessM4Tv2ForTextToText:
101+
self_attention_module: "self_attn"
102+
cross_attention_module: "cross_attention"
103+
value_vector: "value"
104+
T5ForConditionalGeneration:
105+
self_attention_module: "SelfAttention"
106+
cross_attention_module: "EncDecAttention"
107+
value_vector: "value_states"
108+
UMT5ForConditionalGeneration:
109+
self_attention_module: "SelfAttention"
110+
cross_attention_module: "EncDecAttention"
111+
value_vector: "value_states"

0 commit comments

Comments
 (0)