Skip to content

Commit

Permalink
Add split aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Nov 4, 2024
1 parent d7c5269 commit 35bd4dc
Showing 1 changed file with 142 additions and 2 deletions.
144 changes: 142 additions & 2 deletions inseq/data/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar

import torch

Expand Down Expand Up @@ -541,7 +543,7 @@ def validate_spans(cls, span_sequence: list[TokenWithId], spans: IndexSpan | Non
prev_span_max = -1
for span in spans:
assert len(span) == 2, f"Spans must contain two indexes, got {spans}"
assert span[1] > span[0] + 1, f"Spans must be non-empty, got {spans}"
assert span[1] >= span[0] + 1, f"Spans must be non-empty, got {spans}"
assert (
span[0] >= prev_span_max
), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}"
Expand Down Expand Up @@ -722,6 +724,144 @@ def get_spans(tokens: list[TokenWithId], special_chars: str | tuple[str, ...], i
return spans


class StringSplitAggregator(ContiguousSpanAggregator):
"""Aggregates contiguous tokens using specified strings as separators.
Args:
attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object to aggregate.
aggregate_fn (:obj:`Callable`, optional): Function to aggregate over the subwords.
Defaults to the highest absolute value score across the aggregated span, with original sign
preserved (e.g. [0.3, -0.7, 0.1] -> -0.7).
aggregate_source (bool, optional): Whether to aggregate over the source sequence. Defaults to True.
aggregate_target (bool, optional): Whether to aggregate over the target sequence. Defaults to True.
split_pattern (str): Regular expression pattern used to split the sequences.
split_mode (str, optional): Treatment for split tokens. If "single", these are kept separate from previous and
following tokens. If "start", they are concatenated with following tokens. If "end", they are concatenated
to previous tokens. Defaults to "single".
"""

aggregator_name = "split"

class SplitStrategy(Enum):
SINGLE = "single"
START = "start"
END = "end"

@classmethod
def aggregate(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_source: bool = True,
aggregate_target: bool = True,
split_pattern: str = None,
split_mode: Literal["single", "start", "end"] = SplitStrategy.SINGLE.value,
**kwargs,
):
source_spans = []
target_spans = []
if split_pattern is None:
raise ValueError("split_pattern is None. Provide a valid regular expression pattern to split the string.")
if aggregate_source:
source_spans = cls.get_spans(attr.source, split_pattern, split_mode)
if aggregate_target:
target_spans = cls.get_spans(attr.target, split_pattern, split_mode)
return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs)

@classmethod
def get_spans(
cls,
tokens: list[TokenWithId],
split_pattern: str,
split_mode: Literal["single", "start", "end"] = SplitStrategy.SINGLE.value,
) -> list[tuple[int, int]]:
full_text = "".join(t.token for t in tokens)
curr_idx = 0
token_spans = []

# Generate token spans
for tok in tokens:
token_spans.append((curr_idx, curr_idx + len(tok.token)))
curr_idx += len(tok.token)

# Find all matches for the given pattern
matches = list(re.finditer(split_pattern, full_text))
if not matches:
return []
matches_spans = [(m.start(), m.end()) for m in matches]

# Create matches_tokens list
matches_tokens = []
for start, end in matches_spans:
token_start = next((i for i, (ts, te) in enumerate(token_spans) if ts <= start < te), None)
token_end = next((i for i, (ts, te) in enumerate(token_spans) if ts < end <= te), None) + 1
if token_start is not None and token_end is not None:
matches_tokens.append((token_start, token_end))

# Remove duplicate spans
seen_tokens = set()
matches_tokens = [m for m in matches_tokens if not (m in seen_tokens or seen_tokens.add(m))]

# If overlapping token spans are found, split them
non_overlapping_matches = []
for curr_idx, (start, end) in enumerate(matches_tokens):
curr_start, curr_end = start, end
if len(matches_tokens) > curr_idx + 1 and end > matches_tokens[curr_idx + 1][0]:
curr_end = matches_tokens[curr_idx + 1][0]
if curr_idx > 0 and start < non_overlapping_matches[-1][1]:
curr_start = non_overlapping_matches[-1][1]
non_overlapping_matches.append((curr_start, curr_end))
if curr_end != end and end < matches_tokens[curr_idx + 1][1]:
non_overlapping_matches.append((curr_end, end))
matches_tokens = non_overlapping_matches

# Fill missing spans
aggregate_spans = []
matched_span = []
if matches_tokens[0][0] != 0:
aggregate_spans.append((0, matches_tokens[0][0]))
matched_span.append(False)
for i in range(len(matches_tokens) - 1):
aggregate_spans.append(matches_tokens[i])
matched_span.append(True)
if matches_tokens[i][1] != matches_tokens[i + 1][0]:
aggregate_spans.append((matches_tokens[i][1], matches_tokens[i + 1][0]))
matched_span.append(False)
aggregate_spans.append(matches_tokens[-1])
matched_span.append(True)
if matches_tokens[-1][1] != len(tokens):
aggregate_spans.append((matches_tokens[-1][1], len(tokens)))
matched_span.append(False)

# Create aggregate spans based on the split strategy
if split_mode == cls.SplitStrategy.SINGLE.value:
return aggregate_spans
elif split_mode in (cls.SplitStrategy.START.value, cls.SplitStrategy.END.value):
merge_aggregate_spans = []
curr_span_start = 0

# If the strategy is "start", all match spans are concatenated to their following non-match spans
# If the strategy is "end", all match spans are concatenated to their preceding non-match spans
# Example:
# aggregate_spans = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 9)]
# matched_span = [False, True, True, False, True, False]
# Start strategy: [(0, 1), (1, 2), (2, 4), (4, 9)]
# End strategy: [(0, 2), (2, 3), (3, 5), (5, 9)]
for (start, end), is_match in zip(aggregate_spans, matched_span, strict=False):
if is_match:
if split_mode == cls.SplitStrategy.START.value and start != curr_span_start:
merge_aggregate_spans.append((curr_span_start, start))
curr_span_start = start
elif split_mode == cls.SplitStrategy.END.value:
merge_aggregate_spans.append((curr_span_start, end))
curr_span_start = end
if curr_span_start != aggregate_spans[-1][1]:
merge_aggregate_spans.append((curr_span_start, aggregate_spans[-1][1]))

return merge_aggregate_spans
else:
raise ValueError("Invalid split strategy: must be one of 'single', 'start', 'end'")


class PairAggregator(SequenceAttributionAggregator):
"""Aggregates two FeatureAttributionSequenceOutput object into a single one containing the diff.
Expand Down

0 comments on commit 35bd4dc

Please sign in to comment.