|
21 | 21 | from torch import nn
|
22 | 22 | from torch.utils.hooks import RemovableHandle
|
23 | 23 |
|
24 |
| -from ....utils import StackFrame, find_block_stack, get_post_variable_assignment_hook, validate_indices |
| 24 | +from ....utils import ( |
| 25 | + StackFrame, |
| 26 | + find_block_stack, |
| 27 | + get_post_variable_assignment_hook, |
| 28 | + recursive_get_submodule, |
| 29 | + validate_indices, |
| 30 | +) |
25 | 31 | from ....utils.typing import (
|
26 | 32 | EmbeddingsTensor,
|
27 | 33 | InseqAttribution,
|
@@ -99,6 +105,11 @@ def value_zeroing_forward_mid_hook(
|
99 | 105 | zeroed_units_indices: Optional[OneOrMoreIndices] = None,
|
100 | 106 | batch_size: int = 1,
|
101 | 107 | ) -> None:
|
| 108 | + if varname not in frame.f_locals: |
| 109 | + raise ValueError( |
| 110 | + f"Variable {varname} not found in the local frame." |
| 111 | + f"Other variable names: {', '.join(frame.f_locals.keys())}" |
| 112 | + ) |
102 | 113 | # Zeroing value vectors corresponding to the given token index
|
103 | 114 | if zeroed_token_index is not None:
|
104 | 115 | values_size = frame.f_locals[varname].size()
|
@@ -234,7 +245,9 @@ def compute_modules_post_zeroing_similarity(
|
234 | 245 | value_zeroing_hook_handles: list[RemovableHandle] = []
|
235 | 246 | # Value zeroing hooks are registered for every token separately since they are token-dependent
|
236 | 247 | for block_idx, block in enumerate(modules):
|
237 |
| - attention_module = block.get_submodule(attention_module_name) |
| 248 | + attention_module = recursive_get_submodule(block, attention_module_name) |
| 249 | + if attention_module is None: |
| 250 | + raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.") |
238 | 251 | if isinstance(zeroed_units_indices, dict):
|
239 | 252 | if block_idx not in zeroed_units_indices:
|
240 | 253 | continue
|
|
0 commit comments