Skip to content

Commit e57eca3

Browse files
committed
Fix multiline value assignment hooking for VZ
1 parent efa9ca3 commit e57eca3

File tree

6 files changed

+53
-4
lines changed

6 files changed

+53
-4
lines changed

inseq/attr/feat/ops/value_zeroing.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from torch import nn
2222
from torch.utils.hooks import RemovableHandle
2323

24-
from ....utils import StackFrame, find_block_stack, get_post_variable_assignment_hook, validate_indices
24+
from ....utils import (
25+
StackFrame,
26+
find_block_stack,
27+
get_post_variable_assignment_hook,
28+
recursive_get_submodule,
29+
validate_indices,
30+
)
2531
from ....utils.typing import (
2632
EmbeddingsTensor,
2733
InseqAttribution,
@@ -99,6 +105,11 @@ def value_zeroing_forward_mid_hook(
99105
zeroed_units_indices: Optional[OneOrMoreIndices] = None,
100106
batch_size: int = 1,
101107
) -> None:
108+
if varname not in frame.f_locals:
109+
raise ValueError(
110+
f"Variable {varname} not found in the local frame."
111+
f"Other variable names: {', '.join(frame.f_locals.keys())}"
112+
)
102113
# Zeroing value vectors corresponding to the given token index
103114
if zeroed_token_index is not None:
104115
values_size = frame.f_locals[varname].size()
@@ -234,7 +245,9 @@ def compute_modules_post_zeroing_similarity(
234245
value_zeroing_hook_handles: list[RemovableHandle] = []
235246
# Value zeroing hooks are registered for every token separately since they are token-dependent
236247
for block_idx, block in enumerate(modules):
237-
attention_module = block.get_submodule(attention_module_name)
248+
attention_module = recursive_get_submodule(block, attention_module_name)
249+
if attention_module is None:
250+
raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.")
238251
if isinstance(zeroed_units_indices, dict):
239252
if block_idx not in zeroed_units_indices:
240253
continue

inseq/commands/commands_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,4 @@ def command_args_docstring(cls):
1818
field_help = field.metadata.get("help", "")
1919
docstring += textwrap.dedent(f"\n**{field.name}** (``{field_type}``): {field_help}\n")
2020
cls.__doc__ = docstring
21-
print(docstring)
2221
return cls

inseq/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
get_sequences_from_batched_steps,
5757
normalize,
5858
pad_with_nan,
59+
recursive_get_submodule,
5960
remap_from_filtered,
6061
top_p_logits_mask,
6162
validate_indices,
@@ -126,4 +127,5 @@
126127
"StackFrame",
127128
"validate_indices",
128129
"pad_with_nan",
130+
"recursive_get_submodule",
129131
]

inseq/utils/hooks.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from torch import nn
77

8+
from .misc import get_left_padding
9+
810
StackFrame = TypeVar("StackFrame")
911

1012

@@ -31,7 +33,21 @@ def get_last_variable_assignment_position(
3133
# Matches any assignment of variable varname
3234
pattern = rf"^\s*(?:\w+\s*,\s*)*\b{varname}\b\s*(?:,.+\s*)*=\s*[^\W=]+.*$"
3335
code, startline = getsourcelines(getattr(module, fname))
34-
line_numbers = [i for i, line in enumerate(code) if re.match(pattern, line)]
36+
line_numbers = []
37+
i = 0
38+
while i < len(code):
39+
line = code[i]
40+
# Handles multi-line assignments
41+
if re.match(pattern, line):
42+
parentheses_count = line.count("(") - line.count(")")
43+
ends_with_newline = lambda l: l.strip().endswith("\\")
44+
follow_indent = lambda l, i: len(code) > i + 1 and get_left_padding(code[i + 1]) > get_left_padding(l)
45+
while (ends_with_newline(line) or follow_indent(line, i) or parentheses_count > 0) and len(code) > i + 1:
46+
i += 1
47+
line = code[i]
48+
parentheses_count += line.count("(") - line.count(")")
49+
line_numbers.append(i)
50+
i += 1
3551
if len(line_numbers) == 0:
3652
return None
3753
return line_numbers[-1] + startline + 1

inseq/utils/misc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,8 @@ def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str]
434434
else:
435435
removed_token_idxs += [idx]
436436
return clean_tokens, removed_token_idxs
437+
438+
439+
def get_left_padding(text: str):
440+
"""Returns the number of spaces at the beginning of a string."""
441+
return len(text) - len(text.lstrip())

inseq/utils/torch_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,17 @@ def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False)
348348
if front:
349349
return torch.cat([nan_tensor, t], dim=dim)
350350
return torch.cat([t, nan_tensor], dim=dim)
351+
352+
353+
def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]:
354+
if target == "":
355+
return parent
356+
mod = None
357+
if hasattr(parent, target):
358+
mod = getattr(parent, target)
359+
else:
360+
for submodule in parent.children():
361+
mod = recursive_get_submodule(submodule, target)
362+
if mod is not None:
363+
break
364+
return mod

0 commit comments

Comments
 (0)