Skip to content

Commit 6c4d1df

Browse files
committed
VZ working for encoder-decoder
1 parent ea78f50 commit 6c4d1df

File tree

11 files changed

+151
-77
lines changed

11 files changed

+151
-77
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
88
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
9+
- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
10+
- Added `rollout` (`inseq.data.aggregation_functions.RolloutAggregationFunction`) aggregation function for `SequenceAttributionAggregator` class ([#173](https://github.com/inseq-team/inseq/pull/173)).
11+
- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).
912

1013
## 🔧 Fixes & Refactoring
1114

@@ -26,4 +29,5 @@
2629

2730
## 💥 Breaking Changes
2831

29-
*No changes*
32+
- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution).
33+
- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173))

inseq/attr/feat/attribution_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,15 @@ def extract_args(
144144
def get_source_target_attributions(
145145
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
146146
is_encoder_decoder: bool,
147+
has_sequence_scores: bool = False,
147148
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
148149
if isinstance(attr, tuple):
149150
if is_encoder_decoder:
150-
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
151+
if has_sequence_scores:
152+
return (attr[0], attr[1], attr[2])
153+
else:
154+
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
151155
else:
152-
return (None, attr[0])
156+
return (None, None, attr[0]) if has_sequence_scores else (None, attr[0])
153157
else:
154158
return (attr, None) if is_encoder_decoder else (None, attr)

inseq/attr/feat/ops/value_zeroing.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ class ValueZeroingSimilarityMetric(Enum):
4545
class ValueZeroingModule(Enum):
4646
DECODER = "decoder"
4747
ENCODER = "encoder"
48-
CROSS = "cross"
4948

5049

5150
class ValueZeroing(InseqAttribution):
@@ -155,20 +154,26 @@ def compute_modules_post_zeroing_similarity(
155154
inputs: TensorOrTupleOfTensorsGeneric,
156155
additional_forward_args: TensorOrTupleOfTensorsGeneric,
157156
hidden_states: MultiLayerEmbeddingsTensor,
157+
attention_module_name: str,
158+
attributed_seq_len: Optional[int] = None,
158159
similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
159160
mode: str = ValueZeroingModule.DECODER.value,
160161
zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
161-
threshold: float = 1e-5,
162+
min_score_threshold: float = 1e-5,
163+
use_causal_mask: bool = False,
162164
) -> MultiLayerScoreTensor:
163165
"""Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block.
164166
165167
Args:
166168
modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for.
167169
hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean
168170
counterparts when computing the similarity.
169-
similarity_scores_shape (:obj:`torch.Size`): The shape of the similarity scores tensor to be returned.
171+
attention_module_name (:obj:`str`): The name of the attention module to zero the values for.
172+
attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed
173+
to be the same as the length of the hidden states.
170174
similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine".
171175
mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder".
176+
172177
zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
173178
`Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
174179
that should be zeroed to compute corrupted states.
@@ -179,18 +184,25 @@ def compute_modules_post_zeroing_similarity(
179184
- If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
180185
the corresponding layer. Any missing layer will not be zeroed.
181186
Default: None.
187+
min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the
188+
similarity. Default: 1e-5.
189+
use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False.
182190
183191
Returns:
184192
:obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances
185193
(1 - similarity score) between original and corrupted states for each layer.
186194
"""
187195
if mode == ValueZeroingModule.DECODER.value:
188196
modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder())
189-
batch_size = hidden_states.size(0)
190-
num_layers = len(modules)
191-
sequence_length = hidden_states.size(2)
197+
elif mode == ValueZeroingModule.ENCODER.value:
198+
modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder())
192199
else:
193200
raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.")
201+
if attributed_seq_len is None:
202+
attributed_seq_len = hidden_states.size(2)
203+
batch_size = hidden_states.size(0)
204+
generated_seq_len = hidden_states.size(2)
205+
num_layers = len(modules)
194206

