-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OcclusionAttribution and LimeAttribution #145
Changes from 4 commits
1379d84
ab7c23d
cc49afc
4e93da3
000aac5
364195d
6e80e62
42cfc7c
cb8b4c7
3ee678f
85aaab9
b8096d1
92ea057
1bfa69a
430aecf
fbcf2e0
ead7aa6
e37f462
2f89a8b
6a66515
3151f3f
eea0588
3723108
dc0c7c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Any, Dict | ||
|
||
import logging | ||
|
||
from captum.attr import Occlusion | ||
|
||
from ...data import PerturbationFeatureAttributionStepOutput | ||
from ...utils import Registry | ||
from ..attribution_decorators import set_hook, unset_hook | ||
from .attribution_utils import get_source_target_attributions | ||
from .gradient_attribution import FeatureAttribution | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PerturbationMethodRegistry(FeatureAttribution, Registry): | ||
gsarti marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's call it PerturbationAttribution to keep it consistent with GradientAttribution. We might want to bulk change them to add the Registry specification at a later time though! |
||
"""Occlusion-based attribution methods.""" | ||
gsarti marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to |
||
|
||
@set_hook | ||
def hook(self, **kwargs): | ||
pass | ||
|
||
@unset_hook | ||
def unhook(self, **kwargs): | ||
pass | ||
|
||
|
||
class OcclusionAttribution(PerturbationMethodRegistry): | ||
"""Occlusion-based attribution method. | ||
Reference implementation: | ||
`https://captum.ai/api/occlusion.html <https://captum.ai/api/occlusion.html>`__. | ||
|
||
Usages in other implementations: | ||
`niuzaisheng/AttExplainer <https://github.com/niuzaisheng/AttExplainer/blob/main/baseline_methods/\ | ||
explain_baseline_captum.py>`__ | ||
`andrewPoulton/explainable-asag <https://github.com/andrewPoulton/explainable-asag/blob/main/explanation.py>`__ | ||
`copenlu/xai-benchmark <https://github.com/copenlu/xai-benchmark/blob/master/saliency_gen/\ | ||
interpret_grads_occ.py>`__ | ||
`DFKI-NLP/thermostat <https://github.com/DFKI-NLP/thermostat/blob/main/src/thermostat/explainers/occlusion.py>`__ | ||
""" | ||
|
||
method_name = "occlusion" | ||
|
||
def __init__(self, attribution_model, **kwargs): | ||
super().__init__(attribution_model) | ||
self.is_layer_attribution = False | ||
self.method = Occlusion(self.attribution_model) | ||
|
||
def attribute_step( | ||
self, | ||
attribute_fn_main_args: Dict[str, Any], | ||
attribution_args: Dict[str, Any] = {}, | ||
) -> Any: | ||
|
||
if "sliding_window_shapes" not in attribution_args: | ||
# Sliding window shapes is defined as a tuple | ||
# First entry is between 1 and length of input | ||
# Second entry is given by the max length of the underlying model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit puzzled by the second entry: the max length you take here is the max generation length that the model can handle, but my understanding was that this would be the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, I accidentally took the next best attribute in |
||
# If not explicitly given via attribution_args, the default is (1, model_max_length) | ||
attribution_args["sliding_window_shapes"] = (1, self.attribution_model.model_max_length) | ||
|
||
attr = self.method.attribute( | ||
**attribute_fn_main_args, | ||
**attribution_args, | ||
) | ||
|
||
source_attributions, target_attributions = get_source_target_attributions( | ||
attr, self.attribution_model.is_encoder_decoder | ||
) | ||
return PerturbationFeatureAttributionStepOutput( | ||
source_attributions=source_attributions, | ||
target_attributions=target_attributions, | ||
) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -474,3 +474,23 @@ class GradientFeatureAttributionStepOutput(FeatureAttributionStepOutput): | |||||||||
""" | ||||||||||
|
||||||||||
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = GradientFeatureAttributionSequenceOutput | ||||||||||
|
||||||||||
|
||||||||||
# Perturbation attribution classes | ||||||||||
|
||||||||||
|
||||||||||
@dataclass(eq=False, repr=False) | ||||||||||
class PerturbationFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput): | ||||||||||
"""Raw output of a single sequence of perturbation feature attribution.""" | ||||||||||
|
||||||||||
def __post_init__(self): | ||||||||||
super().__post_init__() | ||||||||||
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do perturbation attributions have shape If indeed it is not the case, then we would not need a specific class for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, when I left out this aggregation step from the inseq/inseq/data/aggregator.py Lines 169 to 172 in 000aac5
I assume this will apply to other perturbation methods as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you to have the return object to be of |
||||||||||
self._dict_aggregate_fn["target_attributions"]["sequence_aggregate"] = sum_normalize_attributions | ||||||||||
|
||||||||||
|
||||||||||
@dataclass(eq=False, repr=False) | ||||||||||
class PerturbationFeatureAttributionStepOutput(FeatureAttributionStepOutput): | ||||||||||
"""Raw output of a single step of perturbation feature attribution.""" | ||||||||||
|
||||||||||
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = PerturbationFeatureAttributionSequenceOutput |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be already in the main branch if you merged main!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this was an accident. Thanks for pointing it out!