Skip to content

Commit 3ee678f

Browse files
committed
Added reshapes for interpretable inputs in LimeBase. (WIP)
1 parent cb8b4c7 commit 3ee678f

File tree

3 files changed

+324
-47
lines changed

3 files changed

+324
-47
lines changed

inseq/attr/feat/ops/lime.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import inspect
2+
import math
3+
import warnings
4+
from functools import partial
5+
from typing import Any, Callable, Optional, cast
6+
7+
import torch
8+
from captum._utils.common import (
9+
_expand_additional_forward_args,
10+
_expand_target,
11+
)
12+
from captum._utils.models.linear_model import SkLearnLinearModel
13+
from captum._utils.models.model import Model
14+
from captum._utils.progress import progress
15+
from captum._utils.typing import (
16+
TargetType,
17+
TensorOrTupleOfTensorsGeneric,
18+
)
19+
from captum.attr import LimeBase
20+
from torch import Tensor
21+
from torch.utils.data import DataLoader, TensorDataset
22+
23+
24+
class Lime(LimeBase):
25+
def __init__(
26+
self,
27+
attribution_model: Callable,
28+
interpretable_model: Model = None,
29+
similarity_func: Callable = None,
30+
perturb_func: Callable = None,
31+
perturb_interpretable_space: bool = False,
32+
from_interp_rep_transform: Optional[Callable] = None,
33+
to_interp_rep_transform: Optional[Callable] = None,
34+
) -> None:
35+
if interpretable_model is None:
36+
interpretable_model = SkLearnLinearModel("linear_model.Ridge")
37+
38+
if similarity_func is None:
39+
similarity_func = self.token_similarity_kernel
40+
41+
if perturb_func is None:
42+
perturb_func = partial(
43+
self.perturb_func,
44+
)
45+
46+
super().__init__(
47+
forward_func=attribution_model,
48+
interpretable_model=interpretable_model,
49+
similarity_func=similarity_func,
50+
perturb_func=perturb_func,
51+
perturb_interpretable_space=perturb_interpretable_space,
52+
from_interp_rep_transform=None,
53+
to_interp_rep_transform=self.to_interp_rep_transform,
54+
)
55+
self.attribution_model = attribution_model
56+
assert self.attribution_model.model.device is not None
57+
assert self.attribution_model.tokenizer.pad_token_id is not None
58+
59+
# @log_usage
60+
def attribute(
61+
self,
62+
inputs: TensorOrTupleOfTensorsGeneric,
63+
target: TargetType = None,
64+
additional_forward_args: Any = None,
65+
n_samples: int = 50,
66+
perturbations_per_eval: int = 1,
67+
show_progress: bool = False,
68+
**kwargs,
69+
) -> Tensor:
70+
r"""
71+
This method attributes the output of the model with given target index
72+
(in case it is provided, otherwise it assumes that output is a
73+
scalar) to the inputs of the model using the approach described above.
74+
It trains an interpretable model and returns a representation of the
75+
interpretable model.
76+
77+
It is recommended to only provide a single example as input (tensors
78+
with first dimension or batch size = 1). This is because LIME is generally
79+
used for sample-based interpretability, training a separate interpretable
80+
model to explain a model's prediction on each individual example.
81+
82+
A batch of inputs can be provided as inputs only if forward_func
83+
returns a single value per batch (e.g. loss).
84+
The interpretable feature representation should still have shape
85+
1 x num_interp_features, corresponding to the interpretable
86+
representation for the full batch, and perturbations_per_eval
87+
must be set to 1.
88+
89+
Args:
90+
91+
inputs (tensor or tuple of tensors): Input for which LIME
92+
is computed. If forward_func takes a single
93+
tensor as input, a single input tensor should be provided.
94+
If forward_func takes multiple tensors as input, a tuple
95+
of the input tensors should be provided. It is assumed
96+
that for all given input tensors, dimension 0 corresponds
97+
to the number of examples, and if multiple input tensors
98+
are provided, the examples must be aligned appropriately.
99+
target (int, tuple, tensor or list, optional): Output indices for
100+
which surrogate model is trained
101+
(for classification cases,
102+
this is usually the target class).
103+
If the network returns a scalar value per example,
104+
no target index is necessary.
105+
For general 2D outputs, targets can be either:
106+
107+
- a single integer or a tensor containing a single
108+
integer, which is applied to all input examples
109+
110+
- a list of integers or a 1D tensor, with length matching
111+
the number of examples in inputs (dim 0). Each integer
112+
is applied as the target for the corresponding example.
113+
114+
For outputs with > 2 dimensions, targets can be either:
115+
116+
- A single tuple, which contains #output_dims - 1
117+
elements. This target index is applied to all examples.
118+
119+
- A list of tuples with length equal to the number of
120+
examples in inputs (dim 0), and each tuple containing
121+
#output_dims - 1 elements. Each tuple is applied as the
122+
target for the corresponding example.
123+
124+
Default: None
125+
additional_forward_args (any, optional): If the forward function
126+
requires additional arguments other than the inputs for
127+
which attributions should not be computed, this argument
128+
can be provided. It must be either a single additional
129+
argument of a Tensor or arbitrary (non-tuple) type or a
130+
tuple containing multiple additional arguments including
131+
tensors or any arbitrary python types. These arguments
132+
are provided to forward_func in order following the
133+
arguments in inputs.
134+
For a tensor, the first dimension of the tensor must
135+
correspond to the number of examples. For all other types,
136+
the given argument is used for all forward evaluations.
137+
Note that attributions are not computed with respect
138+
to these arguments.
139+
Default: None
140+
n_samples (int, optional): The number of samples of the original
141+
model used to train the surrogate interpretable model.
142+
Default: `50` if `n_samples` is not provided.
143+
perturbations_per_eval (int, optional): Allows multiple samples
144+
to be processed simultaneously in one call to forward_fn.
145+
Each forward pass will contain a maximum of
146+
perturbations_per_eval * #examples samples.
147+
For DataParallel models, each batch is split among the
148+
available devices, so evaluations on each available
149+
device contain at most
150+
(perturbations_per_eval * #examples) / num_devices
151+
samples.
152+
If the forward function returns a single scalar per batch,
153+
perturbations_per_eval must be set to 1.
154+
Default: 1
155+
show_progress (bool, optional): Displays the progress of computation.
156+
It will try to use tqdm if available for advanced features
157+
(e.g. time estimation). Otherwise, it will fallback to
158+
a simple output of progress.
159+
Default: False
160+
**kwargs (Any, optional): Any additional arguments necessary for
161+
sampling and transformation functions (provided to
162+
constructor).
163+
Default: None
164+
165+
Returns:
166+
**interpretable model representation**:
167+
- **interpretable model representation* (*Any*):
168+
A representation of the interpretable model trained. The return
169+
type matches the return type of train_interpretable_model_func.
170+
For example, this could contain coefficients of a
171+
linear surrogate model.
172+
"""
173+
with torch.no_grad():
174+
inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
175+
device = inp_tensor.device
176+
177+
interpretable_inps = []
178+
similarities = []
179+
outputs = []
180+
181+
curr_model_inputs = []
182+
expanded_additional_args = None
183+
expanded_target = None
184+
perturb_generator = None
185+
if inspect.isgeneratorfunction(self.perturb_func):
186+
perturb_generator = self.perturb_func(inputs, **kwargs)
187+
188+
if show_progress:
189+
attr_progress = progress(
190+
total=math.ceil(n_samples / perturbations_per_eval),
191+
desc=f"{self.get_name()} attribution",
192+
)
193+
attr_progress.update(0)
194+
195+
batch_count = 0
196+
for _ in range(n_samples):
197+
if perturb_generator:
198+
try:
199+
curr_sample = next(perturb_generator)
200+
except StopIteration:
201+
warnings.warn("Generator completed prior to given n_samples iterations!")
202+
break
203+
else:
204+
curr_sample = self.perturb_func(inputs, **kwargs)
205+
batch_count += 1
206+
if self.perturb_interpretable_space:
207+
interpretable_inps.append(curr_sample)
208+
curr_model_inputs.append(
209+
self.from_interp_rep_transform(curr_sample, inputs, **kwargs) # type: ignore
210+
)
211+
else:
212+
curr_model_inputs.append(curr_sample)
213+
interpretable_inps.append(
214+
self.to_interp_rep_transform(curr_sample, inputs, **kwargs) # type: ignore
215+
)
216+
curr_sim = self.similarity_func(inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs)
217+
similarities.append(
218+
curr_sim.flatten() if isinstance(curr_sim, Tensor) else torch.tensor([curr_sim], device=device)
219+
)
220+
221+
if len(curr_model_inputs) == perturbations_per_eval:
222+
if expanded_additional_args is None:
223+
expanded_additional_args = _expand_additional_forward_args(
224+
additional_forward_args, len(curr_model_inputs)
225+
)
226+
if expanded_target is None:
227+
expanded_target = _expand_target(target, len(curr_model_inputs))
228+
229+
model_out = self._evaluate_batch(
230+
curr_model_inputs,
231+
expanded_target,
232+
expanded_additional_args,
233+
device,
234+
)
235+
236+
if show_progress:
237+
attr_progress.update()
238+
239+
outputs.append(model_out)
240+
241+
curr_model_inputs = []
242+
243+
if len(curr_model_inputs) > 0:
244+
expanded_additional_args = _expand_additional_forward_args(
245+
additional_forward_args, len(curr_model_inputs)
246+
)
247+
expanded_target = _expand_target(target, len(curr_model_inputs))
248+
model_out = self._evaluate_batch(
249+
curr_model_inputs,
250+
expanded_target,
251+
expanded_additional_args,
252+
device,
253+
)
254+
if show_progress:
255+
attr_progress.update()
256+
outputs.append(model_out)
257+
258+
if show_progress:
259+
attr_progress.close()
260+
261+
""" Modification of original attribute function:
262+
Squeeze the batch dimension out of interpretable_inps
263+
-> 2D tensor (n_samples ✕ (input_dim * embedding_dim))
264+
"""
265+
combined_interp_inps = torch.cat([i.view(-1).unsqueeze(dim=0) for i in interpretable_inps]).double()
266+
267+
combined_outputs = (torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)).double()
268+
combined_sim = (
269+
torch.cat(similarities) if len(similarities[0].shape) > 0 else torch.stack(similarities)
270+
).double()
271+
dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim)
272+
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
273+
274+
""" Second modification:
275+
Reshape of the learned representation
276+
-> 3D tensor (b=1 ✕ input_dim ✕ embedding_dim)
277+
"""
278+
return self.interpretable_model.representation().reshape(inp_tensor.shape)
279+
280+
@staticmethod
281+
def token_similarity_kernel(
282+
original_input: tuple,
283+
perturbed_input: tuple,
284+
perturbed_interpretable_input: torch.Tensor,
285+
**kwargs,
286+
) -> torch.Tensor:
287+
original_input_tensor = original_input[0] # [0]
288+
perturbed_input_tensor = perturbed_input[0]
289+
assert original_input_tensor.shape == perturbed_input_tensor.shape
290+
similarity = torch.sum(original_input_tensor == perturbed_input_tensor) / len(original_input_tensor)
291+
return similarity
292+
293+
def perturb_func(
294+
self,
295+
original_input: tuple, # always needs to be last argument before **kwargs due to "partial"
296+
**kwargs: Any,
297+
) -> tuple:
298+
"""
299+
Sampling function
300+
"""
301+
original_input_tensor = original_input[0]
302+
mask = torch.randint(low=0, high=2, size=original_input_tensor.size()).to(self.attribution_model.device)
303+
perturbed_input = original_input_tensor * mask + (1 - mask) * self.attribution_model.tokenizer.pad_token_id
304+
perturbed_input_tuple = tuple({perturbed_input})
305+
return perturbed_input_tuple # [0][0] # FIXME
306+
307+
@staticmethod
308+
def to_interp_rep_transform(sample, original_input, **kwargs: Any):
309+
return sample[0] # [0] # FIXME: Access first entry of tuple

0 commit comments

Comments
 (0)