195207
# Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the
196208
# embedding layer, and we are only interested in the transformer blocks outputs.
@@ -199,7 +211,7 @@ def compute_modules_post_zeroing_similarity(
199211
}
200212
# Scores for every layer of the model
201213
all_scores = torch.ones(
202-
batch_size, num_layers, sequence_length, sequence_length, device=hidden_states.device
214+
batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device
203215
) * float("nan")
204216

205217
# Hooks:
@@ -218,11 +230,11 @@ def compute_modules_post_zeroing_similarity(
218230
modules[block_idx].register_forward_hook(states_extract_and_patch_hook)
219231
)
220232
# Zeroing is done for every token in the sequence separately (O(n) complexity)
221-
for token_idx in range(sequence_length):
233+
for token_idx in range(attributed_seq_len):
222234
value_zeroing_hook_handles: list[RemovableHandle] = []
223235
# Value zeroing hooks are registered for every token separately since they are token-dependent
224236
for block_idx, block in enumerate(modules):
225-
attention_module = block.get_submodule(self.forward_func.config.attention_module)
237+
attention_module = block.get_submodule(attention_module_name)
226238
if isinstance(zeroed_units_indices, dict):
227239
if block_idx not in zeroed_units_indices:
228240
continue
@@ -259,19 +271,22 @@ def compute_modules_post_zeroing_similarity(
259271
for block_idx in range(len(modules)):
260272
similarity_scores = self.SIMILARITY_METRICS[similarity_metric](
261273
self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx]
262-
)[:, token_idx:]
263-
all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores
274+
)
275+
if use_causal_mask:
276+
all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:]
277+
else:
278+
all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores
264279
self.corrupted_block_output_states = {}
265280
for handle in states_extraction_hook_handles:
266281
handle.remove()
267282
self.clean_block_output_states = {}
268-
all_scores = torch.where(all_scores < threshold, torch.zeros_like(all_scores), all_scores)
283+
all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores)
269284
# Normalize scores to sum to 1
270-
per_token_sum_score = all_scores.sum(dim=-1, keepdim=True)
285+
per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True)
271286
per_token_sum_score[per_token_sum_score == 0] = 1
272287
all_scores = all_scores / per_token_sum_score
273288

274-
# Final shape: [batch_size, seq_len, seq_len, num_layers]
289+
# Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers]
275290
return all_scores.permute(0, 3, 2, 1)
276291

277292
def attribute(
@@ -312,18 +327,39 @@ def attribute(
312327
f"Similarity metric {similarity_metric} not available."
313328
f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}"
314329
)
330+
315331
decoder_scores = self.compute_modules_post_zeroing_similarity(
316332
inputs=inputs,
317333
additional_forward_args=additional_forward_args,
318334
hidden_states=decoder_hidden_states,
335+
attention_module_name=self.forward_func.config.self_attention_module,
319336
similarity_metric=similarity_metric,
320337
mode=ValueZeroingModule.DECODER.value,
321338
zeroed_units_indices=zeroed_units_indices,
339+
use_causal_mask=True,
322340
)
323-
return decoder_scores
324341
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
325342
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
326-
# if is_encoder_decoder:
327-
# encoder_hidden_states = torch.stack(outputs.encoder_hidden_states)
328-
# encoder = self.forward_func.get_encoder()
329-
# encoder_stack = find_block_stack(encoder)
343+
if self.forward_func.is_encoder_decoder:
344+
# TODO: Enable different encoder/decoder/cross zeroing indices
345+
encoder_scores = self.compute_modules_post_zeroing_similarity(
346+
inputs=inputs,
347+
additional_forward_args=additional_forward_args,
348+
hidden_states=encoder_hidden_states,
349+
attention_module_name=self.forward_func.config.self_attention_module,
350+
similarity_metric=similarity_metric,
351+
mode=ValueZeroingModule.ENCODER.value,
352+
zeroed_units_indices=zeroed_units_indices,
353+
)
354+
cross_scores = self.compute_modules_post_zeroing_similarity(
355+
inputs=inputs,
356+
additional_forward_args=additional_forward_args,
357+
hidden_states=decoder_hidden_states,
358+
attributed_seq_len=encoder_hidden_states.size(2),
359+
attention_module_name=self.forward_func.config.cross_attention_module,
360+
similarity_metric=similarity_metric,
361+
mode=ValueZeroingModule.DECODER.value,
362+
zeroed_units_indices=zeroed_units_indices,
363+
)
364+
return encoder_scores, cross_scores, decoder_scores
365+
return (decoder_scores,)

