|
| 1 | +import inspect |
| 2 | +import math |
| 3 | +import warnings |
| 4 | +from functools import partial |
| 5 | +from typing import Any, Callable, Optional, cast |
| 6 | + |
| 7 | +import torch |
| 8 | +from captum._utils.common import ( |
| 9 | + _expand_additional_forward_args, |
| 10 | + _expand_target, |
| 11 | +) |
| 12 | +from captum._utils.models.linear_model import SkLearnLinearModel |
| 13 | +from captum._utils.models.model import Model |
| 14 | +from captum._utils.progress import progress |
| 15 | +from captum._utils.typing import ( |
| 16 | + TargetType, |
| 17 | + TensorOrTupleOfTensorsGeneric, |
| 18 | +) |
| 19 | +from captum.attr import LimeBase |
| 20 | +from torch import Tensor |
| 21 | +from torch.utils.data import DataLoader, TensorDataset |
| 22 | + |
| 23 | + |
| 24 | +class Lime(LimeBase): |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + attribution_model: Callable, |
| 28 | + interpretable_model: Model = None, |
| 29 | + similarity_func: Callable = None, |
| 30 | + perturb_func: Callable = None, |
| 31 | + perturb_interpretable_space: bool = False, |
| 32 | + from_interp_rep_transform: Optional[Callable] = None, |
| 33 | + to_interp_rep_transform: Optional[Callable] = None, |
| 34 | + ) -> None: |
| 35 | + if interpretable_model is None: |
| 36 | + interpretable_model = SkLearnLinearModel("linear_model.Ridge") |
| 37 | + |
| 38 | + if similarity_func is None: |
| 39 | + similarity_func = self.token_similarity_kernel |
| 40 | + |
| 41 | + if perturb_func is None: |
| 42 | + perturb_func = partial( |
| 43 | + self.perturb_func, |
| 44 | + ) |
| 45 | + |
| 46 | + super().__init__( |
| 47 | + forward_func=attribution_model, |
| 48 | + interpretable_model=interpretable_model, |
| 49 | + similarity_func=similarity_func, |
| 50 | + perturb_func=perturb_func, |
| 51 | + perturb_interpretable_space=perturb_interpretable_space, |
| 52 | + from_interp_rep_transform=None, |
| 53 | + to_interp_rep_transform=self.to_interp_rep_transform, |
| 54 | + ) |
| 55 | + self.attribution_model = attribution_model |
| 56 | + assert self.attribution_model.model.device is not None |
| 57 | + assert self.attribution_model.tokenizer.pad_token_id is not None |
| 58 | + |
| 59 | + # @log_usage |
| 60 | + def attribute( |
| 61 | + self, |
| 62 | + inputs: TensorOrTupleOfTensorsGeneric, |
| 63 | + target: TargetType = None, |
| 64 | + additional_forward_args: Any = None, |
| 65 | + n_samples: int = 50, |
| 66 | + perturbations_per_eval: int = 1, |
| 67 | + show_progress: bool = False, |
| 68 | + **kwargs, |
| 69 | + ) -> Tensor: |
| 70 | + r""" |
| 71 | + This method attributes the output of the model with given target index |
| 72 | + (in case it is provided, otherwise it assumes that output is a |
| 73 | + scalar) to the inputs of the model using the approach described above. |
| 74 | + It trains an interpretable model and returns a representation of the |
| 75 | + interpretable model. |
| 76 | +
|
| 77 | + It is recommended to only provide a single example as input (tensors |
| 78 | + with first dimension or batch size = 1). This is because LIME is generally |
| 79 | + used for sample-based interpretability, training a separate interpretable |
| 80 | + model to explain a model's prediction on each individual example. |
| 81 | +
|
| 82 | + A batch of inputs can be provided as inputs only if forward_func |
| 83 | + returns a single value per batch (e.g. loss). |
| 84 | + The interpretable feature representation should still have shape |
| 85 | + 1 x num_interp_features, corresponding to the interpretable |
| 86 | + representation for the full batch, and perturbations_per_eval |
| 87 | + must be set to 1. |
| 88 | +
|
| 89 | + Args: |
| 90 | +
|
| 91 | + inputs (tensor or tuple of tensors): Input for which LIME |
| 92 | + is computed. If forward_func takes a single |
| 93 | + tensor as input, a single input tensor should be provided. |
| 94 | + If forward_func takes multiple tensors as input, a tuple |
| 95 | + of the input tensors should be provided. It is assumed |
| 96 | + that for all given input tensors, dimension 0 corresponds |
| 97 | + to the number of examples, and if multiple input tensors |
| 98 | + are provided, the examples must be aligned appropriately. |
| 99 | + target (int, tuple, tensor or list, optional): Output indices for |
| 100 | + which surrogate model is trained |
| 101 | + (for classification cases, |
| 102 | + this is usually the target class). |
| 103 | + If the network returns a scalar value per example, |
| 104 | + no target index is necessary. |
| 105 | + For general 2D outputs, targets can be either: |
| 106 | +
|
| 107 | + - a single integer or a tensor containing a single |
| 108 | + integer, which is applied to all input examples |
| 109 | +
|
| 110 | + - a list of integers or a 1D tensor, with length matching |
| 111 | + the number of examples in inputs (dim 0). Each integer |
| 112 | + is applied as the target for the corresponding example. |
| 113 | +
|
| 114 | + For outputs with > 2 dimensions, targets can be either: |
| 115 | +
|
| 116 | + - A single tuple, which contains #output_dims - 1 |
| 117 | + elements. This target index is applied to all examples. |
| 118 | +
|
| 119 | + - A list of tuples with length equal to the number of |
| 120 | + examples in inputs (dim 0), and each tuple containing |
| 121 | + #output_dims - 1 elements. Each tuple is applied as the |
| 122 | + target for the corresponding example. |
| 123 | +
|
| 124 | + Default: None |
| 125 | + additional_forward_args (any, optional): If the forward function |
| 126 | + requires additional arguments other than the inputs for |
| 127 | + which attributions should not be computed, this argument |
| 128 | + can be provided. It must be either a single additional |
| 129 | + argument of a Tensor or arbitrary (non-tuple) type or a |
| 130 | + tuple containing multiple additional arguments including |
| 131 | + tensors or any arbitrary python types. These arguments |
| 132 | + are provided to forward_func in order following the |
| 133 | + arguments in inputs. |
| 134 | + For a tensor, the first dimension of the tensor must |
| 135 | + correspond to the number of examples. For all other types, |
| 136 | + the given argument is used for all forward evaluations. |
| 137 | + Note that attributions are not computed with respect |
| 138 | + to these arguments. |
| 139 | + Default: None |
| 140 | + n_samples (int, optional): The number of samples of the original |
| 141 | + model used to train the surrogate interpretable model. |
| 142 | + Default: `50` if `n_samples` is not provided. |
| 143 | + perturbations_per_eval (int, optional): Allows multiple samples |
| 144 | + to be processed simultaneously in one call to forward_fn. |
| 145 | + Each forward pass will contain a maximum of |
| 146 | + perturbations_per_eval * #examples samples. |
| 147 | + For DataParallel models, each batch is split among the |
| 148 | + available devices, so evaluations on each available |
| 149 | + device contain at most |
| 150 | + (perturbations_per_eval * #examples) / num_devices |
| 151 | + samples. |
| 152 | + If the forward function returns a single scalar per batch, |
| 153 | + perturbations_per_eval must be set to 1. |
| 154 | + Default: 1 |
| 155 | + show_progress (bool, optional): Displays the progress of computation. |
| 156 | + It will try to use tqdm if available for advanced features |
| 157 | + (e.g. time estimation). Otherwise, it will fallback to |
| 158 | + a simple output of progress. |
| 159 | + Default: False |
| 160 | + **kwargs (Any, optional): Any additional arguments necessary for |
| 161 | + sampling and transformation functions (provided to |
| 162 | + constructor). |
| 163 | + Default: None |
| 164 | +
|
| 165 | + Returns: |
| 166 | + **interpretable model representation**: |
| 167 | + - **interpretable model representation* (*Any*): |
| 168 | + A representation of the interpretable model trained. The return |
| 169 | + type matches the return type of train_interpretable_model_func. |
| 170 | + For example, this could contain coefficients of a |
| 171 | + linear surrogate model. |
| 172 | + """ |
| 173 | + with torch.no_grad(): |
| 174 | + inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] |
| 175 | + device = inp_tensor.device |
| 176 | + |
| 177 | + interpretable_inps = [] |
| 178 | + similarities = [] |
| 179 | + outputs = [] |
| 180 | + |
| 181 | + curr_model_inputs = [] |
| 182 | + expanded_additional_args = None |
| 183 | + expanded_target = None |
| 184 | + perturb_generator = None |
| 185 | + if inspect.isgeneratorfunction(self.perturb_func): |
| 186 | + perturb_generator = self.perturb_func(inputs, **kwargs) |
| 187 | + |
| 188 | + if show_progress: |
| 189 | + attr_progress = progress( |
| 190 | + total=math.ceil(n_samples / perturbations_per_eval), |
| 191 | + desc=f"{self.get_name()} attribution", |
| 192 | + ) |
| 193 | + attr_progress.update(0) |
| 194 | + |
| 195 | + batch_count = 0 |
| 196 | + for _ in range(n_samples): |
| 197 | + if perturb_generator: |
| 198 | + try: |
| 199 | + curr_sample = next(perturb_generator) |
| 200 | + except StopIteration: |
| 201 | + warnings.warn("Generator completed prior to given n_samples iterations!") |
| 202 | + break |
| 203 | + else: |
| 204 | + curr_sample = self.perturb_func(inputs, **kwargs) |
| 205 | + batch_count += 1 |
| 206 | + if self.perturb_interpretable_space: |
| 207 | + interpretable_inps.append(curr_sample) |
| 208 | + curr_model_inputs.append( |
| 209 | + self.from_interp_rep_transform(curr_sample, inputs, **kwargs) # type: ignore |
| 210 | + ) |
| 211 | + else: |
| 212 | + curr_model_inputs.append(curr_sample) |
| 213 | + interpretable_inps.append( |
| 214 | + self.to_interp_rep_transform(curr_sample, inputs, **kwargs) # type: ignore |
| 215 | + ) |
| 216 | + curr_sim = self.similarity_func(inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs) |
| 217 | + similarities.append( |
| 218 | + curr_sim.flatten() if isinstance(curr_sim, Tensor) else torch.tensor([curr_sim], device=device) |
| 219 | + ) |
| 220 | + |
| 221 | + if len(curr_model_inputs) == perturbations_per_eval: |
| 222 | + if expanded_additional_args is None: |
| 223 | + expanded_additional_args = _expand_additional_forward_args( |
| 224 | + additional_forward_args, len(curr_model_inputs) |
| 225 | + ) |
| 226 | + if expanded_target is None: |
| 227 | + expanded_target = _expand_target(target, len(curr_model_inputs)) |
| 228 | + |
| 229 | + model_out = self._evaluate_batch( |
| 230 | + curr_model_inputs, |
| 231 | + expanded_target, |
| 232 | + expanded_additional_args, |
| 233 | + device, |
| 234 | + ) |
| 235 | + |
| 236 | + if show_progress: |
| 237 | + attr_progress.update() |
| 238 | + |
| 239 | + outputs.append(model_out) |
| 240 | + |
| 241 | + curr_model_inputs = [] |
| 242 | + |
| 243 | + if len(curr_model_inputs) > 0: |
| 244 | + expanded_additional_args = _expand_additional_forward_args( |
| 245 | + additional_forward_args, len(curr_model_inputs) |
| 246 | + ) |
| 247 | + expanded_target = _expand_target(target, len(curr_model_inputs)) |
| 248 | + model_out = self._evaluate_batch( |
| 249 | + curr_model_inputs, |
| 250 | + expanded_target, |
| 251 | + expanded_additional_args, |
| 252 | + device, |
| 253 | + ) |
| 254 | + if show_progress: |
| 255 | + attr_progress.update() |
| 256 | + outputs.append(model_out) |
| 257 | + |
| 258 | + if show_progress: |
| 259 | + attr_progress.close() |
| 260 | + |
| 261 | + """ Modification of original attribute function: |
| 262 | + Squeeze the batch dimension out of interpretable_inps |
| 263 | + -> 2D tensor (n_samples ✕ (input_dim * embedding_dim)) |
| 264 | + """ |
| 265 | + combined_interp_inps = torch.cat([i.view(-1).unsqueeze(dim=0) for i in interpretable_inps]).double() |
| 266 | + |
| 267 | + combined_outputs = (torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)).double() |
| 268 | + combined_sim = ( |
| 269 | + torch.cat(similarities) if len(similarities[0].shape) > 0 else torch.stack(similarities) |
| 270 | + ).double() |
| 271 | + dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim) |
| 272 | + self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) |
| 273 | + |
| 274 | + """ Second modification: |
| 275 | + Reshape of the learned representation |
| 276 | + -> 3D tensor (b=1 ✕ input_dim ✕ embedding_dim) |
| 277 | + """ |
| 278 | + return self.interpretable_model.representation().reshape(inp_tensor.shape) |
| 279 | + |
| 280 | + @staticmethod |
| 281 | + def token_similarity_kernel( |
| 282 | + original_input: tuple, |
| 283 | + perturbed_input: tuple, |
| 284 | + perturbed_interpretable_input: torch.Tensor, |
| 285 | + **kwargs, |
| 286 | + ) -> torch.Tensor: |
| 287 | + original_input_tensor = original_input[0] # [0] |
| 288 | + perturbed_input_tensor = perturbed_input[0] |
| 289 | + assert original_input_tensor.shape == perturbed_input_tensor.shape |
| 290 | + similarity = torch.sum(original_input_tensor == perturbed_input_tensor) / len(original_input_tensor) |
| 291 | + return similarity |
| 292 | + |
| 293 | + def perturb_func( |
| 294 | + self, |
| 295 | + original_input: tuple, # always needs to be last argument before **kwargs due to "partial" |
| 296 | + **kwargs: Any, |
| 297 | + ) -> tuple: |
| 298 | + """ |
| 299 | + Sampling function |
| 300 | + """ |
| 301 | + original_input_tensor = original_input[0] |
| 302 | + mask = torch.randint(low=0, high=2, size=original_input_tensor.size()).to(self.attribution_model.device) |
| 303 | + perturbed_input = original_input_tensor * mask + (1 - mask) * self.attribution_model.tokenizer.pad_token_id |
| 304 | + perturbed_input_tuple = tuple({perturbed_input}) |
| 305 | + return perturbed_input_tuple # [0][0] # FIXME |
| 306 | + |
| 307 | + @staticmethod |
| 308 | + def to_interp_rep_transform(sample, original_input, **kwargs: Any): |
| 309 | + return sample[0] # [0] # FIXME: Access first entry of tuple |
0 commit comments