|
19 | 19 | import torch
|
20 | 20 | from torch.linalg import vector_norm
|
21 | 21 |
|
22 |
| -from ..attr.feat.ops import rollout_fn |
23 | 22 | from ..utils import Registry, available_classes
|
24 | 23 | from ..utils.typing import (
|
25 | 24 | ScoreTensor,
|
@@ -94,45 +93,6 @@ def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreT
|
94 | 93 | return vector_norm(scores, ord=vnorm_ord, dim=dim)
|
95 | 94 |
|
96 | 95 |
|
97 |
| -class RolloutAggregationFunction(AggregationFunction): |
98 |
| - aggregation_function_name = "rollout" |
99 |
| - |
100 |
| - def __init__(self): |
101 |
| - super().__init__() |
102 |
| - self.takes_single_tensor: bool = False |
103 |
| - self.takes_sequence_scores: bool = True |
104 |
| - |
105 |
| - def __call__( |
106 |
| - self, |
107 |
| - scores: Union[torch.Tensor, tuple[torch.Tensor, ...]], |
108 |
| - dim: int, |
109 |
| - sequence_scores: dict[str, torch.Tensor] = {}, |
110 |
| - ) -> ScoreTensor: |
111 |
| - dec_self_prefix = "decoder_self" |
112 |
| - enc_self_prefix = "encoder_self" |
113 |
| - dec_match = [name for name in sequence_scores.keys() if name.startswith(dec_self_prefix)] |
114 |
| - enc_match = [name for name in sequence_scores.keys() if name.startswith(enc_self_prefix)] |
115 |
| - if isinstance(scores, torch.Tensor): |
116 |
| - # If no matching prefix is found, we assume the decoder-only target-only rollout case |
117 |
| - if not dec_match or not enc_match: |
118 |
| - return rollout_fn(scores, dim=dim) |
119 |
| - # If both prefixes are found, we assume the encoder-decoder source-only rollout case |
120 |
| - else: |
121 |
| - enc_match = sequence_scores[enc_match[0]] |
122 |
| - dec_match = sequence_scores[dec_match[0]] |
123 |
| - return rollout_fn((enc_match, scores, dec_match), dim=dim)[0] |
124 |
| - elif not enc_match: |
125 |
| - raise KeyError( |
126 |
| - "Could not find encoder self-importance scores in sequence scores. " |
127 |
| - "Encoder self-importance scores are required for encoder-decoder rollout. They should be provided " |
128 |
| - f"as an entry in the sequence scores dictionary with key starting with '{enc_self_prefix}', and " |
129 |
| - "value being a tensor of shape (src_seq_len, src_seq_len, ..., rollout_dim)." |
130 |
| - ) |
131 |
| - else: |
132 |
| - enc_match = sequence_scores[enc_match[0]] |
133 |
| - return rollout_fn((enc_match,) + scores, dim=dim) |
134 |
| - |
135 |
| - |
136 | 96 | DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
|
137 | 97 | "source_attributions": {"spans": "absmax"},
|
138 | 98 | "target_attributions": {"spans": "absmax"},
|
|
0 commit comments