inseq/attr/feat/perturbation_attribution.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,25 @@ def attribute_step(
144144
attribution_args: dict[str, Any] = {},
145145
) -> MultiDimensionalFeatureAttributionStepOutput:
146146
attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
147-
source_attributions, target_attributions = get_source_target_attributions(
148-
attr, self.attribution_model.is_encoder_decoder
147+
encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions(
148+
attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True
149149
)
150+
sequence_scores = {}
151+
if self.attribution_model.is_encoder_decoder:
152+
if len(attribute_fn_main_args["inputs"]) > 1:
153+
target_attributions = decoder_self_scores.to("cpu")
154+
else:
155+
target_attributions = None
156+
sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
157+
sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
158+
return MultiDimensionalFeatureAttributionStepOutput(
159+
source_attributions=decoder_cross_scores.to("cpu"),
160+
target_attributions=target_attributions,
161+
sequence_scores=sequence_scores,
162+
_num_dimensions=1, # num_layers
163+
)
150164
return MultiDimensionalFeatureAttributionStepOutput(
151-
source_attributions=source_attributions,
152-
target_attributions=target_attributions,
165+
source_attributions=None,
166+
target_attributions=decoder_self_scores,
153167
_num_dimensions=1, # num_layers
154168
)

inseq/data/attribution.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,11 @@ def from_step_attributions(
209209
curr_target = [a.target[seq_idx][0] for a in attributions]
210210
targets.append(drop_padding(curr_target, pad_token))
211211
if has_bos_token:
212-
tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][1:]
213-
tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
212+
tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding(
213+
tokenized_target_sentences[seq_idx], pad_token
214+
)
215+
else:
216+
tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
214217
if attr_pos_end is None:
215218
attr_pos_end = max(len(t) for t in tokenized_target_sentences)
216219
for seq_idx in range(num_sequences):
@@ -238,8 +241,6 @@ def from_step_attributions(
238241
[att.target_attributions for att in attributions], padding_dims=[1]
239242
)
240243
for seq_id in range(num_sequences):
241-
if has_bos_token:
242-
target_attributions[seq_id] = target_attributions[seq_id][1:, ...]
243244
start_idx = max(pos_start) - pos_start[seq_id]
244245
end_idx = start_idx + len(tokenized_target_sentences[seq_id])
245246
target_attributions[seq_id] = target_attributions[seq_id][

inseq/models/model_config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from dataclasses import dataclass
33
from pathlib import Path
4+
from typing import Optional
45

56
import yaml
67

@@ -12,18 +13,23 @@ class ModelConfig:
1213
"""Configuration used by the methods for which the attribute ``use_model_config=True``.
1314
1415
Args:
15-
attention_module (:obj:`str`):
16-
The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in
17-
transformers). Can be identified by looking at the name of the attribute instantiating the attention module
16+
self_attention_module (:obj:`str`):
17+
The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in
18+
transformers). Can be identified by looking at the name of the self-attention module attribute
1819
in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2).
20+
cross_attention_module (:obj:`str`):
21+
The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models
22+
in transformers). Can be identified by looking at the name of the cross-attention module attribute
23+
in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`).
1924
value_vector (:obj:`str`):
2025
The name of the variable in the forward pass of the attention module containing the value vector
2126
(e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of
2227
the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2).
2328
"""
2429

25-
attention_module: str
30+
self_attention_module: str
2631
value_vector: str
32+
cross_attention_module: Optional[str] = None
2733

2834

2935
MODEL_CONFIGS = {

inseq/models/model_config.yaml

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,53 @@
1+
# AutoModelForCausalLM
2+
BloomForCausalLM:
3+
self_attention_module: "self_attention"
4+
value_vector: "value_layer"
15
GPT2LMHeadModel:
2-
attention_module: "attn"
6+
self_attention_module: "attn"
37
value_vector: "value"
48
OpenAIGPTLMHeadModel:
5-
attention_module: "attn"
9+
self_attention_module: "attn"
610
value_vector: "value"
711
GPTNeoXForCausalLM:
8-
attention_module: "attention"
12+
self_attention_module: "attention"
913
value_vector: "value"
10-
BloomForCausalLM:
11-
attention_module: "self_attention"
12-
value_vector: "value_layer"
1314
LlamaForCausalLM:
14-
attention_module: "self_attn"
15+
self_attention_module: "self_attn"
1516
value_vector: "value_states"
1617
GPTBigCodeForCausalLM:
17-
attention_module: "attn"
18+
self_attention_module: "attn"
1819
value_vector: "value"
1920
CodeGenForCausalLM:
20-
attention_module: "attn"
21+
self_attention_module: "attn"
2122
value_vector: "value"
22-
23-
# TODO ForCausalLM
23+
# TODO
24+
# BioGptForCausalLM
25+
# GemmaForCausalLM
2426
# GPTNeoForCausalLM
2527
# GPTJForCausalLM
28+
# MistralForCausalLM
29+
# MixtralForCausalLM
30+
# MptForCausalLM
31+
# OpenLlamaForCausalLM
2632
# OPTForCausalLM
33+
# PhiForCausalLM
34+
# StableLmForCausalLM
2735
# XGLMForCausalLM
28-
# BioGptForCausalLM
29-
# XLNetLMHeadModel
36+
37+
# AutoModelForSeq2SeqLM
38+
MarianMTModel:
39+
self_attention_module: "self_attn"
40+
cross_attention_module: "encoder_attn"
41+
value_vector: "value_states"
3042

3143
# TODO ForConditionalGeneration
3244
# BartForConditionalGeneration
33-
# BlenderbotForConditionalGeneration
34-
# T5ForConditionalGeneration
35-
# MarianMTModel
36-
# LongT5ForConditionalGeneration
3745
# FSMTForConditionalGeneration
46+
# LongT5ForConditionalGeneration
3847
# M2M100ForConditionalGeneration
3948
# MBartForConditionalGeneration
40-
# PegasusForConditionalGeneration
41-
# ProphetNetForConditionalGeneration
42-
# LEDForConditionalGeneration
43-
# BigBirdPegasusForConditionalGeneration
44-
# PLBartForConditionalGeneration
45-
# SwitchTransformerForConditionalGeneration
49+
# MT5ForConditionalGeneration
4650
# NllbMoeForConditionalGeneration
51+
# SeamlessM4TForTextToText
52+
# SeamlessM4Tv2ForTextToText
53+
# T5ForConditionalGeneration

tests/attr/feat/test_feature_attribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu
7575
"orig_tgt": "I soldati della pace ONU",
7676
"contrast_tgt": "Le forze militari di pace delle Nazioni Unite",
7777
"alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]],
78-
"aligned_tgts": ["▁Le → ▁I", "▁forze → ▁soldati", "▁di → ▁della", "▁pace", "▁Nazioni → ▁ONU", "</s>"],
78+
"aligned_tgts": ["<pad>", "▁Le → ▁I", "▁forze → ▁soldati", "▁di → ▁della", "▁pace", "▁Nazioni → ▁ONU", "</s>"],
7979
}
8080
out = saliency_mt_model_larger.attribute(
8181
aligned["src"],

0 commit comments

Comments
 (0)