Skip to content

Commit 19f781f

Browse files
committed
Moved rollout code to #254
1 parent e57eca3 commit 19f781f

File tree

5 files changed

+0
-247
lines changed

5 files changed

+0
-247
lines changed

CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
88
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
99
- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
10-
- Added `rollout` (`inseq.data.aggregation_functions.RolloutAggregationFunction`) aggregation function for `SequenceAttributionAggregator` class ([#173](https://github.com/inseq-team/inseq/pull/173)).
1110
- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).
1211

1312
## 🔧 Fixes & Refactoring

inseq/attr/feat/ops/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
22
from .lime import Lime
33
from .monotonic_path_builder import MonotonicPathBuilder
4-
from .rollout import rollout_fn
54
from .sequential_integrated_gradients import SequentialIntegratedGradients
65
from .value_zeroing import ValueZeroing
76

@@ -10,6 +9,5 @@
109
"MonotonicPathBuilder",
1110
"ValueZeroing",
1211
"Lime",
13-
"rollout_fn",
1412
"SequentialIntegratedGradients",
1513
]

inseq/attr/feat/ops/rollout.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

inseq/data/aggregation_functions.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
from torch.linalg import vector_norm
2121

22-
from ..attr.feat.ops import rollout_fn
2322
from ..utils import Registry, available_classes
2423
from ..utils.typing import (
2524
ScoreTensor,
@@ -94,45 +93,6 @@ def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreT
9493
return vector_norm(scores, ord=vnorm_ord, dim=dim)
9594

9695

97-
class RolloutAggregationFunction(AggregationFunction):
98-
aggregation_function_name = "rollout"
99-
100-
def __init__(self):
101-
super().__init__()
102-
self.takes_single_tensor: bool = False
103-
self.takes_sequence_scores: bool = True
104-
105-
def __call__(
106-
self,
107-
scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
108-
dim: int,
109-
sequence_scores: dict[str, torch.Tensor] = {},
110-
) -> ScoreTensor:
111-
dec_self_prefix = "decoder_self"
112-
enc_self_prefix = "encoder_self"
113-
dec_match = [name for name in sequence_scores.keys() if name.startswith(dec_self_prefix)]
114-
enc_match = [name for name in sequence_scores.keys() if name.startswith(enc_self_prefix)]
115-
if isinstance(scores, torch.Tensor):
116-
# If no matching prefix is found, we assume the decoder-only target-only rollout case
117-
if not dec_match or not enc_match:
118-
return rollout_fn(scores, dim=dim)
119-
# If both prefixes are found, we assume the encoder-decoder source-only rollout case
120-
else:
121-
enc_match = sequence_scores[enc_match[0]]
122-
dec_match = sequence_scores[dec_match[0]]
123-
return rollout_fn((enc_match, scores, dec_match), dim=dim)[0]
124-
elif not enc_match:
125-
raise KeyError(
126-
"Could not find encoder self-importance scores in sequence scores. "
127-
"Encoder self-importance scores are required for encoder-decoder rollout. They should be provided "
128-
f"as an entry in the sequence scores dictionary with key starting with '{enc_self_prefix}', and "
129-
"value being a tensor of shape (src_seq_len, src_seq_len, ..., rollout_dim)."
130-
)
131-
else:
132-
enc_match = sequence_scores[enc_match[0]]
133-
return rollout_fn((enc_match,) + scores, dim=dim)
134-
135-
13696
DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
13797
"source_attributions": {"spans": "absmax"},
13898
"target_attributions": {"spans": "absmax"},

tests/attr/feat/ops/test_rollout.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)