-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OcclusionAttribution and LimeAttribution #145
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @nfelnlp, thank you for submitting a PR! We will respond as soon as possible.
inseq/attr/feat/occlusion.py
Outdated
self, | ||
batch: EncoderDecoderBatch, | ||
target_ids: TargetIdsTensor, | ||
attribute_target: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why I thought that custom attribution functions were only applicable to gradient-based methods, but in theory, there should be no problem in using them for captum.attr.Occlusion
since it accepts a forward_func
as input. Parameters should be updated here to reflect the changes merged from #138.
if "sliding_window_shapes" not in attribution_args: | ||
# Sliding window shapes is defined as a tuple | ||
# First entry is between 1 and length of input | ||
# Second entry is given by the max length of the underlying model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit puzzled by the second entry: the max length you take here is the max generation length that the model can handle, but my understanding was that this would be the hidden_size
from the model config, to ensure that there is no partial masking of token embeddings. Could you clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I accidentally took the next best attribute in self.attribution_model
that had 512 as the size. This was careless.
The second entry should rather be based on the embedding size, right?
Does accessing it via self.attribution_model.get_embedding_layer()
make sense?
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do perturbation attributions have shape [attributed_text_length, generated_text_length, hidden_size]
like the ones generated by gradient methods? sum_normalize_attributions
ensures to cast the 3D tensor above to a 2D tensor for visualization, but I thought that for occlusion this wouldn't be needed.
If indeed it is not the case, then we would not need a specific class for PerturbationFeatureAttribution
methods and we could simply stick to the base FeatureAttributionSequenceOutput
and FeatureAttributionStepOutput
for the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when I left out this aggregation step from the __post_init__
, I had a 3D tensor that resulted in a shape violation here:
inseq/inseq/data/aggregator.py
Lines 169 to 172 in 000aac5
if attr.source_attributions is not None: | |
assert len(attr.source_attributions.shape) == 2 | |
if attr.target_attributions is not None: | |
assert len(attr.target_attributions.shape) == 2 |
I assume this will apply to other perturbation methods as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you to have the return object to be of FeatureAttributionSequenceOutput
and FeatureAttributionStepOutput
for now.
.pre-commit-config.yaml
Outdated
@@ -5,7 +5,7 @@ default_stages: [commit, push] | |||
|
|||
repos: | |||
- repo: https://github.com/pre-commit/pre-commit-hooks | |||
rev: v4.3.0 | |||
rev: v4.4.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be already in the main branch if you merged main!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this was an accident. Thanks for pointing it out!
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PerturbationMethodRegistry(FeatureAttribution, Registry): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it PerturbationAttribution to keep it consistent with GradientAttribution. We might want to bulk change them to add the Registry specification at a later time though!
|
||
|
||
class PerturbationMethodRegistry(FeatureAttribution, Registry): | ||
"""Occlusion-based attribution methods.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to """Perturbation-based attribution method registry."""
It's not clear to me yet which dimensions the original and perturbed tensor should have. Also, I wasn't sure how to apply the perturbation (mask) to the 3D tensor, especially in cases with a batch size larger than 1. |
Update on LIME: The So it makes sense to handle 2D tensors at a time. I thought of having this loop inside attrs = []
for b, batch in enumerate(attribute_fn_main_args['inputs'][0]):
single_input = tuple(
[inp[b] if type(inp) == torch.Tensor else inp
for inp in attribute_fn_main_args['inputs']]
)
single_additional_forward_args = tuple(
[arg[b] if type(arg) == torch.Tensor else arg
for arg in attribute_fn_main_args['additional_forward_args']]
)
single_attribute_fn_main_args = {
'inputs': single_input,
'additional_forward_args': single_additional_forward_args,
}
single_attr = self.method.attribute(
**single_attribute_fn_main_args,
**attribution_args,
)
attrs.append(single_attr)
attr = torch.stack([single_a for single_a in attrs], dim=0) Would you recommend following this path or should I handle unpacking the batch in other ways, @gsarti ? Maybe overriding the The error message I get after above change (putting the examples in a batch through ValueError: not enough values to unpack (expected 2, got 1)Traceback (most recent call last):
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/attr/_core/lime.py", line 479, in attribute
model_out = self._evaluate_batch(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/attr/_core/lime.py", line 535, in _evaluate_batch
model_out = _run_forward(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/captum/_utils/common.py", line 456, in _run_forward
output = forward_func(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nfel/PycharmProjects/inseq/inseq/models/encoder_decoder.py", line 273, in forward
output = self.get_forward_output(
File "/home/nfel/PycharmProjects/inseq/inseq/models/encoder_decoder.py", line 248, in get_forward_output
return self.model(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 1444, in forward
outputs = self.model(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 1224, in forward
encoder_outputs = self.encoder(
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 751, in forward
embed_pos = self.embed_positions(input_shape)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/nfel/.cache/pypoetry/virtualenvs/inseq-wcIi0sem-py3.10/lib/python3.10/site-packages/transformers/models/marian/modeling_marian.py", line 136, in forward
bsz, seq_len = input_ids_shape[:2]
ValueError: not enough values to unpack (expected 2, got 1)
Process finished with exit code 1 The underlying model ( Modified token_similarity_kernel and perturb_func to work with the current instance-wise attribution@staticmethod
def token_similarity_kernel(
original_input: tuple,
perturbed_input: tuple,
perturbed_interpretable_input: torch.Tensor,
**kwargs,
) -> torch.Tensor:
original_input_tensor = original_input[0]
perturbed_input_tensor = perturbed_input[0]
assert original_input_tensor.shape == perturbed_input_tensor.shape
similarity = torch.sum(original_input_tensor == perturbed_input_tensor)/len(original_input_tensor)
return similarity
def perturb_func(
self,
original_input: tuple, # always needs to be last argument before **kwargs due to "partial"
**kwargs: Any,
) -> tuple:
"""
Sampling function
"""
original_input_tensor = original_input[0]
mask = torch.randint(low=0, high=2, size=original_input_tensor.size()).to(self.attribution_model.device)
perturbed_input = (
original_input_tensor * mask + (1 - mask) * self.attribution_model.tokenizer.pad_token_id
)
perturbed_input_tuple = tuple({perturbed_input})
return perturbed_input_tuple There seem to be many ways to make this work, but I haven't found a clean and safe way yet. |
I implemented a rough solution for reshaping the 3D tensor into a 2D one for LimeBase to handle. The problem was that the linear model in LimeBase expects a 2D tensor of shape (n_samples x "everything else"), so my idea was to apply inseq/inseq/attr/feat/ops/lime.py Lines 261 to 278 in 3ee678f
Do you have a suggestion on what might be the problem? |
LIME is ready for testing. I used the perturb_func implemented in Thermostat which gave me more sensible results. |
LIME results are still somewhat strange. I'm not sure if all values are supposed to be positive. |
Hey @nfelnlp, I finally had time to start reviewing the perturbation methods, here are some thoughts: Occlusion
Lime
GradientSHAP
|
…ansform functions and uses UNK as default mask token. Renamed GradientSHAP.
Thank you very much for your feedback! I think we're approaching a publishable version for this branch. In Occlusion, I managed to correct the The remaining points (second and third in Occlusion and second in LIME) have require more time to investigate and implement. Let me know what you think of the proposed changes and how close we are to merging. Thanks a lot! |
If I left the aggregation (using File "/home/nfel/PycharmProjects/inseq/test.py", line 19, in <module>
mt_out.show()
File "/home/nfel/PycharmProjects/inseq/inseq/data/attribution.py", line 506, in show
attr.show(min_val, max_val, display, return_html, aggregator, **kwargs)
File "/home/nfel/PycharmProjects/inseq/inseq/data/attribution.py", line 200, in show
aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self
File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 130, in aggregate
return aggregator.aggregate(self, **kwargs)
File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 161, in aggregate
aggregated = super().aggregate(attr, aggregate_fn=aggregate_fn, **kwargs)
File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 74, in aggregate
cls.end_aggregation_hook(aggregated, **kwargs)
File "/home/nfel/PycharmProjects/inseq/inseq/data/aggregator.py", line 174, in end_aggregation_hook
assert len(attr.source_attributions.shape) == 2
AssertionError Does it mean that
|
The checks in |
(copied from the comment in the code I added with the last commit) |
* origin/main: Add OcclusionAttribution and LimeAttribution (#145)
Description
Added classes for Occlusion methods and a wrapper for the Captum implementation of Zeiler & Fergus (2013).
Related Issue
Related to #107
Type of Change